Source code for fast_conformation.msa_generation.colabfold

# fmt: off

############################################
# imports
###########################################
import hashlib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects

from string import ascii_uppercase, ascii_lowercase
import hashlib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects
from string import ascii_uppercase, ascii_lowercase

# List of colors used in PyMOL
pymol_color_list = ["#33ff33", "#00ffff", "#ff33cc", "#ffff00", "#ff9999", "#e5e5e5", "#7f7fff", "#ff7f00",
                    "#7fff7f", "#199999", "#ff007f", "#ffdd5e", "#8c3f99", "#b2b2b2", "#007fff", "#c4b200",
                    "#8cb266", "#00bfbf", "#b27f7f", "#fcd1a5", "#ff7f7f", "#ffbfdd", "#7fffff", "#ffff7f",
                    "#00ff7f", "#337fcc", "#d8337f", "#bfff3f", "#ff7fff", "#d8d8ff", "#3fffbf", "#b78c4c",
                    "#339933", "#66b2b2", "#ba8c84", "#84bf00", "#b24c66", "#7f7f7f", "#3f3fa5", "#a5512b"]

pymol_cmap = matplotlib.colors.ListedColormap(pymol_color_list)
alphabet_list = list(ascii_uppercase + ascii_lowercase)

aatypes = set('ACDEFGHIKLMNPQRSTVWY')

[docs] def get_hash(x): """ Generate a SHA-1 hash for a given string. Args: x (str): The input string to be hashed. Returns: str: The SHA-1 hash of the input string. """ return hashlib.sha1(x.encode()).hexdigest()
[docs] def homooligomerize(msas, deletion_matrices, homooligomer=1): """ Homooligomerizes the input MSAs (Multiple Sequence Alignments) and deletion matrices. Args: msas (list of lists): A list of MSAs. deletion_matrices (list of lists): A list of deletion matrices corresponding to the MSAs. homooligomer (int): The number of homooligomeric copies. Default is 1 (no homooligomerization). Returns: tuple: A tuple containing the homooligomerized MSAs and deletion matrices. """ if homooligomer == 1: return msas, deletion_matrices else: new_msas = [] new_mtxs = [] for o in range(homooligomer): for msa, mtx in zip(msas, deletion_matrices): num_res = len(msa[0]) L = num_res * o R = num_res * (homooligomer - (o + 1)) new_msas.append(["-" * L + s + "-" * R for s in msa]) new_mtxs.append([[0] * L + m + [0] * R for m in mtx]) return new_msas, new_mtxs
# Keeping the typo for cross-compatibility
[docs] def homooliomerize(msas, deletion_matrices, homooligomer=1): """ Homooligomerizes the input MSAs and deletion matrices. This function is a typo version of `homooligomerize` for cross-compatibility. Args: msas (list of lists): A list of MSAs. deletion_matrices (list of lists): A list of deletion matrices corresponding to the MSAs. homooligomer (int): The number of homooligomeric copies. Default is 1 (no homooligomerization). Returns: tuple: A tuple containing the homooligomerized MSAs and deletion matrices. """ return homooligomerize(msas, deletion_matrices, homooligomer=homooligomer)
[docs] def homooligomerize_heterooligomer(msas, deletion_matrices, lengths, homooligomers): """ Homooligomerizes the input MSAs and deletion matrices for heterooligomeric complexes. Args: msas (list of lists): A list of MSAs. deletion_matrices (list of lists): A list of deletion matrices corresponding to the MSAs. lengths (list of int): A list of lengths for each component in the complex. homooligomers (list of int): A list of homooligomeric copies for each component. Returns: tuple: A tuple containing the homooligomerized MSAs and deletion matrices. """ if max(homooligomers) == 1: return msas, deletion_matrices elif len(homooligomers) == 1: return homooligomerize(msas, deletion_matrices, homooligomers[0]) else: frag_ij = [[0, lengths[0]]] for length in lengths[1:]: j = frag_ij[-1][-1] frag_ij.append([j, j + length]) mod_msas, mod_mtxs = [], [] for msa, mtx in zip(msas, deletion_matrices): mod_msa, mod_mtx = [], [] for n, (s, m) in enumerate(zip(msa, mtx)): _s, _m, _ok = [], [], [] for i, j in frag_ij: _s.append(s[i:j]) _m.append(m[i:j]) _ok.append(max([o != "-" for o in _s[-1]])) if n == 0: mod_msa.append("".join([x * h for x, h in zip(_s, homooligomers)])) mod_mtx.append(sum([x * h for x, h in zip(_m, homooligomers)], [])) elif sum(_ok) == 1: a = _ok.index(True) for h_a in range(homooligomers[a]): _blank_seq = [["-" * l] * h for l, h in zip(lengths, homooligomers)] _blank_mtx = [[[0] * l] * h for l, h in zip(lengths, homooligomers)] _blank_seq[a][h_a] = _s[a] _blank_mtx[a][h_a] = _m[a] mod_msa.append("".join(["".join(x) for x in _blank_seq])) mod_mtx.append(sum([sum(x, []) for x in _blank_mtx], [])) else: for a in range(len(lengths) - 1): if _ok[a]: for b in range(a + 1, len(lengths)): if _ok[b]: for h_a in range(homooligomers[a]): for h_b in range(homooligomers[b]): _blank_seq = [["-" * l] * h for l, h in zip(lengths, homooligomers)] _blank_mtx = [[[0] * l] * h for l, h in zip(lengths, homooligomers)] for c, h_c in zip([a, b], [h_a, h_b]): _blank_seq[c][h_c] = _s[c] _blank_mtx[c][h_c] = _m[c] mod_msa.append("".join(["".join(x) for x in _blank_seq])) mod_mtx.append(sum([sum(x, []) for x in _blank_mtx], [])) mod_msas.append(mod_msa) mod_mtxs.append(mod_mtx) return mod_msas, mod_mtxs
[docs] def chain_break(idx_res, Ls, length=200): """ Adds a large number to residue indices to indicate chain breaks in a sequence. Args: idx_res (ndarray): The array of residue indices. Ls (list of int): The lengths of different segments in the sequence. length (int): The value to add to the residue index at chain breaks. Returns: ndarray: The updated array of residue indices with chain breaks. """ L_prev = 0 for L_i in Ls[:-1]: idx_res[L_prev + L_i:] += length L_prev += L_i return idx_res
################################################## # plotting functions ##################################################
[docs] def plot_plddt_legend(dpi=100): """ Plots a legend for pLDDT (predicted Local Distance Difference Test) scores. Args: dpi (int): Dots per inch setting for the plot. Returns: matplotlib.pyplot: The plot object with the pLDDT legend. """ thresh = ['plDDT:', 'Very low (<50)', 'Low (60)', 'OK (70)', 'Confident (80)', 'Very high (>90)'] plt.figure(figsize=(1, 0.1), dpi=dpi) for c in ["#FFFFFF", "#FF0000", "#FFFF00", "#00FF00", "#00FFFF", "#0000FF"]: plt.bar(0, 0, color=c) plt.legend(thresh, frameon=False, loc='center', ncol=6, handletextpad=1, columnspacing=1, markerscale=0.5) plt.axis(False) return plt
[docs] def plot_ticks(Ls): """ Plots tick marks indicating segment boundaries on a plot. Args: Ls (list of int): The lengths of different segments in the sequence. """ Ln = sum(Ls) L_prev = 0 for L_i in Ls[:-1]: L = L_prev + L_i L_prev += L_i plt.plot([0, Ln], [L, L], color="black") plt.plot([L, L], [0, Ln], color="black") ticks = np.cumsum([0] + Ls) ticks = (ticks[1:] + ticks[:-1]) / 2 plt.yticks(ticks, alphabet_list[:len(ticks)])
[docs] def plot_confidence(plddt, pae=None, Ls=None, dpi=100): """ Plots predicted confidence metrics (pLDDT and PAE) for a protein structure. Args: plddt (ndarray): Array of predicted Local Distance Difference Test (pLDDT) scores. pae (ndarray): Array of Predicted Aligned Error (PAE) scores (optional). Ls (list of int): The lengths of different segments in the sequence (optional). dpi (int): Dots per inch setting for the plot. Returns: matplotlib.pyplot: The plot object displaying the confidence metrics. """ use_ptm = False if pae is None else True if use_ptm: plt.figure(figsize=(10, 3), dpi=dpi) plt.subplot(1, 2, 1) else: plt.figure(figsize=(5, 3), dpi=dpi) plt.title('Predicted lDDT') plt.plot(plddt) if Ls is not None: L_prev = 0 for L_i in Ls[:-1]: L = L_prev + L_i L_prev += L_i plt.plot([L, L], [0, 100], color="black") plt.ylim(0, 100) plt.ylabel('plDDT') plt.xlabel('position') if use_ptm: plt.subplot(1, 2, 2) plt.title('Predicted Aligned Error') Ln = pae.shape[0] plt.imshow(pae, cmap="bwr", vmin=0, vmax=30, extent=(0, Ln, Ln, 0)) if Ls is not None and len(Ls) > 1: plot_ticks(Ls) plt.colorbar() plt.xlabel('Scored residue') plt.ylabel('Aligned residue') return plt
[docs] def plot_msas(msas, ori_seq=None, sort_by_seqid=True, deduplicate=True, dpi=100, return_plt=True): """ Plots Multiple Sequence Alignments (MSAs). Args: msas (list of lists): A list of MSAs to be plotted. ori_seq (str): The original sequence (optional). sort_by_seqid (bool): Whether to sort sequences by sequence identity (default: True). deduplicate (bool): Whether to remove duplicate sequences (default: True). dpi (int): Dots per inch setting for the plot. return_plt (bool): Whether to return the plot object (default: True). Returns: matplotlib.pyplot: The plot object displaying the MSAs, if return_plt is True. """ if ori_seq is None: ori_seq = msas[0][0] seqs = ori_seq.replace("/", "").split(":") seqs_dash = ori_seq.replace(":", "").split("/") Ln = np.cumsum(np.append(0, [len(seq) for seq in seqs])) Ln_dash = np.cumsum(np.append(0, [len(seq) for seq in seqs_dash])) Nn, lines = [], [] for msa in msas: msa_ = set(msa) if deduplicate else msa if len(msa_) > 0: Nn.append(len(msa_)) msa_ = np.asarray([list(seq) for seq in msa_]) gap_ = msa_ != "-" qid_ = msa_ == np.array(list("".join(seqs))) gapid = np.stack([gap_[:, Ln[i]:Ln[i + 1]].max(-1) for i in range(len(seqs))], -1) seqid = np.stack([qid_[:, Ln[i]:Ln[i + 1]].mean(-1) for i in range(len(seqs))], -1).sum(-1) / ( gapid.sum(-1) + 1e-8) non_gaps = gap_.astype(float) non_gaps[non_gaps == 0] = np.nan if sort_by_seqid: lines.append(non_gaps[seqid.argsort()] * seqid[seqid.argsort(), None]) else: lines.append(non_gaps[::-1] * seqid[::-1, None]) Nn = np.cumsum(np.append(0, Nn)) lines = np.concatenate(lines, 0) if return_plt: plt.figure(figsize=(8, 5), dpi=dpi) plt.title("Sequence coverage") plt.imshow(lines, interpolation='nearest', aspect='auto', cmap="rainbow_r", vmin=0, vmax=1, origin='lower', extent=(0, lines.shape[1], 0, lines.shape[0])) for i in Ln[1:-1]: plt.plot([i, i], [0, lines.shape[0]], color="black") for i in Ln_dash[1:-1]: plt.plot([i, i], [0, lines.shape[0]], "--", color="black") for j in Nn[1:-1]: plt.plot([0, lines.shape[1]], [j, j], color="black") plt.plot((np.isnan(lines) == False).sum(0), color='black') plt.xlim(0, lines.shape[1]) plt.ylim(0, lines.shape[0]) plt.colorbar(label="Sequence identity to query") plt.xlabel("Positions") plt.ylabel("Sequences") if return_plt: return plt