Source code for fast_conformation.gui.plot_widget

import sys
import numpy as np
import pyqtgraph as pg
from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QVBoxLayout, QWidget
)
from pyqtgraph.colormap import ColorMap
from matplotlib.colors import Normalize

[docs] class PlotWidget(pg.GraphicsLayoutWidget): """ PlotWidget is a custom widget for plotting data using pyqtgraph. It supports both line plots and scatter plots with optional color mapping and colorbars. Methods: add_plot: Adds a new plot to the widget. add_borders: Adds borders to the plot. add_scatter: Adds a scatter plot to the plot item. add_line: Adds a line plot to the plot item. add_colorbar: Adds a colorbar to the widget. """ def __init__(self, parent=None): """ Initializes the PlotWidget. Args: parent: The parent widget, if any. """ super().__init__(parent) self.plots = [] self.setBackground('w')
[docs] def add_plot(self, x_data, y_data, title, x_label, y_label, color=None, label=None, resids=None, scatter=False, colorbar=False): """ Adds a new plot to the widget. Args: x_data: The data for the x-axis. y_data: The data for the y-axis. title: The title of the plot. x_label: The label for the x-axis. y_label: The label for the y-axis. color: The color of the line or scatter points. label: The label for the legend. resids: Residual values for coloring scatter points. scatter: Whether to create a scatter plot. colorbar: Whether to add a colorbar for the scatter plot. Returns: The created plot item. """ plot_item = self.addPlot(title=title) plot_item.addLegend() self.setBackground('w') if scatter and colorbar: scatter = self.add_scatter(plot_item, x_data, y_data, resids, colorbar=True) colorbar = self.add_colorbar(resids) elif scatter: scatter = self.add_scatter(plot_item, x_data, y_data, resids) else: self.add_line(plot_item, x_data, y_data, color, label) plot_item.setLabel('left', y_label) plot_item.setLabel('bottom', x_label) plot_item.showGrid(x=True, y=True) self.nextRow() self.plots.append(plot_item) self.add_borders(plot_item) return plot_item
[docs] def add_borders(self, plot): """ Adds borders to the given plot. Args: plot: The plot item to which borders will be added. """ plot.getViewBox().setBorder(pg.mkPen(color='lightgrey', width=1))
[docs] def add_scatter(self, plot_item, x_data, y_data, resids=None, color='b', colorbar=False, label=None): """ Adds a scatter plot to the given plot item. Args: plot_item: The plot item to which the scatter plot will be added. x_data: The data for the x-axis. y_data: The data for the y-axis. resids: Residual values for coloring the scatter points. color: The color of the scatter points. colorbar: Whether to add a colorbar. label: The label for the legend. Returns: The created scatter plot item. """ if colorbar and (resids is not None): colors = np.array([[68, 1, 84, 255], [58, 82, 139, 255], [32, 144, 140, 255], [94, 201, 97, 255], [253, 231, 37, 255]]) norm = Normalize(vmin=resids.min(), vmax=resids.max()) colormap = ColorMap(pos=np.linspace(0, 1, len(colors)), color=colors) brushes = [pg.mkBrush(*colormap.map(norm(resid))) for resid in resids] else: pg_color = pg.mkColor(color) brushes = pg.mkBrush(pg_color) scatter = pg.ScatterPlotItem(x=x_data, y=y_data, symbol='o', symbolSize=5, brush=brushes, name=label, pen=None) plot_item.addItem(scatter) return scatter
[docs] def add_line(self, plot_item, x_data, y_data, color, label, lstyle=None): """ Adds a line plot to the given plot item. Args: plot_item: The plot item to which the line plot will be added. x_data: The data for the x-axis. y_data: The data for the y-axis. color: The color of the line. label: The label for the legend. lstyle: The line style (e.g., solid, dashed). """ plot_item.plot(x_data, y_data, pen=pg.mkPen(color=color, width=2), name=label, linestyle=lstyle)
[docs] def add_colorbar(self, resids): """ Adds a colorbar to the widget based on residual values. Args: resids: Residual values to map the colorbar. Returns: The created colorbar item. """ colors = np.array([[68, 1, 84, 255], [58, 82, 139, 255], [32, 144, 140, 255], [94, 201, 97, 255], [253, 231, 37, 255]]) positions = np.linspace(0, 1, len(colors)) colormap = ColorMap(pos=positions, color=colors) colorbar = pg.ColorBarItem(values=(resids.min(), resids.max()), colorMap=colormap, label='Residue #') self.addItem(colorbar, row=len(self.plots), col=1) return colorbar