Source code for fast_conformation.ensemble_analysis.rmsf

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib import cm

from scipy.signal import find_peaks

from MDAnalysis.analysis import rms, align
import pyqtgraph as pg
from tqdm import tqdm


TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

[docs] def calculate_rmsf_and_call_peaks(jobname, prediction_dicts, align_range, output_path, peak_width, prominence, threshold, widget): """ Calculate Root Mean Square Fluctuation (RMSF) and detect peaks for multiple molecular dynamics predictions. Parameters: jobname (str): The name of the job or analysis. prediction_dicts (dict): A dictionary containing prediction data with associated MDAnalysis Universes. align_range (str): Atom selection string for alignment of trajectories (MDAnalysis selection syntax). output_path (str): Directory where the analysis results and plots will be saved. peak_width (int): Minimum width of peaks in the RMSF data to be considered. prominence (float): Prominence required for a peak in the RMSF data. threshold (float): Threshold for peak detection based on the standard deviation of RMSF values. widget (object): A widget object for displaying plots interactively. Returns: dict: The updated prediction_dicts with detected peaks added to each prediction's data. """ with tqdm(total=len(prediction_dicts), bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}') as pbar: for result in prediction_dicts: pbar.set_description(f'Running RMSF peak analysis for {result}') u = prediction_dicts[result]['mda_universe'] max_seq = prediction_dicts[result]['max_seq'] extra_seq = prediction_dicts[result]['extra_seq'] average = align.AverageStructure(u, u, select=align_range, ref_frame=0).run() ref = average.results.universe align.AlignTraj(u, ref, select=align_range, in_memory=True).run() atom_sel = u.select_atoms(align_range) r = rms.RMSF(atom_sel).run() rmsf_values = r.results.rmsf resids = atom_sel.resids mean_rmsf = np.mean(rmsf_values) std_rmsf = np.std(rmsf_values) threshold = mean_rmsf + threshold * std_rmsf # Detect peaks peaks, properties = find_peaks(rmsf_values, width=peak_width, prominence=prominence) # Plot RMSF with detected peaks x_label = 'Residue Number' y_label = 'RMSF (Å)' title = f'{jobname} {max_seq} {extra_seq}' plt.figure(figsize=(8, 3)) plt.title(title, fontsize=16) plt.xlabel(x_label, fontsize=14) plt.ylabel(y_label, fontsize=14) plt.tick_params(axis='both', which='major', labelsize=12) plt.plot(resids, rmsf_values) plt.plot(resids[peaks], rmsf_values[peaks], "x", color="C1") plt.vlines(x=resids[peaks], ymin=rmsf_values[peaks] - properties["prominences"], ymax=rmsf_values[peaks], color="C1") plt.hlines(y=properties["width_heights"], xmin=resids[properties["left_ips"].astype(int)], xmax=resids[properties["right_ips"].astype(int)], color="C1") plt.tight_layout() full_output_path = (f"{output_path}/" f"{jobname}/" f"analysis/" f"mobile_detection/" f"{jobname}_" f"{max_seq}_" f"{extra_seq}_" f"rmsf_peaks.png") plt.savefig(full_output_path, dpi=300) # Extract and save detected peaks information detected_peaks = {} peak_counter = 0 for i, peak in enumerate(peaks): left_ip = int(properties["left_ips"][i]) right_ip = int(properties["right_ips"][i]) peak_resids = resids[left_ip:right_ip + 1] peak_prop = {'starting_residue': peak_resids[0], 'ending_residue': peak_resids[-1], 'length': len(peak_resids), 'peak_value': rmsf_values[peak], 'prominence': properties["prominences"][i], 'width_height': properties["width_heights"][i]} peak_counter += 1 detected_peaks[f'detected_peak_{peak_counter}'] = peak_prop prediction_dicts[result]['detected_peaks'] = detected_peaks pbar.update(n=1) return prediction_dicts
[docs] def calculate_rmsf_multiple(jobname, prediction_dicts, align_range, output_path, widget): """ Calculate Root Mean Square Fluctuation (RMSF) for multiple molecular dynamics predictions and plot the results. Parameters: jobname (str): The name of the job or analysis. prediction_dicts (dict): A dictionary containing prediction data with associated MDAnalysis Universes. align_range (str): Atom selection string for alignment of trajectories (MDAnalysis selection syntax). output_path (str): Directory where the analysis results and plots will be saved. widget (object): A widget object for displaying plots interactively. Returns: None """ labels = [] x_label = 'Residue number' y_label = 'RMSF (Å)' title = f'{jobname}' if output_path: plt.figure(figsize=(8, 3)) plt.title(title, fontsize=16) plt.xlabel(x_label, fontsize=14) plt.ylabel(y_label, fontsize=14) plt.tick_params(axis='both', which='major', labelsize=12) colors = ['blue', 'green', 'magenta', 'orange', 'grey', 'brown', 'cyan', 'purple'] with tqdm(total=len(prediction_dicts), bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}') as pbar: plotter = None for idx, (result, data) in enumerate(prediction_dicts.items()): pbar.set_description(f'Measuring RMSF for {result}') u = data['mda_universe'] average = align.AverageStructure(u, u, select=align_range, ref_frame=0).run() ref = average.results.universe align.AlignTraj(u, ref, select=align_range, in_memory=True).run() atom_sel = u.select_atoms(align_range) r = rms.RMSF(atom_sel).run() resids = atom_sel.resids if output_path: plt.plot(resids, r.results.rmsf, color=colors[idx % len(colors)], label=result) if widget: if plotter is None: plotter = widget.add_plot(resids, r.results.rmsf, color=colors[idx % len(colors)], label=result, title=title, x_label=x_label, y_label=y_label) else: widget.add_line(plotter, resids, r.results.rmsf, color=colors[idx % len(colors)], label=result) labels.append(result) pbar.update(n=1) if output_path: plt.legend() plt.tight_layout() full_output_path = (f"{output_path}/" f"{jobname}/" f"analysis/" f"rmsf_plddt/" f"{jobname}_rmsf_all.png") plt.savefig(full_output_path, dpi=300) plt.close()
[docs] def plot_plddt_rmsf_corr(jobname, prediction_dicts, plddt_dict, output_path, widget): """ Plot the correlation between pLDDT scores and RMSF values for each prediction. Parameters: jobname (str): The name of the job or analysis. prediction_dicts (dict): A dictionary containing prediction data with associated MDAnalysis Universes. plddt_dict (dict): A dictionary containing pLDDT scores for each prediction. output_path (str): Directory where the analysis results and plots will be saved. widget (object): A widget object for displaying plots interactively. Returns: None """ with tqdm(total=len(prediction_dicts), bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}') as pbar: plot_item = None for result in prediction_dicts: pbar.set_description(f'Running pLDDT/RMSF Correlation Analysis for {result}') max_seq = prediction_dicts[result]['max_seq'] extra_seq = prediction_dicts[result]['extra_seq'] plddt_data = plddt_dict[result]['all_plddts'] arrays = np.array(plddt_data) plddt_avg = np.mean(arrays, axis=0) u = prediction_dicts[result]['mda_universe'] average = align.AverageStructure(u, u, select='name CA', ref_frame=0).run() ref = average.results.universe align.AlignTraj(u, ref, select='name CA', in_memory=True).run() atom_sel = u.select_atoms('name CA') r = rms.RMSF(atom_sel).run() rmsf_values = r.results.rmsf resids = atom_sel.resids norm = Normalize(vmin=resids.min(), vmax=resids.max()) cmap = plt.get_cmap('viridis') title = f'{jobname} {max_seq} {extra_seq}' x_label = 'C-Alpha RMSF (A)' y_label = 'Average pLDDT' if widget: plot_item = widget.add_plot(rmsf_values, plddt_avg, title=title, x_label=x_label, y_label=y_label, resids=resids, scatter=True, colorbar=True) if output_path: plt.scatter(rmsf_values, plddt_avg, c=[cmap(norm(resid)) for resid in resids]) plt.title(title, fontsize=16) plt.xlabel(x_label, fontsize=14) plt.ylabel(y_label, fontsize=14) plt.tick_params(axis='both', which='major', labelsize=12) plt.tick_params(axis='both', which='minor', labelsize=12) plt.tight_layout() plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=plt.gca(), label='Residue #') full_output_path = (f"{output_path}/" f"{jobname}/" f"analysis/" f"rmsf_plddt/" f"{jobname}_" f"{max_seq}_" f"{extra_seq}_" f"plddt_rmsf_corr.png") plt.savefig(full_output_path, dpi=300) plt.close() pbar.update(n=1)
[docs] def plot_plddt_line(jobname, plddt_dict, output_path, custom_start_residue, widget): """ Plot pLDDT scores across residues for multiple predictions. Parameters: jobname (str): The name of the job or analysis. plddt_dict (dict): A dictionary containing pLDDT scores for each prediction. output_path (str): Directory where the analysis results and plots will be saved. custom_start_residue (int): A custom starting residue number for plotting. widget (object): A widget object for displaying plots interactively. Returns: None """ plotter = None colors = ['red', 'blue', 'green', 'purple', 'orange', 'grey', 'brown', 'cyan', 'magenta'] x_label = 'Residue number' y_label = 'pLDDT' title = f'{jobname}' if output_path: plt.figure(figsize=(8, 3)) plt.title(title, fontsize=16) plt.xlabel(x_label, fontsize=14) plt.ylabel(y_label, fontsize=14) plt.tick_params(axis='both', which='major', labelsize=12) for idx, result in enumerate(plddt_dict): plddt_data = plddt_dict[result]['all_plddts'] arrays = np.array(plddt_data) plddt_avg = np.mean(arrays, axis=0) length_avg = plddt_avg.shape[0] if np.ndim(arrays) != 1 else arrays.shape[0] residue_numbers = np.arange(length_avg) if custom_start_residue is not None: residue_numbers += custom_start_residue if output_path: plt.plot(residue_numbers, plddt_avg, color=colors[idx % len(colors)], label=result) if plotter is None and widget: plotter = widget.add_plot(residue_numbers, plddt_avg, color=colors[idx % len(colors)], label=result, title=title, x_label=x_label, y_label=y_label) elif widget: widget.add_line(plotter, residue_numbers, plddt_avg, color=colors[idx % len(colors)], label=result) if output_path: plt.legend() plt.tight_layout() full_output_path = (f"{output_path}/" f"{jobname}/" f"analysis/" f"rmsf_plddt/" f"{jobname}_" f"all_plddt.png") plt.savefig(full_output_path, dpi=300) plt.close()
[docs] def build_dataset_rmsf_peaks(jobname, results_dict, output_path, engine): """ Build a dataset from detected RMSF peaks and save it as a CSV file. Parameters: jobname (str): The name of the job or analysis. results_dict (dict): A dictionary containing detected peaks for each prediction. output_path (str): Directory where the analysis results and dataset will be saved. engine (str): The name of the engine used for the analysis (e.g., AlphaFold2, OpenFold). Returns: None: The function saves the dataset as a CSV file in the specified output directory. """ trials = [] peak_labels = [] peak_starting_residues = [] peak_ending_residues = [] peak_lengths = [] peak_rmsf_values = [] peak_prominences = [] peak_width_heights = [] for result in results_dict: peak_data = results_dict[result]['detected_peaks'] for peak in peak_data: trials.append(result) peak_labels.append(peak) peak_starting_residues.append(results_dict[result]['detected_peaks'][peak]['starting_residue']) peak_ending_residues.append(results_dict[result]['detected_peaks'][peak]['ending_residue']) peak_rmsf_values.append(results_dict[result]['detected_peaks'][peak]['peak_value']) peak_prominences.append(results_dict[result]['detected_peaks'][peak]['prominence']) peak_lengths.append(results_dict[result]['detected_peaks'][peak]['length']) peak_width_heights.append(results_dict[result]['detected_peaks'][peak]['width_height']) df = pd.DataFrame({'peak_label': peak_labels, 'starting_residue': peak_starting_residues, 'ending_residue': peak_ending_residues, 'length': peak_lengths, 'rmsf_value': peak_rmsf_values, 'prominence': peak_prominences, 'width_height': peak_width_heights, 'trial': trials}) full_output_path = (f"{output_path}/" f"{jobname}/" f"analysis/" f"mobile_detection/" f"{jobname}_" f"rmsf_peak_calling_results.csv") df.to_csv(full_output_path, index=False) print(f"\nSaved peak calling output of all results at {output_path}/{jobname}/predictions/{engine}/ to " f"{full_output_path}\n")