Skip to content

Commit

Permalink
Refactoring (#21)
Browse files Browse the repository at this point in the history
* models no longer saved

* simpler plotting

* refactoring

* refactoring

* refactoring

* version update
  • Loading branch information
jonfunk21 authored Dec 26, 2024
1 parent 54ca8c7 commit ce89e8f
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 147 deletions.
28 changes: 18 additions & 10 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2367,10 +2367,10 @@ async def train_mlde_model():

search_dest = os.path.join(
f"{model.library.rep_path}",
f"../models/{model.model_type}/{model.x}/predictions",
f"../models/{model.model_type}/{model.rep}/predictions",
)
search_file = os.path.join(
search_dest, f"{model.model_type}_{model.x}_predictions.csv"
search_dest, f"{model.model_type}_{model.rep}_predictions.csv"
)

if os.path.exists(search_file):
Expand Down Expand Up @@ -2451,9 +2451,9 @@ async def mlde_search():
model = MODEL()
mlde_explore = input.mlde_explore()
optim_problem = OPTIM_DICT[input.optim_problem()]
max_eval = MAX_EVAL_DICT[model.x]
max_eval = MAX_EVAL_DICT[model.rep]
acq_fn = ACQ_DICT[input.acquisition_fn()]
batch_size = BATCH_SIZE_DICT[model.x]
batch_size = BATCH_SIZE_DICT[model.rep]

try:
loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -2607,7 +2607,13 @@ async def discovery_train():
p.set(message="Visualizing results", detail="This may take a while...")

fig, ax, df = await loop.run_in_executor(
executor, model_lib.plot_umap, model.x, None, None, model_lib.names
executor,
model_lib.plot,
"umap",
model.rep,
None,
None,
model_lib.names,
)

# set reactive variables
Expand Down Expand Up @@ -2673,8 +2679,9 @@ async def clustering():

fig, ax, df = await loop.run_in_executor(
executor,
model_lib.plot_umap,
model.x,
model_lib.plot,
"umap",
model.rep,
None,
None,
model_lib.names,
Expand Down Expand Up @@ -2758,7 +2765,7 @@ async def discovery_search():

model = DISCOVERY_MODEL()
try:
batch_size = BATCH_SIZE_DICT[model.x]
batch_size = BATCH_SIZE_DICT[model.rep]
loop = asyncio.get_running_loop()
out, search_results = await loop.run_in_executor(
executor,
Expand All @@ -2780,8 +2787,9 @@ async def discovery_search():
print("We start here")
fig, ax, df = await loop.run_in_executor(
executor,
model.library.plot_umap,
model.x,
model.library.plot,
"umap",
model.rep,
None,
None,
model.library.names,
Expand Down
56 changes: 30 additions & 26 deletions demo/discovery_demo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import matplotlib.pyplot as plt

import proteusAI as pai

Expand All @@ -8,39 +9,42 @@


# will initiate storage space - else in memory
datasets = ["demo/demo_data/methyltransfereases.csv"]
y_columns = ["coverage_5"]
dataset = "demo/demo_data/methyltransfereases.csv"
y_column = "coverage_5"

results_dictionary = {}
for dataset in datasets:
for y in y_columns:
# load data from csv or excel: x should be sequences, y should be labels, y_type class or num
library = pai.Library(
source=dataset,
seqs_col="sequence",
y_col=y,
y_type="class",
names_col="uid",
)

# compute and save ESM-2 representations at example_lib/representations/esm2
library.compute(method="esm2", batch_size=10)
# load data from csv or excel: x should be sequences, y should be labels, y_type class or num
library = pai.Library(
source=dataset,
seqs_col="sequence",
y_col=y_column,
y_type="class",
names_col="uid",
)

# define a model
model = pai.Model(library=library, k_folds=5, model_type="rf", x="esm2")
# compute and save ESM-2 representations at example_lib/representations/esm2
library.compute(method="esm2_8M", batch_size=10)

# train model
model.train()
# define a model
model = pai.Model(library=library, k_folds=5, model_type="rf", rep="esm2_8M")

# search predict the classes of unknown sequences
out, mask = model.search()
# train model
model.train()

# save results
if not os.path.exists("demo/demo_data/out/"):
os.makedirs("demo/demo_data/out/", exist_ok=True)
# search predict the classes of unknown sequences
out, search_mask = model.search()

out["df"].to_csv("demo/demo_data/out/discovery_search_results.csv")
# save results
if not os.path.exists("demo/demo_data/out/"):
os.makedirs("demo/demo_data/out/", exist_ok=True)

model_lib = pai.Library(source=out)
out["df"].to_csv("demo/demo_data/out/discovery_search_results.csv")

model_lib.plot_tsne(model.x, None, None, model_lib.names)
model_lib = pai.Library(source=out)

# plot results
fig, ax, plot_df = model.library.plot(
rep="esm2_8M", use_y_pred=True, highlight_mask=search_mask
)
plt.savefig("demo/demo_data/out/search_results.png")
60 changes: 28 additions & 32 deletions demo/mlde_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,31 @@


# will initiate storage space - else in memory
datasets = ["demo/demo_data/Nitric_Oxide_Dioxygenase.csv"]
y_columns = ["Data"]

results_dictionary = {}
for dataset in datasets:
for y in y_columns:
# load data from csv or excel: x should be sequences, y should be labels, y_type class or num
library = pai.Library(
source=dataset,
seqs_col="Sequence",
y_col=y,
y_type="num",
names_col="Description",
)

# compute and save ESM-2 representations at example_lib/representations/esm2
library.compute(method="esm2", batch_size=10)

# define a model
model = pai.Model(library=library, k_folds=5, model_type="rf", x="blosum62")

# train model
model.train()

# search for new mutants
out = model.search(optim_problem="max")

# save results
if not os.path.exists("demo/demo_data/out/"):
os.makedirs("demo/demo_data/out/", exist_ok=True)

out.to_csv("demo/demo_data/out/demo_search_results.csv")
dataset = "demo/demo_data/Nitric_Oxide_Dioxygenase.csv"

# load data from csv or excel: x should be sequences, y should be labels, y_type class or num
library = pai.Library(
source=dataset,
seqs_col="Sequence",
y_col="Data",
y_type="num",
names_col="Description",
)

# compute and save ESM-2 representations at example_lib/representations/esm2
library.compute(method="esm2", batch_size=10)

# define a model
model = pai.Model(library=library, k_folds=5, model_type="rf", x="blosum62")

# train model
model.train()

# search for new mutants
out = model.search(optim_problem="max")

# save results
if not os.path.exists("demo/demo_data/out/"):
os.makedirs("demo/demo_data/out/", exist_ok=True)

out.to_csv("demo/demo_data/out/demo_search_results.csv")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "proteusAI"
version = "0.0.9"
version = "0.1.0"
requires-python = ">= 3.8"
description = "ProteusAI is a python package designed for AI driven protein engineering."
readme = "README.md"
Expand Down
60 changes: 60 additions & 0 deletions src/proteusAI/Library/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,66 @@ def struc_geom(self, ref, residues: dict = {}):

return df

def plot(
self,
rep: str,
method: str = "umap",
y_upper=None,
y_lower=None,
names=None,
highlight_mask=None,
highlight_label=None,
use_y_pred=False,
):
"""
Plot library data using specific representations.
Args:
rep (str): Representation type to plot.
method (str): Method for plotting. (umap, tsne, pca)
y_upper (float, optional): Upper threshold for special coloring.
y_lower (float, optional): Lower threshold for special coloring.
names (List[str], optional): List of names for each point.
highlight_mask (list): List of 0s and 1s to highlight plot. Default None.
highlight_label (str): Text for the legend entry of highlighted points.
"""

if method == "umap":
fig, ax, df = self.plot_umap(
rep,
y_upper=y_upper,
y_lower=y_lower,
names=names,
highlight_mask=highlight_mask,
highlight_label=highlight_label,
use_y_pred=use_y_pred,
)
elif method == "tsne":
fig, ax, df = self.plot_tsne(
rep,
y_upper=y_upper,
y_lower=y_lower,
names=names,
highlight_mask=highlight_mask,
highlight_label=highlight_label,
use_y_pred=use_y_pred,
)
elif method == "pca":
fig, ax, df = self.plot_pca(
rep,
y_upper=y_upper,
y_lower=y_lower,
names=names,
highlight_mask=highlight_mask,
highlight_label=highlight_label,
use_y_pred=use_y_pred,
)

else:
raise ValueError(f"Unsupported method: {method}")

return fig, ax, df

def plot_tsne(
self,
rep: str,
Expand Down
Loading

0 comments on commit ce89e8f

Please sign in to comment.