From 3b66d163b4996a49b83476b8e710168b8d4933a3 Mon Sep 17 00:00:00 2001 From: sambra95 Date: Fri, 24 Jan 2025 14:49:13 +0100 Subject: [PATCH] zero shot loading bug --- app/app.py | 46 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/app/app.py b/app/app.py index 66dde7c4..b54d8bf9 100755 --- a/app/app.py +++ b/app/app.py @@ -20,6 +20,7 @@ from shiny import App, Inputs, Outputs, Session, reactive, render, ui from shiny.types import FileInfo, ImgData import shinywidgets as widgets +from proteusAI.io_tools.fasta import hash_sequence import proteusAI as pai @@ -1384,6 +1385,14 @@ async def _(): DATASET_PATH.set(data_path) MODE.set("zero-shot") + # if chains is list and len > 1, select first chain + if prot.chains == []: + chain = prot.chains + seq_hash = hash_sequence(prot.seq) + elif isinstance(prot.chains, list): + chain = prot.chains[0] + seq_hash = hash_sequence(prot.seq[chain]) + # check for zs-computations # TODO: test if the number of computations match with the number of sequences. zs_computed = [] rep_computed = [] @@ -1394,7 +1403,7 @@ async def _(): ) if os.path.exists(zs_path): - if "zs_scores.csv" in os.listdir(zs_path): + if f"{seq_hash}_zs_scores.csv" in os.listdir(zs_path): zs_computed.append(model) for rep in REP_TYPES: @@ -1410,9 +1419,18 @@ async def _(): rep_path = os.path.join( prot.user, f"{prot.name}/zero_shot/rep/{REP_DICT[model]}" ) + print(seq_hash) + print( + os.listdir( + os.path.join( + prot.user, + f"{prot.name}/zero_shot/results/{REP_DICT[model]}/", + ) + ) + ) df_path = os.path.join( prot.user, - f"{prot.name}/zero_shot/results/{REP_DICT[model]}/zs_scores.csv", + f"{prot.name}/zero_shot/results/{REP_DICT[model]}/{seq_hash}_zs_scores.csv", ) if os.path.exists(df_path): @@ -1539,9 +1557,10 @@ async def _(): # load zs_scores for model in ZS_MODELS: for chain in prot.chains: + seq_hash = hash_sequence(prot.seq[chain]) df_path = os.path.join( prot.user, - f"{name}/zero_shot/results/{chain}/{REP_DICT[model]}/zs_scores.csv", + f"{name}/zero_shot/results/{chain}/{REP_DICT[model]}/{seq_hash}_zs_scores.csv", ) if os.path.exists(df_path): df = pd.read_csv(df_path) # noqa: F841 @@ -1894,10 +1913,12 @@ def _(): if chain is not None and len(prot.chains) >= 1: chain = input.zs_chain() - s = f"{name}/zero_shot/results/{chain}/{REP_DICT[method]}/zs_scores.csv" + seq_hash = hash_sequence(prot.seq[chain]) + s = f"{name}/zero_shot/results/{chain}/{REP_DICT[method]}/{seq_hash}_zs_scores.csv" else: chain = None - s = f"{name}/zero_shot/results/{REP_DICT[method]}/zs_scores.csv" + seq_hash = hash_sequence(prot.seq) + s = f"{name}/zero_shot/results/{REP_DICT[method]}/{seq_hash}_zs_scores.csv" df_path = os.path.join(prot.user, s) if os.path.exists(df_path): @@ -1916,11 +1937,19 @@ def zs_df(alt=None): prot = PROTEIN() method = REP_DICT[input.computed_zs_scores()] if prot.chains is not None and len(prot.chains) >= 1: + seq_hash = hash_sequence(prot.seq[input.zs_chain()]) path = os.path.join( - prot.zs_path, "results", input.zs_chain(), method, "zs_scores.csv" + prot.zs_path, + "results", + input.zs_chain(), + method, + f"{seq_hash}_zs_scores.csv", ) else: - path = os.path.join(prot.zs_path, "results", method, "zs_scores.csv") + seq_hash = hash_sequence(prot.seq) + path = os.path.join( + prot.zs_path, "results", method, f"{seq_hash}_zs_scores.csv" + ) df = pd.read_csv(path) try: @@ -2277,9 +2306,10 @@ def mlde_dynamic_ui(): computed_zs = [] for model in ZS_MODELS: + seq_hash = hash_sequence(prot.seq[input.mlde_chain()]) df_path = os.path.join( prot.user, - f"{name}/zero_shot/results/{chain}/{REP_DICT[model]}/zs_scores.csv", + f"{name}/zero_shot/results/{chain}/{REP_DICT[model]}/{seq_hash}_zs_scores.csv", ) if os.path.exists(df_path): computed_zs.append(model)