From 2fdb4404a1f269286c41c37bd69dfea98d0db338 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Thu, 15 Aug 2024 18:00:28 -0400 Subject: [PATCH] adjustments for survival models --- bean/model/survival_model.py | 59 +++++++++++++++++++------------- bean/preprocessing/data_class.py | 32 +++++++++++++++-- 2 files changed, 66 insertions(+), 25 deletions(-) diff --git a/bean/model/survival_model.py b/bean/model/survival_model.py index 9fdfbc6..4dc1434 100755 --- a/bean/model/survival_model.py +++ b/bean/model/survival_model.py @@ -289,28 +289,35 @@ def MixtureNormalModel( with replicate_plate: with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): - time_pi = data.control_timepoint - # Accounting for sample specific overall edit rate across all guides. - # P(allele | guide, bin=bulk) pi = pyro.sample( "pi", dist.Dirichlet( pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1) ), ) - assert pi.shape == ( - data.n_reps, - 1, - data.n_guides, - 2, - ), pi.shape - # If pi is sampled in later timepoint, account for the selection. - expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi - pyro.sample( - "bulk_allele_count", - dist.Multinomial(probs=expanded_allele_p, validate_args=False), - obs=data.allele_counts_control, - ) + with time_plate: + with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): + time_pi = data.timepoints + # Accounting for sample specific overall edit rate across all guides. + # P(allele | guide, bin=bulk) + assert pi.shape == ( + data.n_reps, + 1, + data.n_guides, + 2, + ), pi.shape + # If pi is sampled in later timepoint, account for the selection. + + expanded_allele_p = pi * r.expand( + data.n_reps, len(data.timepoints), -1, -1 + ) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand( + data.n_reps, len(data.timepoints), -1, -1 + ) + pyro.sample( + "allele_count", + dist.Multinomial(probs=expanded_allele_p, validate_args=False), + obs=data.allele_counts, + ) if scale_by_accessibility: # Endogenous target site editing rate may be different pi = scale_pi_by_accessibility( @@ -486,13 +493,19 @@ def MultiMixtureNormalModel( pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1) ), ) - # If pi is sampled in later timepoint, account for the selection. - expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi - pyro.sample( - "bulk_allele_count", - dist.Multinomial(probs=expanded_allele_p, validate_args=False), - obs=data.allele_counts_control, - ) + with time_plate: + with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): + # If pi is sampled in later timepoint, account for the selection. + expanded_allele_p = pi * r.expand( + data.n_reps, 1, -1, -1 + ) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand( + data.n_reps, len(data.timepoints), -1, -1 + ) + pyro.sample( + "allele_count", + dist.Multinomial(probs=expanded_allele_p, validate_args=False), + obs=data.allele_counts, + ) if scale_by_accessibility: # Endogenous target site editing rate may be different pi = scale_pi_by_accessibility( diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 5e3da55..3e6f2f7 100755 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -331,6 +331,21 @@ def _post_init( self.screen_control.samples["size_factor_bcmatch"].to_numpy() ).reshape(self.n_reps, 1) + if hasattr(self, "timepoints") and not hasattr(self, "allele_counts"): + control_allele_counts = [] + for timepoint in self.timepoints: + screen_t = self.screen[:, self.screen.samples.time == timepoint.item()] + edited_control = self.transform_data(screen_t.layers["edits"], n_bins=1) + nonedited_control = ( + self.transform_data(screen_t.layers["X_bcmatch"], 1) + - edited_control + ) + nonedited_control[nonedited_control < 0] = 0 + control_allele_counts.append( + torch.stack([nonedited_control, edited_control], axis=-1) + ) # (n_reps, n_bins, n_guides, n_alleles) + self.allele_counts = torch.cat(control_allele_counts, axis=1) + edited_control = self.transform_data( self.screen_control.layers["edits"], n_bins=1 ) @@ -981,12 +996,25 @@ def __init__( def _pre_init(self, time_column: str, condition_column: str): self.condition_column = self.time_column = time_column try: + max_time = self.screen.samples[time_column].max() self.screen.samples[time_column] = self.screen.samples[time_column].astype( float ) self.screen.samples[time_column] = ( - self.screen.samples[time_column] - / self.screen.samples[time_column].max() + self.screen.samples[time_column] / max_time + ) + + self.screen_selected.samples[time_column] = self.screen_selected.samples[ + time_column + ].astype(float) + self.screen_selected.samples[time_column] = ( + self.screen_selected.samples[time_column] / max_time + ) + self.screen_control.samples[time_column] = self.screen_control.samples[ + time_column + ].astype(float) + self.screen_control.samples[time_column] = ( + self.screen_control.samples[time_column] / max_time ) except ValueError as e: raise ValueError(