Skip to content

Commit

Permalink
✨ plotting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sambra95 committed Jan 23, 2025
1 parent 8536989 commit e62e9ba
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 100 deletions.
15 changes: 9 additions & 6 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import tooltips
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
from shiny.types import FileInfo, ImgData
import shinywidgets as widgets

import proteusAI as pai


app_path = os.path.dirname(os.path.realpath(__file__))
google_analytics = os.path.join(app_path, "google_analytics.html")
google_analytics_string = ""
Expand Down Expand Up @@ -1990,8 +1992,8 @@ def zs_download_ui():
out = ZS_SCORES()
if out is not None:
return ui.TagList(
ui.output_plot("entropy_plot"),
ui.output_plot("scores_plot"),
widgets.output_widget("entropy_plot"),
widgets.output_widget("heatmap_plot"),
# Descriptors
ui.h4("Table Interpretation"),
ui.row(
Expand Down Expand Up @@ -2034,7 +2036,7 @@ def download_zs_df():

### OUTPUT PROTEIN MODE ###
@output
@render.plot
@widgets.render_widget
@reactive.event(input.plot_entropy)
def entropy_plot(alt=None):
"""
Expand Down Expand Up @@ -2068,11 +2070,11 @@ def entropy_plot(alt=None):

### UPDATE SCORES PLOT ###
@output
@render.plot
@widgets.render_widget
@reactive.event(input.plot_scores)
def scores_plot(alt=None):
def heatmap_plot(alt=None):
"""
Create the per position entropy plot
Create heatmap of zeroshot scores
"""
prot = PROTEIN()

Expand Down Expand Up @@ -2103,6 +2105,7 @@ def scores_plot(alt=None):
model=MODEL_DICT[input.computed_zs_scores()],
chain=chain,
)
print(fig)
return fig

### STRUCTURE MODE ###
Expand Down
182 changes: 88 additions & 94 deletions src/proteusAI/ml_tools/esm_tools/esm_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from esm.inverse_folding.util import CoordBatchConverter
from matplotlib.colors import LinearSegmentedColormap
import plotly.graph_objects as go

alphabet = torch.load(
os.path.join(Path(__file__).parent, "alphabet.pt"), weights_only=False
Expand Down Expand Up @@ -1037,79 +1038,63 @@ def plot_heatmap(
df = pd.DataFrame(
probability_distribution_np[:, list(filtered_alphabet.values())],
columns=[i for i in filtered_alphabet.keys()],
)
).T

# Calculate the symmetric vmin and vmax around zero
abs_max = max(abs(df.min().min()), abs(df.max().max()))

# colors
if color_sheme == "rwb":
colors = ["red", "white", "blue"]
cmap = LinearSegmentedColormap.from_list("red_white_blue", colors)
vmin, vmax = -abs_max, abs_max # Symmetric range around zero
elif color_sheme == "r":
cmap = "Reds"
vmin, vmax = None, None # No centering for single color schemes
else:
cmap = "Blues"
vmin, vmax = None, None # No centering for single color schemes

# Create a heatmap using seaborn
fig, ax = plt.subplots(figsize=(12, 6))

ax = plt.gca()
sns.heatmap(
df.T,
cmap=cmap,
linewidths=0.5,
annot=False,
cbar=True,
ax=ax,
vmin=vmin,
vmax=vmax,
)
data = df.copy()

# Adjust x-axis ticks and labels if 'section' is specified
if section is not None:
ax.set_xticks(
np.arange(
0, section[1] - section[0] + 1, max(1, (section[1] - section[0]) // 10)
)
)
ax.set_xticklabels(
np.arange(
section[0], section[1] + 1, max(1, (section[1] - section[0]) // 10)
)
)
# Create the heatmap using Plotly
fig = go.Figure()

# Highlight the mutated positions with a green box
if highlight_positions is not None:
for pos, mutated_residue in highlight_positions.items():
residue_index = filtered_alphabet[mutated_residue]
rect = patches.Rectangle(
(pos, residue_index - min_val),
1,
1,
linewidth=1,
edgecolor="lime",
facecolor="none",
)
ax.add_patch(rect)
# Create a custom colorscale for the heatmap
custom_colorscale = [
[0.0, "red"], # Minimum value (deep red)
[0.5, "white"], # Zero value (white)
[1.0, "blue"], # Maximum value (deep blue)
]

plt.xlabel("Sequence Position")
plt.ylabel("Residue")
if title is None:
plt.title("Per-Position Probability Distribution Heatmap")
else:
plt.title(title)
# Add the main heatmap
fig.add_trace(
go.Heatmap(
z=data.values,
x=data.columns,
y=data.index,
colorscale=custom_colorscale,
zmid=0, # Ensure 0 is centered on the colorscale
showscale=True,
hovertemplate=(
"Position: %{x}<br>"
+ "Amino Acid: %{y}<br>"
+ "Zero Shot Score: %{z}<br>"
+ "<extra></extra>" # Hides the trace name in the hover box
),
)
)

# Save the plot to the specified destination, if provided
if dest is not None:
plt.savefig(dest, dpi=400, bbox_inches="tight")
# Overlay black boxes for cells with a value of exactly 0
for i, amino_acid in enumerate(data.index):
for j, position in enumerate(data.columns):
if data.at[amino_acid, position] == 0:
fig.add_shape(
type="rect",
x0=j - 0.5,
x1=j + 0.5,
y0=i - 0.5,
y1=i + 0.5,
line=dict(color="black", width=0.5),
)

# Show the plot, if the 'show' argument is True
if show:
plt.show()
# Customize the layout for better interaction
fig.update_layout(
title="Zero-Shot Scores",
xaxis=dict(
title="Sequence Position",
tickmode="array",
tickvals=list(range(len(data.columns))),
ticktext=[x + 1 for x in data.columns],
automargin=True,
),
yaxis=dict(title="Amino Acid Type", automargin=True),
)

return fig

Expand All @@ -1125,20 +1110,20 @@ def plot_per_position_entropy(
use_normal_ticks: bool = True,
):
"""
Plot the per position entropy for a given sequence.
Plot the per position entropy for a given sequence using Plotly.
Args:
per_position_entropy (torch.Tensor): Tensor of per position entropy values with shape (batch_size, sequence_length).
sequence (str): Protein sequence.
highlight_positions (list): List of positions to highlight in red (0-indexed) (default: None).
show (bool): Display the plot if True (default: False).
dest (str): Optional path to save the plot as an image file (default: None).
dest (str): Optional path to save the plot as an HTML file (default: None).
title (str): Title of plot.
section (tuple): Section of the sequence to plot (default: None). If None, the entire sequence is plotted.
use_normal_ticks (bool): If True, use normal numerical ticks for x-axis (default: False).
Returns:
matplotlib.figure.Figure: The matplotlib figure object for the plot.
plotly.graph_objects.Figure: The Plotly figure object for the plot.
"""

# Convert the tensor to a numpy array
Expand All @@ -1160,40 +1145,49 @@ def plot_per_position_entropy(
# Create an array of positions for the x-axis
positions = np.arange(len(sequence))

# Dynamic tick interval
tick_interval = 1 if len(sequence) <= 30 else int(len(sequence) / 20)

# Create a bar plot of per position entropy
fig = plt.figure(figsize=(20, 6))

# Determine bar colors
if highlight_positions is None:
plt.bar(positions, per_position_entropy_np.squeeze())
colors = ["blue"] * len(positions)
else:
colors = ["red" if pos in highlight_positions else "blue" for pos in positions]
plt.bar(positions, per_position_entropy_np.squeeze(), color=colors)

# Set the x-axis labels
if use_normal_ticks:
plt.xticks(positions[::tick_interval])
else:
plt.xticks(
positions[::tick_interval],
[sequence[i] for i in positions[::tick_interval]],
# Create the interactive bar chart
fig = go.Figure()

fig.add_trace(
go.Bar(
x=positions if use_normal_ticks else [sequence[i] for i in positions],
y=per_position_entropy_np.squeeze(),
marker_color=colors,
hovertemplate=(
"Position: %{x}<br>" "Per Position Entropy: %{y}<br>" "<extra></extra>"
),
)
)

# Set the labels and title
plt.xlabel("Sequence Position")
plt.ylabel("Per Position Entropy")
if title is None:
plt.title("Per Position Entropy of Sequence")
else:
plt.title(title)
# Customize the layout
fig.update_layout(
title=title or "Per Position Entropy of Sequence",
xaxis=dict(
title="Sequence Position" if use_normal_ticks else "Amino Acid",
tickmode="array",
tickvals=positions,
ticktext=(
positions + 1 if use_normal_ticks else [sequence[i] for i in positions]
),
),
yaxis=dict(title="Per Position Entropy"),
bargap=0.1,
plot_bgcolor="white", # Set the plot background to white
paper_bgcolor="white", # Set the overall figure background to white
)

# Save the plot as an HTML file if a destination is provided
if dest is not None:
plt.savefig(dest, dpi=300)
fig.write_html(dest)

# Show the plot
if show:
plt.show()
fig.show()

return fig
1 change: 1 addition & 0 deletions src/proteusAI/visual_tools/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def plot_umap(
highlight_mask: Union[List[Union[int, float]], None] = None,
highlight_label: str = "Highlighted",
df: Union[pd.DataFrame, None] = None,
html: bool = False,
):
"""
Create a UMAP plot and optionally color by y values, with special coloring for points outside given thresholds.
Expand Down

0 comments on commit e62e9ba

Please sign in to comment.