From e62e9ba8739a19742b2a0d19517539e313a5950d Mon Sep 17 00:00:00 2001 From: sambra95 Date: Thu, 23 Jan 2025 15:32:58 +0100 Subject: [PATCH] :sparkles: plotting functions --- app/app.py | 15 +- src/proteusAI/ml_tools/esm_tools/esm_tools.py | 182 +++++++++--------- src/proteusAI/visual_tools/plots.py | 1 + 3 files changed, 98 insertions(+), 100 deletions(-) diff --git a/app/app.py b/app/app.py index f086bc6a..4dcd98b3 100755 --- a/app/app.py +++ b/app/app.py @@ -19,9 +19,11 @@ import tooltips from shiny import App, Inputs, Outputs, Session, reactive, render, ui from shiny.types import FileInfo, ImgData +import shinywidgets as widgets import proteusAI as pai + app_path = os.path.dirname(os.path.realpath(__file__)) google_analytics = os.path.join(app_path, "google_analytics.html") google_analytics_string = "" @@ -1990,8 +1992,8 @@ def zs_download_ui(): out = ZS_SCORES() if out is not None: return ui.TagList( - ui.output_plot("entropy_plot"), - ui.output_plot("scores_plot"), + widgets.output_widget("entropy_plot"), + widgets.output_widget("heatmap_plot"), # Descriptors ui.h4("Table Interpretation"), ui.row( @@ -2034,7 +2036,7 @@ def download_zs_df(): ### OUTPUT PROTEIN MODE ### @output - @render.plot + @widgets.render_widget @reactive.event(input.plot_entropy) def entropy_plot(alt=None): """ @@ -2068,11 +2070,11 @@ def entropy_plot(alt=None): ### UPDATE SCORES PLOT ### @output - @render.plot + @widgets.render_widget @reactive.event(input.plot_scores) - def scores_plot(alt=None): + def heatmap_plot(alt=None): """ - Create the per position entropy plot + Create heatmap of zeroshot scores """ prot = PROTEIN() @@ -2103,6 +2105,7 @@ def scores_plot(alt=None): model=MODEL_DICT[input.computed_zs_scores()], chain=chain, ) + print(fig) return fig ### STRUCTURE MODE ### diff --git a/src/proteusAI/ml_tools/esm_tools/esm_tools.py b/src/proteusAI/ml_tools/esm_tools/esm_tools.py index 0f971733..2f811484 100755 --- a/src/proteusAI/ml_tools/esm_tools/esm_tools.py +++ b/src/proteusAI/ml_tools/esm_tools/esm_tools.py @@ -36,6 +36,7 @@ ) from esm.inverse_folding.util import CoordBatchConverter from matplotlib.colors import LinearSegmentedColormap +import plotly.graph_objects as go alphabet = torch.load( os.path.join(Path(__file__).parent, "alphabet.pt"), weights_only=False @@ -1037,79 +1038,63 @@ def plot_heatmap( df = pd.DataFrame( probability_distribution_np[:, list(filtered_alphabet.values())], columns=[i for i in filtered_alphabet.keys()], - ) + ).T - # Calculate the symmetric vmin and vmax around zero - abs_max = max(abs(df.min().min()), abs(df.max().max())) - - # colors - if color_sheme == "rwb": - colors = ["red", "white", "blue"] - cmap = LinearSegmentedColormap.from_list("red_white_blue", colors) - vmin, vmax = -abs_max, abs_max # Symmetric range around zero - elif color_sheme == "r": - cmap = "Reds" - vmin, vmax = None, None # No centering for single color schemes - else: - cmap = "Blues" - vmin, vmax = None, None # No centering for single color schemes - - # Create a heatmap using seaborn - fig, ax = plt.subplots(figsize=(12, 6)) - - ax = plt.gca() - sns.heatmap( - df.T, - cmap=cmap, - linewidths=0.5, - annot=False, - cbar=True, - ax=ax, - vmin=vmin, - vmax=vmax, - ) + data = df.copy() - # Adjust x-axis ticks and labels if 'section' is specified - if section is not None: - ax.set_xticks( - np.arange( - 0, section[1] - section[0] + 1, max(1, (section[1] - section[0]) // 10) - ) - ) - ax.set_xticklabels( - np.arange( - section[0], section[1] + 1, max(1, (section[1] - section[0]) // 10) - ) - ) + # Create the heatmap using Plotly + fig = go.Figure() - # Highlight the mutated positions with a green box - if highlight_positions is not None: - for pos, mutated_residue in highlight_positions.items(): - residue_index = filtered_alphabet[mutated_residue] - rect = patches.Rectangle( - (pos, residue_index - min_val), - 1, - 1, - linewidth=1, - edgecolor="lime", - facecolor="none", - ) - ax.add_patch(rect) + # Create a custom colorscale for the heatmap + custom_colorscale = [ + [0.0, "red"], # Minimum value (deep red) + [0.5, "white"], # Zero value (white) + [1.0, "blue"], # Maximum value (deep blue) + ] - plt.xlabel("Sequence Position") - plt.ylabel("Residue") - if title is None: - plt.title("Per-Position Probability Distribution Heatmap") - else: - plt.title(title) + # Add the main heatmap + fig.add_trace( + go.Heatmap( + z=data.values, + x=data.columns, + y=data.index, + colorscale=custom_colorscale, + zmid=0, # Ensure 0 is centered on the colorscale + showscale=True, + hovertemplate=( + "Position: %{x}
" + + "Amino Acid: %{y}
" + + "Zero Shot Score: %{z}
" + + "" # Hides the trace name in the hover box + ), + ) + ) - # Save the plot to the specified destination, if provided - if dest is not None: - plt.savefig(dest, dpi=400, bbox_inches="tight") + # Overlay black boxes for cells with a value of exactly 0 + for i, amino_acid in enumerate(data.index): + for j, position in enumerate(data.columns): + if data.at[amino_acid, position] == 0: + fig.add_shape( + type="rect", + x0=j - 0.5, + x1=j + 0.5, + y0=i - 0.5, + y1=i + 0.5, + line=dict(color="black", width=0.5), + ) - # Show the plot, if the 'show' argument is True - if show: - plt.show() + # Customize the layout for better interaction + fig.update_layout( + title="Zero-Shot Scores", + xaxis=dict( + title="Sequence Position", + tickmode="array", + tickvals=list(range(len(data.columns))), + ticktext=[x + 1 for x in data.columns], + automargin=True, + ), + yaxis=dict(title="Amino Acid Type", automargin=True), + ) return fig @@ -1125,20 +1110,20 @@ def plot_per_position_entropy( use_normal_ticks: bool = True, ): """ - Plot the per position entropy for a given sequence. + Plot the per position entropy for a given sequence using Plotly. Args: per_position_entropy (torch.Tensor): Tensor of per position entropy values with shape (batch_size, sequence_length). sequence (str): Protein sequence. highlight_positions (list): List of positions to highlight in red (0-indexed) (default: None). show (bool): Display the plot if True (default: False). - dest (str): Optional path to save the plot as an image file (default: None). + dest (str): Optional path to save the plot as an HTML file (default: None). title (str): Title of plot. section (tuple): Section of the sequence to plot (default: None). If None, the entire sequence is plotted. use_normal_ticks (bool): If True, use normal numerical ticks for x-axis (default: False). Returns: - matplotlib.figure.Figure: The matplotlib figure object for the plot. + plotly.graph_objects.Figure: The Plotly figure object for the plot. """ # Convert the tensor to a numpy array @@ -1160,40 +1145,49 @@ def plot_per_position_entropy( # Create an array of positions for the x-axis positions = np.arange(len(sequence)) - # Dynamic tick interval - tick_interval = 1 if len(sequence) <= 30 else int(len(sequence) / 20) - - # Create a bar plot of per position entropy - fig = plt.figure(figsize=(20, 6)) - + # Determine bar colors if highlight_positions is None: - plt.bar(positions, per_position_entropy_np.squeeze()) + colors = ["blue"] * len(positions) else: colors = ["red" if pos in highlight_positions else "blue" for pos in positions] - plt.bar(positions, per_position_entropy_np.squeeze(), color=colors) - # Set the x-axis labels - if use_normal_ticks: - plt.xticks(positions[::tick_interval]) - else: - plt.xticks( - positions[::tick_interval], - [sequence[i] for i in positions[::tick_interval]], + # Create the interactive bar chart + fig = go.Figure() + + fig.add_trace( + go.Bar( + x=positions if use_normal_ticks else [sequence[i] for i in positions], + y=per_position_entropy_np.squeeze(), + marker_color=colors, + hovertemplate=( + "Position: %{x}
" "Per Position Entropy: %{y}
" "" + ), ) + ) - # Set the labels and title - plt.xlabel("Sequence Position") - plt.ylabel("Per Position Entropy") - if title is None: - plt.title("Per Position Entropy of Sequence") - else: - plt.title(title) + # Customize the layout + fig.update_layout( + title=title or "Per Position Entropy of Sequence", + xaxis=dict( + title="Sequence Position" if use_normal_ticks else "Amino Acid", + tickmode="array", + tickvals=positions, + ticktext=( + positions + 1 if use_normal_ticks else [sequence[i] for i in positions] + ), + ), + yaxis=dict(title="Per Position Entropy"), + bargap=0.1, + plot_bgcolor="white", # Set the plot background to white + paper_bgcolor="white", # Set the overall figure background to white + ) + # Save the plot as an HTML file if a destination is provided if dest is not None: - plt.savefig(dest, dpi=300) + fig.write_html(dest) # Show the plot if show: - plt.show() + fig.show() return fig diff --git a/src/proteusAI/visual_tools/plots.py b/src/proteusAI/visual_tools/plots.py index a01d1a2f..dbb67b21 100755 --- a/src/proteusAI/visual_tools/plots.py +++ b/src/proteusAI/visual_tools/plots.py @@ -230,6 +230,7 @@ def plot_umap( highlight_mask: Union[List[Union[int, float]], None] = None, highlight_label: str = "Highlighted", df: Union[pd.DataFrame, None] = None, + html: bool = False, ): """ Create a UMAP plot and optionally color by y values, with special coloring for points outside given thresholds.