Skip to content

Commit

Permalink
Merge pull request #27 from jonfunk21/refactoring_zs
Browse files Browse the repository at this point in the history
zero shot loading bug
  • Loading branch information
sambra95 authored Jan 24, 2025
2 parents 52cc692 + 3b66d16 commit b4584dc
Showing 1 changed file with 38 additions and 8 deletions.
46 changes: 38 additions & 8 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
from shiny.types import FileInfo, ImgData
import shinywidgets as widgets
from proteusAI.io_tools.fasta import hash_sequence

import proteusAI as pai

Expand Down Expand Up @@ -1384,6 +1385,14 @@ async def _():
DATASET_PATH.set(data_path)
MODE.set("zero-shot")

# if chains is list and len > 1, select first chain
if prot.chains == []:
chain = prot.chains
seq_hash = hash_sequence(prot.seq)
elif isinstance(prot.chains, list):
chain = prot.chains[0]
seq_hash = hash_sequence(prot.seq[chain])

# check for zs-computations # TODO: test if the number of computations match with the number of sequences.
zs_computed = []
rep_computed = []
Expand All @@ -1394,7 +1403,7 @@ async def _():
)

if os.path.exists(zs_path):
if "zs_scores.csv" in os.listdir(zs_path):
if f"{seq_hash}_zs_scores.csv" in os.listdir(zs_path):
zs_computed.append(model)

for rep in REP_TYPES:
Expand All @@ -1410,9 +1419,18 @@ async def _():
rep_path = os.path.join(
prot.user, f"{prot.name}/zero_shot/rep/{REP_DICT[model]}"
)
print(seq_hash)
print(
os.listdir(
os.path.join(
prot.user,
f"{prot.name}/zero_shot/results/{REP_DICT[model]}/",
)
)
)
df_path = os.path.join(
prot.user,
f"{prot.name}/zero_shot/results/{REP_DICT[model]}/zs_scores.csv",
f"{prot.name}/zero_shot/results/{REP_DICT[model]}/{seq_hash}_zs_scores.csv",
)

if os.path.exists(df_path):
Expand Down Expand Up @@ -1539,9 +1557,10 @@ async def _():
# load zs_scores
for model in ZS_MODELS:
for chain in prot.chains:
seq_hash = hash_sequence(prot.seq[chain])
df_path = os.path.join(
prot.user,
f"{name}/zero_shot/results/{chain}/{REP_DICT[model]}/zs_scores.csv",
f"{name}/zero_shot/results/{chain}/{REP_DICT[model]}/{seq_hash}_zs_scores.csv",
)
if os.path.exists(df_path):
df = pd.read_csv(df_path) # noqa: F841
Expand Down Expand Up @@ -1894,10 +1913,12 @@ def _():

if chain is not None and len(prot.chains) >= 1:
chain = input.zs_chain()
s = f"{name}/zero_shot/results/{chain}/{REP_DICT[method]}/zs_scores.csv"
seq_hash = hash_sequence(prot.seq[chain])
s = f"{name}/zero_shot/results/{chain}/{REP_DICT[method]}/{seq_hash}_zs_scores.csv"
else:
chain = None
s = f"{name}/zero_shot/results/{REP_DICT[method]}/zs_scores.csv"
seq_hash = hash_sequence(prot.seq)
s = f"{name}/zero_shot/results/{REP_DICT[method]}/{seq_hash}_zs_scores.csv"

df_path = os.path.join(prot.user, s)
if os.path.exists(df_path):
Expand All @@ -1916,11 +1937,19 @@ def zs_df(alt=None):
prot = PROTEIN()
method = REP_DICT[input.computed_zs_scores()]
if prot.chains is not None and len(prot.chains) >= 1:
seq_hash = hash_sequence(prot.seq[input.zs_chain()])
path = os.path.join(
prot.zs_path, "results", input.zs_chain(), method, "zs_scores.csv"
prot.zs_path,
"results",
input.zs_chain(),
method,
f"{seq_hash}_zs_scores.csv",
)
else:
path = os.path.join(prot.zs_path, "results", method, "zs_scores.csv")
seq_hash = hash_sequence(prot.seq)
path = os.path.join(
prot.zs_path, "results", method, f"{seq_hash}_zs_scores.csv"
)
df = pd.read_csv(path)

try:
Expand Down Expand Up @@ -2277,9 +2306,10 @@ def mlde_dynamic_ui():

computed_zs = []
for model in ZS_MODELS:
seq_hash = hash_sequence(prot.seq[input.mlde_chain()])
df_path = os.path.join(
prot.user,
f"{name}/zero_shot/results/{chain}/{REP_DICT[model]}/zs_scores.csv",
f"{name}/zero_shot/results/{chain}/{REP_DICT[model]}/{seq_hash}_zs_scores.csv",
)
if os.path.exists(df_path):
computed_zs.append(model)
Expand Down

0 comments on commit b4584dc

Please sign in to comment.