Skip to content

Commit

Permalink
fix sample ordering in data_class for survival screens
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Jun 17, 2024
1 parent 79ed208 commit 7669a35
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 13 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
* Allow `bean run .. tiling` for untranslated `--allele-df-key`.

## 1.2.6
* Fix overflow in `bean run survival` and autograde error related to inplace assignment for `bean run survival tiling`.
* Fix overflow in `bean run survival` and autograde error related to inplace assignment for `bean run survival tiling`.

## 1.2.7
* **CRITICAL** Fix sample ordering & masking issue for survival screens
2 changes: 1 addition & 1 deletion bean/model/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def parse_args(parser=None):
parser.add_argument(
"--sample-mask-col",
type=str,
default=None,
default="mask",
help="Name of the column indicating the sample mask in [Reporter]Screen.samples (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples.",
)
parser.add_argument(
Expand Down
18 changes: 16 additions & 2 deletions bean/model/survival_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,14 @@ def MixtureNormalGuide(

replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3)
guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1)

initial_abundance = pyro.param(
"initial_abundance", torch.ones(data.n_guides) / data.n_guides
)
with pyro.plate("replicate_plate0", data.n_reps, dim=-1):
q_0 = pyro.sample(
"initial_guide_abundance",
dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))),
)
# Set the prior for phenotype means
mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets, 1)))
mu_scale = pyro.param(
Expand Down Expand Up @@ -449,7 +456,14 @@ def MultiMixtureNormalModel(
replicate_plate2 = pyro.plate("rep_plate2", data.n_reps, dim=-2)
time_plate = pyro.plate("time_plate", data.n_condits, dim=-2)
guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1)

initial_abundance = pyro.param(
"initial_abundance", torch.ones(data.n_guides) / data.n_guides
)
with pyro.plate("replicate_plate0", data.n_reps, dim=-1):
q_0 = pyro.sample(
"initial_guide_abundance",
dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))),
)
# Set the prior for phenotype means
with pyro.plate("guide_plate1", data.n_edits):
mu_edits = pyro.sample("mu_alleles", dist.Laplace(0, 1))
Expand Down
21 changes: 13 additions & 8 deletions bean/preprocessing/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
repguide_mask: str = None,
sample_mask_column: str = None,
shrink_alpha: bool = False,
condition_column: str = "sort",
condition_column: str = "condition",
sample_covariate_column: List[str] = [],
control_condition: str = "bulk",
accessibility_col: str = None,
Expand Down Expand Up @@ -1001,9 +1001,9 @@ def _pre_init(self, time_column: str, condition_column: str):
f"Not all replicate share same timepoint definition. If you have missing bin data, add the sample and add 'mask' column in 'screen.samples', or run `bean-qc` that automatically handles this. \n{self.screen.samples}"
)

def _post_init(
self,
):
# def _post_init(
# self,
# ):
self.timepoints = torch.as_tensor(
self.screen_selected.samples[self.time_column].unique()
)
Expand Down Expand Up @@ -1307,7 +1307,7 @@ def __init__(
)
SurvivalScreenData._pre_init(self, time_column, condition_column)
ScreenData._post_init(self)
SurvivalScreenData._post_init(self)
# SurvivalScreenData._post_init(self)
VariantScreenData._post_init(self, target_col)
if use_bcmatch:
self.set_bcmatch(
Expand Down Expand Up @@ -1393,9 +1393,14 @@ def __init__(
**kwargs,
)
SurvivalScreenData._pre_init(self, time_column, condition_column)
ScreenData._post_init(self)
SurvivalScreenData._post_init(self)
VariantScreenData._post_init(self, target_col)
ScreenData._post_init(
self,
)
# SurvivalScreenData._post_init(self)
VariantScreenData._post_init(
self,
target_col,
)
ReporterScreenData._post_init(
self,
screen,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name="crispr-bean",
version="1.2.7",
version="1.2.8",
python_requires=">=3.8.0",
author="Jayoung Ryu",
author_email="[email protected]",
Expand Down

0 comments on commit 7669a35

Please sign in to comment.