From 19420747aa9fcbfc060ec5ee6afc680759800438 Mon Sep 17 00:00:00 2001 From: jykr Date: Tue, 2 Apr 2024 17:37:13 -0500 Subject: [PATCH] fix replicate assignment --- bean/cli/filter.py | 40 ++++++++++++++++++++++++++++++++++++- bean/cli/run.py | 15 +++++++++----- bean/preprocessing/utils.py | 4 +--- 3 files changed, 50 insertions(+), 9 deletions(-) diff --git a/bean/cli/filter.py b/bean/cli/filter.py index bfbbdc7..c707c8c 100755 --- a/bean/cli/filter.py +++ b/bean/cli/filter.py @@ -194,5 +194,43 @@ def main(args): f"Saving plotting result and log at {args.output_prefix}.[filtered_allele_stats.pdf, filter_log.txt]." ) with open(f"{args.output_prefix}.filter_log.txt", "w") as out_log: + out_log.write( + "filter_step\tn_alleles\tn_var\tn_noncoding_var\tn_coding_var\tn_synonymous_var\n" + ) for key in allele_df_keys: - out_log.write(f"{key}\t{len(bdata.uns[key])}\n") + if "translate" in key: + n_coding_vars = len( + set().union( + *bdata.uns[key] + .aa_allele.map(lambda a: list(a.aa_allele.edits)) + .tolist() + ) + ) + n_syn_vars = len( + set().union( + *bdata.uns[key] + .aa_allele.map( + lambda a: {e for e in a.aa_allele.edits if e.ref == e.alt} + ) + .tolist() + ) + ) + n_noncoding_vars = len( + set().union( + *bdata.uns[key] + .aa_allele.map(lambda a: list(a.nt_allele.edits)) + .tolist() + ) + ) + out_log.write( + f"{key}\t{len(bdata.uns[key])}\t{n_coding_vars + n_noncoding_vars}\t{n_noncoding_vars}\t{n_coding_vars}\t{n_syn_vars}\n" + ) + else: + n_noncoding_vars = len( + set().union( + *bdata.uns[key].allele.map(lambda a: list(a.edits)).tolist() + ) + ) + out_log.write( + f"{key}\t{len(bdata.uns[key])}\t{n_noncoding_vars}\t{n_noncoding_vars}\t{0}\t{0}\n" + ) diff --git a/bean/cli/run.py b/bean/cli/run.py index b91457a..b602da6 100755 --- a/bean/cli/run.py +++ b/bean/cli/run.py @@ -44,6 +44,7 @@ warn = logging.warning debug = logging.debug info = logging.info + pyro.set_rng_seed(101) warnings.filterwarnings( @@ -72,16 +73,20 @@ def main(args): print("bean-run: Run model to identify targeted variants and their impact.") bdata = be.read_h5ad(args.bdata_path) args, bdata = check_args(args, bdata) - if args.cuda: - os.environ["CUDA_VISIBLE_DEVICES"] = "1" - torch.set_default_tensor_type(torch.cuda.FloatTensor) - else: - torch.set_default_tensor_type(torch.FloatTensor) prefix = ( args.outdir + "/bean_run_result." + os.path.basename(args.bdata_path).rsplit(".", 1)[0] ) + file_logger = logging.FileHandler(f"{prefix}.log") + file_logger.setLevel(logging.INFO) + logging.getLogger().addHandler(file_logger) + if args.cuda: + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + torch.set_default_tensor_type(torch.cuda.FloatTensor) + else: + torch.set_default_tensor_type(torch.FloatTensor) + os.makedirs(prefix, exist_ok=True) model_label, model, guide = identify_model_guide(args) info("Done loading data. Preprocessing...") diff --git a/bean/preprocessing/utils.py b/bean/preprocessing/utils.py index 21bf6e3..b9e7c2c 100755 --- a/bean/preprocessing/utils.py +++ b/bean/preprocessing/utils.py @@ -24,9 +24,7 @@ def __set__(self, obj, value): def prepare_bdata(bdata: be.ReporterScreen, args, warn, prefix: str): """Utility function for formatting bdata for bean-run""" bdata = bdata.copy() - bdata.samples[args.replicate_col] = bdata.samples[args.replicate_col].astype( - "category" - ) + bdata.samples["replicate"] = bdata.samples[args.replicate_col].astype("category") bdata.guides = bdata.guides.loc[:, ~bdata.guides.columns.duplicated()].copy() if args.library_design == "variant": if bdata.guides[args.target_col].isnull().any():