diff --git a/bean/model/readwrite.py b/bean/model/readwrite.py index c1c68cd..a36493b 100644 --- a/bean/model/readwrite.py +++ b/bean/model/readwrite.py @@ -102,7 +102,6 @@ def write_result_table( ) fit_df = pd.DataFrame(param_dict) - fit_df["novl"] = get_novl(fit_df, "mu", "mu_sd") if "negctrl" in param_hist_dict.keys(): print("Normalizing with common negative control distribution") mu0 = param_hist_dict["negctrl"]["params"]["mu_loc"].detach().cpu().numpy() @@ -114,6 +113,8 @@ def write_result_table( .cpu() .numpy() ) + else: + sd0 = 1.0 print(f"Fitted mu0={mu0}" + (f", sd0={sd0}." if sd_is_fitted else "")) fit_df["mu_scaled"] = (mu - mu0) / sd0 fit_df["mu_sd_scaled"] = mu_sd / sd0 @@ -154,12 +155,12 @@ def write_result_table( fit_df, std, suffix="_adj", - mu_adjusted_col="mu_scaled" - if "negctrl" in param_hist_dict.keys() - else "mu", - mu_sd_adjusted_col="mu_sd_scaled" - if "negctrl" in param_hist_dict.keys() - else "mu_sd", + mu_adjusted_col=( + "mu_scaled" if "negctrl" in param_hist_dict.keys() else "mu" + ), + mu_sd_adjusted_col=( + "mu_sd_scaled" if "negctrl" in param_hist_dict.keys() else "mu_sd" + ), ) fit_df = add_credible_interval(fit_df, "mu_adj", "mu_sd_adj") if sample_covariates is not None: @@ -168,12 +169,16 @@ def write_result_table( fit_df, std, suffix=f"_{sample_cov}_adj", - mu_adjusted_col=f"mu_{sample_cov}_scaled" - if "negctrl" in param_hist_dict.keys() - else f"mu_{sample_cov}", - mu_sd_adjusted_col=f"mu_sd_{sample_cov}_scaled" - if "negctrl" in param_hist_dict.keys() - else f"mu_sd_{sample_cov}", + mu_adjusted_col=( + f"mu_{sample_cov}_scaled" + if "negctrl" in param_hist_dict.keys() + else f"mu_{sample_cov}" + ), + mu_sd_adjusted_col=( + f"mu_sd_{sample_cov}_scaled" + if "negctrl" in param_hist_dict.keys() + else f"mu_sd_{sample_cov}" + ), ) fit_df = add_credible_interval( fit_df, f"mu_{sample_cov}_adj", f"mu_sd_{sample_cov}_adj" diff --git a/bean/model/run.py b/bean/model/run.py index d09fe1e..f8bde98 100644 --- a/bean/model/run.py +++ b/bean/model/run.py @@ -488,3 +488,20 @@ def identify_model_guide(args): fit_noise=(not args.dont_fit_noise), ), ) + + +def identify_negctrl_model_guide(args, data_has_bcmatch): + if args.selection == "sorting": + m = sorting_model + else: + m = survival_model + negctrl_model = partial( + m.ControlNormalModel, + use_bcmatch=(not args.ignore_bcmatch and data_has_bcmatch), + ) + + negctrl_guide = partial( + m.ControlNormalGuide, + use_bcmatch=(not args.ignore_bcmatch and data_has_bcmatch), + ) + return negctrl_model, negctrl_guide diff --git a/bean/model/survival_model.py b/bean/model/survival_model.py index 9b68e2c..ddfb362 100644 --- a/bean/model/survival_model.py +++ b/bean/model/survival_model.py @@ -123,12 +123,12 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True): mu_alleles = pyro.sample("mu_alleles", dist.Laplace(0, 1)) mu = mu_alleles.repeat(data.n_guides).unsqueeze(-1) r = torch.exp(mu) + with pyro.plate("rep_plate1", data.n_reps, dim=-1): + q_0 = pyro.sample( + "initial_guide_abundance", + dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))), + ) with replicate_plate: - with pyro.plate("guide_plate2", data.n_guides): - q_0 = pyro.sample( - "initial_guide_abundance", - dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))), - ) with time_plate as t: time = data.timepoints[t] assert time.shape == (data.n_condits,) diff --git a/bin/bean-run b/bin/bean-run index b4eb5f4..540bd37 100644 --- a/bin/bean-run +++ b/bin/bean-run @@ -31,6 +31,7 @@ from bean.model.run import ( parse_args, check_args, identify_model_guide, + identify_negctrl_model_guide, ) logging.basicConfig( @@ -144,14 +145,8 @@ def main(args, bdata): run_inference(model, guide, ndata, num_steps=args.n_iter) ) if args.fit_negctrl: - negctrl_model = partial( - m.ControlNormalModel, - use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers), - ) - - negctrl_guide = partial( - m.ControlNormalGuide, - use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers), + negctrl_model, negctrl_guide = identify_negctrl_model_guide( + args, "X_bcmatch" in bdata.layers ) negctrl_idx = np.where( guide_info_df[args.negctrl_col].map(lambda s: s.lower())