From 7669a35697e215c5d745c082e4ebd496bd7244e5 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Mon, 17 Jun 2024 17:00:03 -0400 Subject: [PATCH] fix sample ordering in data_class for survival screens --- CHANGELOG.md | 5 ++++- bean/model/parser.py | 2 +- bean/model/survival_model.py | 18 ++++++++++++++++-- bean/preprocessing/data_class.py | 21 +++++++++++++-------- setup.py | 2 +- 5 files changed, 35 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29b22fd..bf8155c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,4 +3,7 @@ * Allow `bean run .. tiling` for untranslated `--allele-df-key`. ## 1.2.6 -* Fix overflow in `bean run survival` and autograde error related to inplace assignment for `bean run survival tiling`. \ No newline at end of file +* Fix overflow in `bean run survival` and autograde error related to inplace assignment for `bean run survival tiling`. + +## 1.2.7 +* **CRITICAL** Fix sample ordering & masking issue for survival screens \ No newline at end of file diff --git a/bean/model/parser.py b/bean/model/parser.py index f5f74e3..5786f1d 100755 --- a/bean/model/parser.py +++ b/bean/model/parser.py @@ -147,7 +147,7 @@ def parse_args(parser=None): parser.add_argument( "--sample-mask-col", type=str, - default=None, + default="mask", help="Name of the column indicating the sample mask in [Reporter]Screen.samples (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples.", ) parser.add_argument( diff --git a/bean/model/survival_model.py b/bean/model/survival_model.py index 877b0a8..282b324 100755 --- a/bean/model/survival_model.py +++ b/bean/model/survival_model.py @@ -366,7 +366,14 @@ def MixtureNormalGuide( replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) - + initial_abundance = pyro.param( + "initial_abundance", torch.ones(data.n_guides) / data.n_guides + ) + with pyro.plate("replicate_plate0", data.n_reps, dim=-1): + q_0 = pyro.sample( + "initial_guide_abundance", + dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))), + ) # Set the prior for phenotype means mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets, 1))) mu_scale = pyro.param( @@ -449,7 +456,14 @@ def MultiMixtureNormalModel( replicate_plate2 = pyro.plate("rep_plate2", data.n_reps, dim=-2) time_plate = pyro.plate("time_plate", data.n_condits, dim=-2) guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) - + initial_abundance = pyro.param( + "initial_abundance", torch.ones(data.n_guides) / data.n_guides + ) + with pyro.plate("replicate_plate0", data.n_reps, dim=-1): + q_0 = pyro.sample( + "initial_guide_abundance", + dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))), + ) # Set the prior for phenotype means with pyro.plate("guide_plate1", data.n_edits): mu_edits = pyro.sample("mu_alleles", dist.Laplace(0, 1)) diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 7d1f1b3..6bba505 100755 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -39,7 +39,7 @@ def __init__( repguide_mask: str = None, sample_mask_column: str = None, shrink_alpha: bool = False, - condition_column: str = "sort", + condition_column: str = "condition", sample_covariate_column: List[str] = [], control_condition: str = "bulk", accessibility_col: str = None, @@ -1001,9 +1001,9 @@ def _pre_init(self, time_column: str, condition_column: str): f"Not all replicate share same timepoint definition. If you have missing bin data, add the sample and add 'mask' column in 'screen.samples', or run `bean-qc` that automatically handles this. \n{self.screen.samples}" ) - def _post_init( - self, - ): + # def _post_init( + # self, + # ): self.timepoints = torch.as_tensor( self.screen_selected.samples[self.time_column].unique() ) @@ -1307,7 +1307,7 @@ def __init__( ) SurvivalScreenData._pre_init(self, time_column, condition_column) ScreenData._post_init(self) - SurvivalScreenData._post_init(self) + # SurvivalScreenData._post_init(self) VariantScreenData._post_init(self, target_col) if use_bcmatch: self.set_bcmatch( @@ -1393,9 +1393,14 @@ def __init__( **kwargs, ) SurvivalScreenData._pre_init(self, time_column, condition_column) - ScreenData._post_init(self) - SurvivalScreenData._post_init(self) - VariantScreenData._post_init(self, target_col) + ScreenData._post_init( + self, + ) + # SurvivalScreenData._post_init(self) + VariantScreenData._post_init( + self, + target_col, + ) ReporterScreenData._post_init( self, screen, diff --git a/setup.py b/setup.py index 86d8921..fd4b6e2 100755 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="crispr-bean", - version="1.2.7", + version="1.2.8", python_requires=">=3.8.0", author="Jayoung Ryu", author_email="jayoung_ryu@g.harvard.edu",