import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from scipy.optimize import curve_fit
from tqdm import tqdm
from fast_conformation.ensemble_analysis.analysis_utils import parabola
from fast_conformation.ensemble_analysis.rmsd import calculate_rmsd
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'
[docs]
class TwodRMSD:
"""
A class to perform 2D RMSD analysis on molecular dynamics simulations.
Attributes:
----------
prediction_dicts : dict
A dictionary containing prediction data with associated MDAnalysis Universes.
input_dict : dict
A dictionary containing job-related metadata (jobname, analysis range, etc.).
ref_gr : str or None, optional
The reference structure file path for the first RMSD calculation (default is None).
ref_alt : str or None, optional
The reference structure file path for the second RMSD calculation (default is None).
filtering_dict : dict
A dictionary to store data related to the filtering of RMSD values.
clustering_dict : dict
A dictionary to store data related to the clustering of 2D RMSD values.
widget : object
A widget object to handle the plotting of the analysis results.
Methods:
-------
calculate_2d_rmsd(trial):
Calculate 2D RMSD for a given trial.
fit_and_filter_data(rmsd_2d_data, n_stdevs):
Fit a parabola to the 2D RMSD data and filter points based on the standard deviation threshold.
show_filt_data(rmsd_2d_data):
Plot the 2D RMSD data along with the fitted curve and filtered points.
plot_filtering_data(rmsd_2d_data):
Generate and save a plot of the filtered 2D RMSD data with the fitted curve.
cluster_2d_data(rmsd_2d_data, n_clusters):
Perform clustering on the filtered 2D RMSD data and store clustering results.
plot_and_save_2d_data(output_path):
Plot the clustered 2D RMSD data, save the plot, and return a DataFrame with the clustering information.
get_2d_rmsd(rmsd_mode_df_path, n_stdevs, n_clusters, output_path):
Execute the full 2D RMSD analysis for all trials, including fitting, filtering, clustering, and saving results.
"""
def __init__(self, prediction_dicts, input_dict, widget, ref_gr=None, ref_alt=None):
"""
Initialize the TwodRMSD class.
Parameters:
----------
prediction_dicts : dict
A dictionary containing prediction data with associated MDAnalysis Universes.
input_dict : dict
A dictionary containing job-related metadata (jobname, analysis range, etc.).
widget : object
A widget object to handle the plotting of the analysis results.
ref_gr : str or None, optional
The reference structure file path for the first RMSD calculation (default is None).
ref_alt : str or None, optional
The reference structure file path for the second RMSD calculation (default is None).
"""
self.prediction_dicts = prediction_dicts
self.input_dict = input_dict
self.ref_gr = ref_gr
self.ref_alt = ref_alt
self.filtering_dict = {}
self.clustering_dict = {}
self.widget = widget
[docs]
def calculate_2d_rmsd(self, trial):
"""
Calculate 2D RMSD for a given trial.
Parameters:
----------
trial : str
The identifier for the trial being analyzed.
Returns:
-------
rmsd_2d_data : np.ndarray
A 2D array of RMSD values against two reference structures.
"""
universe = self.prediction_dicts[trial]['mda_universe']
rmsd_gr = calculate_rmsd(universe,
ref=self.ref_gr,
align_range='backbone',
analysis_range=self.input_dict['analysis_range'])
rmsd_alt = calculate_rmsd(universe,
ref=self.ref_alt,
align_range='backbone',
analysis_range=self.input_dict['analysis_range'])
rmsd_2d_data = np.array([rmsd_gr[self.input_dict['analysis_range']],
rmsd_alt[self.input_dict['analysis_range']]]).T
return rmsd_2d_data
[docs]
def fit_and_filter_data(self, rmsd_2d_data, n_stdevs):
"""
Fit a parabola to the 2D RMSD data and filter points based on the standard deviation threshold.
Parameters:
----------
rmsd_2d_data : np.ndarray
A 2D array of RMSD values.
n_stdevs : int
Number of standard deviations to use for filtering the data.
Returns:
-------
None
"""
popt, _ = curve_fit(parabola, rmsd_2d_data[:, 0], rmsd_2d_data[:, 1])
fit_x = np.linspace(min(rmsd_2d_data[:, 0]), max(rmsd_2d_data[:, 0]), 100)
fit_y = parabola(fit_x, *popt)
fitted_curve_values = parabola(rmsd_2d_data[:, 0], *popt)
distances = np.abs(rmsd_2d_data[:, 1] - fitted_curve_values)
threshold = np.mean(distances) + int(n_stdevs) * np.std(distances)
close_points = distances < threshold
x_close = rmsd_2d_data[close_points, 0]
y_close = rmsd_2d_data[close_points, 1]
bin_edges = np.linspace(min(x_close), max(x_close), 101)
bins = np.digitize(x_close, bin_edges)
unique_bins = np.unique(bins)
ratio = len(unique_bins) / 100
self.filtering_dict = {
'fit_x': fit_x,
'fit_y': fit_y,
'close_points': close_points,
'bins': bins,
'unique_bins': unique_bins,
'ratio': ratio,
'bin_edges': bin_edges,
'x_close': x_close,
'y_close': y_close,
}
[docs]
def show_filt_data(self, rmsd_2d_data):
"""
Plot the 2D RMSD data along with the fitted curve and filtered points.
Parameters:
----------
rmsd_2d_data : np.ndarray
A 2D array of RMSD values.
Returns:
-------
None
"""
title = (f"{self.input_dict['jobname']} "
f"{self.input_dict['max_seq']} "
f"{self.input_dict['extra_seq']}")
plotter = self.widget.add_plot(rmsd_2d_data[:, 0], rmsd_2d_data[:, 1], title=title,
x_label='RMSD vs. Ref1 (Å)', y_label='RMSD vs. Ref2 (Å)', scatter=True)
self.widget.add_line(plotter, self.filtering_dict['fit_x'], self.filtering_dict['fit_y'],
label='Fitted Curve', color='r')
self.widget.add_scatter(plotter, self.filtering_dict['x_close'],
self.filtering_dict['y_close'],
label='Close Points',
color=[68, 1, 84, 255])
[docs]
def plot_filtering_data(self, rmsd_2d_data):
"""
Generate and save a plot of the filtered 2D RMSD data with the fitted curve.
Parameters:
----------
rmsd_2d_data : np.ndarray
A 2D array of RMSD values.
Returns:
-------
None
"""
plt.figure(figsize=(5, 4))
plt.scatter(rmsd_2d_data[:, 0], rmsd_2d_data[:, 1], s=10)
plt.plot(self.filtering_dict['fit_x'],
self.filtering_dict['fit_y'],
label='Fitted Curve', color='red')
plt.scatter(self.filtering_dict['x_close'],
self.filtering_dict['y_close'],
label='Close Points',
color='green',
s=20)
title = (f"{self.input_dict['jobname']} "
f"{self.input_dict['max_seq']} "
f"{self.input_dict['extra_seq']}")
plt.title(title, fontsize=15)
plt.xlabel('RMSD vs. Ref1 (Å)', fontsize=14)
plt.ylabel('RMSD vs. Ref2 (Å)', fontsize=14)
plt.tick_params(axis='both', which='major', labelsize=12)
plt.legend(loc='best')
plt.tight_layout()
plot_path = (f"{self.input_dict['output_path']}/"
f"{self.input_dict['jobname']}/"
f"analysis/"
f"rmsd_2d/"
f"{self.input_dict['jobname']}_"
f"{self.input_dict['max_seq']}_"
f"{self.input_dict['extra_seq']}_"
f"rmsd_2d_fit.png")
plt.savefig(plot_path, dpi=300)
plt.close()
[docs]
def cluster_2d_data(self, rmsd_2d_data, n_clusters):
"""
Perform clustering on the filtered 2D RMSD data and store clustering results.
Parameters:
----------
rmsd_2d_data : np.ndarray
A 2D array of RMSD values.
n_clusters : int
Number of clusters to form.
Returns:
-------
None
"""
kmeans = KMeans(n_clusters)
close_points_2d = np.array([self.filtering_dict['x_close'],
self.filtering_dict['y_close']]).T
kmeans.fit(close_points_2d)
labels = kmeans.labels_
centroids = kmeans.cluster_centers_
sorted_indices = np.argsort(centroids[:, 0])
correct_labels = np.zeros_like(labels)
for correct_label, original_index in enumerate(sorted_indices):
correct_labels[labels == original_index] = correct_label
unique_labels = set(correct_labels)
cluster_counts = {i: 0 for i in unique_labels}
total_samples = close_points_2d.shape[0]
outliers = rmsd_2d_data.shape[0] - total_samples
outliers = (outliers / total_samples) * 100
for k in unique_labels:
class_member_mask = (correct_labels == k)
xy = close_points_2d[class_member_mask]
cluster_counts[k] = round((len(xy) / total_samples) * 100, 1)
self.clustering_dict = {
'labels': labels,
'correct_labels': correct_labels,
'centroids': centroids,
'cluster_counts': cluster_counts,
'outliers': outliers,
'unique_labels': unique_labels,
'close_points_2d': close_points_2d
}
[docs]
def plot_and_save_2d_data(self, output_path):
"""
Plot the clustered 2D RMSD data, save the plot, and return a DataFrame with the clustering information.
Parameters:
----------
output_path : str
The path where the plot will be saved.
Returns:
-------
df : pd.DataFrame
A DataFrame containing the clustering information.
"""
unique_labels = self.clustering_dict['unique_labels']
correct_labels = self.clustering_dict['correct_labels']
cluster_counts = self.clustering_dict['cluster_counts']
centroids = self.clustering_dict['centroids']
outliers = self.clustering_dict['outliers']
title = (f"{self.input_dict['jobname']} "
f"{self.input_dict['max_seq']} "
f"{self.input_dict['extra_seq']} "
f"Score: {self.filtering_dict['ratio']:.2f}")
colors = np.array([[68, 1, 84, 255], [58, 82, 139, 255], [32, 144, 140, 255], [94, 201, 97, 255], [253, 231, 37, 255]])
if self.widget:
plotter = self.widget.add_plot(centroids[:, 0], centroids[:, 1], title=title, x_label='RMSD vs. Ref1 (Å)', y_label='RMSD vs. Ref2 (Å)', label='Centroids', scatter=True)
self.widget.add_line(plotter, self.filtering_dict['fit_x'], self.filtering_dict['fit_y'], label='Fitted Curve', color='r')
for i in unique_labels:
cluster_points = self.clustering_dict['close_points_2d'][correct_labels == i]
self.widget.add_scatter(plotter, cluster_points[:, 0], cluster_points[:, 1], color=colors[i], label=f'Cluster {i} pop: {cluster_counts[i]}')
if output_path:
plt.figure(figsize=(5, 4))
for i in unique_labels:
cluster_points = self.clustering_dict['close_points_2d'][correct_labels == i]
plt.scatter(cluster_points[:, 0], cluster_points[:, 1],
label=f'Cluster {i} pop: {cluster_counts[i]}', alpha=0.6)
plt.scatter(centroids[:, 0], centroids[:, 1], s=100, c='black', marker='X', label='Centroids')
plt.plot(self.filtering_dict['fit_x'], self.filtering_dict['fit_y'], label='Fitted Curve', color='red')
plt.title(title, fontsize=16)
plt.legend()
plt.xlabel('RMSD vs. Ref1 (Å)', fontsize=14)
plt.ylabel('RMSD vs. Ref2 (Å)', fontsize=14)
plt.tick_params(axis='both', which='major', labelsize=12)
plt.tight_layout()
plot_path = (f"{self.input_dict['output_path']}/"
f"{self.input_dict['jobname']}/"
f"analysis/"
f"rmsd_2d/"
f"{self.input_dict['jobname']}_"
f"{self.input_dict['max_seq']}_"
f"{self.input_dict['extra_seq']}_"
f"rmsd_2d_clustered.png")
plt.savefig(plot_path, dpi=300)
plt.close()
records = []
for i in unique_labels:
records.append({
'trial': self.input_dict['trial'],
'analysis_range': self.input_dict['analysis_range'],
'score': self.filtering_dict['ratio'],
'cluster_label': i,
'cluster_pop': cluster_counts[i],
'centroid_values': centroids[i],
'%_outliers': outliers
})
df = pd.DataFrame(records)
return df
[docs]
def get_2d_rmsd(self, rmsd_mode_df_path, n_stdevs, n_clusters, output_path):
"""
Execute the full 2D RMSD analysis for all trials, including fitting, filtering, clustering, and saving results.
Parameters:
----------
rmsd_mode_df_path : str
The path to the RMSD mode data file.
n_stdevs : int
Number of standard deviations to use for filtering the data.
n_clusters : int
Number of clusters to form.
output_path : str
The path where the results will be saved.
Returns:
-------
None
"""
df = pd.read_csv(rmsd_mode_df_path)
unique_trials = df['trial'].unique()
df_all_trials = pd.DataFrame()
with tqdm(total=len(self.prediction_dicts), bar_format=TQDM_BAR_FORMAT) as pbar:
for trial in unique_trials:
if not n_clusters:
unique_df = df[df['trial'] == trial]
n_clusters_trial = len(unique_df['mode_label']) + 1
pbar.set_description(f'Running 2D RMSD analysis for {trial}')
self.input_dict['trial'] = trial
self.input_dict['max_seq'] = self.prediction_dicts[trial]['max_seq']
self.input_dict['extra_seq'] = self.prediction_dicts[trial]['extra_seq']
rmsd_2d_data = self.calculate_2d_rmsd(trial)
if len(rmsd_2d_data) > 0:
self.fit_and_filter_data(rmsd_2d_data, n_stdevs)
if output_path:
self.plot_filtering_data(rmsd_2d_data)
if self.widget:
self.show_filt_data(rmsd_2d_data)
self.cluster_2d_data(rmsd_2d_data, n_clusters_trial)
df_to_save = self.plot_and_save_2d_data(output_path)
df_all_trials = pd.concat([df_all_trials, df_to_save], ignore_index=True)
pbar.update(n=1)
csv_path = (f"{self.input_dict['output_path']}/"
f"{self.input_dict['jobname']}/"
f"analysis/"
f"rmsd_2d/"
f"{self.input_dict['jobname']}_"
f"clustering_analysis.csv")
print(f"\nSaving {self.input_dict['jobname']} 2D RMSD analysis results to {csv_path}\n")
df_all_trials.to_csv(csv_path, index=False)