Source code for fast_conformation.gui.make_predictions
from PyQt5.QtWidgets import QApplication, QWidget, QGridLayout, QLabel, QComboBox, QLineEdit, QPushButton, QVBoxLayout, QHBoxLayout, QFileDialog, QCheckBox
import sys
from fast_conformation.predict_ensemble import run_ensemble_prediction
from fast_conformation.predict_ensemble import run_ensemble_prediction
from fast_conformation.gui.widget_base import AnalysisWidgetBase, merge_configs
from PyQt5.QtWidgets import QVBoxLayout, QGridLayout, QLabel, QComboBox, QLineEdit, QPushButton, QCheckBox, QFileDialog, QHBoxLayout
[docs]
class MakePredictionsWidget(AnalysisWidgetBase):
"""
The MakePredictionsWidget class provides a user interface for configuring
and running predictions using different MSA (Multiple Sequence Alignment)
options and settings.
Methods:
init_ui: Initializes the user interface components.
select_msa_path: Opens a file dialog to select the MSA file.
select_output_path: Opens a file dialog to select the output directory.
add_seq_pair: Adds a new sequence pair input to the interface.
remove_seq_pair: Removes an existing sequence pair input.
validate_inputs: Validates the user inputs to ensure they are correct.
get_specific_options: Retrieves the specific options set by the user.
run_analysis: Validates inputs, merges configurations, and starts the prediction job.
get_seq_pairs: Retrieves the sequence pairs input by the user.
"""
def __init__(self, job_manager, general_options_getter=None, *args, **kwargs):
"""
Initialize the MakePredictionsWidget with a job manager and optional general options getter.
Args:
job_manager: The manager responsible for handling job execution.
general_options_getter: An optional callable to retrieve general analysis options.
*args: Additional arguments to pass to the parent class.
**kwargs: Additional keyword arguments to pass to the parent class.
"""
super().__init__(job_manager, *args, **kwargs)
self.general_options_getter = general_options_getter
self.init_ui()
[docs]
def init_ui(self):
"""
Initializes the user interface components and layout for the widget.
"""
layout = QGridLayout()
# Engine
self.engine_label = QLabel("Engine:")
self.engine_dropdown = QComboBox()
self.engine_dropdown.addItems(["alphafold2"])
self.setStyleSheet("""
QToolBar {
background-color: #333333;
color: white;
padding: 10px;
}
QPushButton {
background-color: #555555;
color: white;
border: none;
padding: 8px 16px;
border-radius: 4px;
margin: 0 5px;
}
QPushButton:hover {
background-color: #666666;
}
QPushButton:pressed {
background-color: #777777;
}
""")
self.job_name_label = QLabel("Job name:")
self.job_name_input = QLineEdit()
# MSA Path
self.msa_path_label = QLabel("MSA Path:")
self.msa_path_input = QLineEdit()
self.msa_path_button = QPushButton("Select MSA File")
self.msa_path_button.clicked.connect(self.select_msa_path)
# MSA From
self.msa_from_label = QLabel("MSA From:")
self.msa_from_dropdown = QComboBox()
self.msa_from_dropdown.addItems(["mmseqs2", "jackhmmer"])
# Seq Pairs
self.seq_pairs_label = QLabel("max_seq:extra_seq pairs:")
self.seq_pairs_layout = QVBoxLayout()
self.add_seq_pair_button = QPushButton("Add Pair")
self.add_seq_pair_button.clicked.connect(lambda: self.add_seq_pair(seq1="", seq2=""))
# Seeds
self.seeds_label = QLabel("Seeds:")
self.seeds_input = QLineEdit("10")
# Platform
self.platform_label = QLabel("Platform:")
self.platform_dropdown = QComboBox()
self.platform_dropdown.addItems(["cpu", "gpu"])
# Save All
self.save_all_label = QLabel("Save All:")
self.save_all_checkbox = QCheckBox()
self.save_all_checkbox.setChecked(False)
# Subset MSA To
self.subset_msa_to_label = QLabel("Max MSA Depth:")
self.subset_msa_to_input = QLineEdit("")
# Output Path
self.output_path_label = QLabel("Output Path:")
self.output_path_input = QLineEdit("")
self.output_path_button = QPushButton("Browse")
# Adding widgets to the layout
layout.addWidget(self.engine_label, 0, 0)
layout.addWidget(self.engine_dropdown, 0, 1)
layout.addWidget(self.msa_path_label, 1, 0)
layout.addWidget(self.msa_path_input, 1, 1)
layout.addWidget(self.msa_path_button, 1, 2)
layout.addWidget(self.msa_from_label, 2, 0)
layout.addWidget(self.msa_from_dropdown, 2, 1)
layout.addWidget(self.seq_pairs_label, 3, 0)
layout.addLayout(self.seq_pairs_layout, 3, 1, 1, 2)
layout.addWidget(self.add_seq_pair_button, 4, 1)
layout.addWidget(self.seeds_label, 5, 0)
layout.addWidget(self.seeds_input, 5, 1)
layout.addWidget(self.platform_label, 6, 0)
layout.addWidget(self.platform_dropdown, 6, 1)
layout.addWidget(self.save_all_label, 7, 0)
layout.addWidget(self.save_all_checkbox, 7, 1)
layout.addWidget(self.subset_msa_to_label, 10, 0)
layout.addWidget(self.subset_msa_to_input, 10, 1)
layout.addWidget(self.output_path_label, 11, 0)
layout.addWidget(self.output_path_input, 11, 1)
layout.addWidget(self.output_path_button, 11, 2)
layout.addWidget(self.job_name_label, 12, 0)
layout.addWidget(self.job_name_input, 12, 1)
self.setLayout(layout)
self.setWindowTitle("Advanced MSA Options")
self.add_seq_pair(seq1="256", seq2="512")
# Run Button
self.run_button = QPushButton("Run")
self.run_button.clicked.connect(lambda: self.run_analysis())
self.output_path_button.clicked.connect(self.select_output_path)
layout.addWidget(self.run_button, 13, 1)
[docs]
def select_msa_path(self):
"""
Opens a file dialog to allow the user to select the MSA file.
"""
options = QFileDialog.Options()
options |= QFileDialog.DontUseNativeDialog
file_path, _ = QFileDialog.getOpenFileName(self, "Select MSA File", "", "MSA Files (*.a3m);;All Files (*)", options=options)
if file_path:
self.msa_path_input.setText(file_path)
[docs]
def select_output_path(self):
"""
Opens a file dialog to allow the user to select the output directory.
"""
directory = QFileDialog.getExistingDirectory(self, "Select Directory")
if directory:
self.output_path_input.setText(directory)
[docs]
def add_seq_pair(self, seq1="", seq2=""):
"""
Adds a new sequence pair input to the interface.
Args:
seq1: The first sequence in the pair.
seq2: The second sequence in the pair.
"""
seq_pair_layout = QHBoxLayout()
seq1_input = QLineEdit(seq1)
seq1_input.setPlaceholderText("Sequence 1")
seq2_input = QLineEdit(seq2)
seq2_input.setPlaceholderText("Sequence 2")
remove_button = QPushButton("Remove")
remove_button.clicked.connect(lambda: self.remove_seq_pair(seq_pair_layout))
seq_pair_layout.addWidget(seq1_input)
seq_pair_layout.addWidget(seq2_input)
seq_pair_layout.addWidget(remove_button)
self.seq_pairs_layout.addLayout(seq_pair_layout)
[docs]
def remove_seq_pair(self, layout):
"""
Removes an existing sequence pair input from the interface.
Args:
layout: The layout containing the sequence pair to be removed.
"""
while layout.count():
child = layout.takeAt(0)
if child.widget():
child.widget().deleteLater()
self.seq_pairs_layout.removeItem(layout)
layout.deleteLater()
[docs]
def validate_inputs(self):
"""
Validates the user inputs to ensure they are correct.
Returns:
A list of error messages if validation fails, otherwise an empty list.
"""
errors = []
if not self.msa_path_input.text():
errors.append("MSA Path cannot be empty.")
if not self.seeds_input.text().isdigit():
errors.append("Seeds must be a number.")
if self.subset_msa_to_input.text() and not self.subset_msa_to_input.text().isdigit():
errors.append("Subset MSA To must be a number.")
return errors
[docs]
def get_specific_options(self):
"""
Retrieves the specific options set by the user.
Returns:
A dictionary containing the options for the prediction job.
"""
return {
'msa_path': self.msa_path_input.text(),
'output_path': self.output_path_input.text(),
'jobname': self.job_name_input.text(),
'seq_pairs': self.get_seq_pairs(),
'seeds': int(self.seeds_input.text()),
'save_all': self.save_all_checkbox.isChecked(),
'platform': self.platform_dropdown.currentText(),
'subset_msa_to': int(self.subset_msa_to_input.text()) if self.subset_msa_to_input.text() else None,
'msa_from': self.msa_from_dropdown.currentText()
}
[docs]
def run_analysis(self):
"""
Validates inputs, merges configurations, and starts the prediction job.
"""
errors = self.validate_inputs()
if errors:
self.show_error_message(errors)
return
try:
g_options = self.general_options_getter()
except Exception:
g_options = {}
specific_options = self.get_specific_options()
config = merge_configs(g_options, specific_options)
job_id = self.job_manager.run_job(run_ensemble_prediction, (config,), config['jobname'])
self.show_info_message(f"Job {config['jobname']} started.")
[docs]
def get_seq_pairs(self):
"""
Retrieves the sequence pairs input by the user.
Returns:
A list of sequence pairs, each represented as a list of two integers.
"""
seq_pairs = []
for i in range(self.seq_pairs_layout.count()):
layout = self.seq_pairs_layout.itemAt(i).layout()
seq1 = layout.itemAt(0).widget().text()
seq2 = layout.itemAt(1).widget().text()
seq_pairs.append([int(seq1), int(seq2)])
return seq_pairs