Skip to content

Commit

Permalink
adjustments for survival models
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Aug 15, 2024
1 parent 86572bc commit 2fdb440
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 25 deletions.
59 changes: 36 additions & 23 deletions bean/model/survival_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,28 +289,35 @@ def MixtureNormalModel(

with replicate_plate:
with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
time_pi = data.control_timepoint
# Accounting for sample specific overall edit rate across all guides.
# P(allele | guide, bin=bulk)
pi = pyro.sample(
"pi",
dist.Dirichlet(
pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1)
),
)
assert pi.shape == (
data.n_reps,
1,
data.n_guides,
2,
), pi.shape
# If pi is sampled in later timepoint, account for the selection.
expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi
pyro.sample(
"bulk_allele_count",
dist.Multinomial(probs=expanded_allele_p, validate_args=False),
obs=data.allele_counts_control,
)
with time_plate:
with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
time_pi = data.timepoints
# Accounting for sample specific overall edit rate across all guides.
# P(allele | guide, bin=bulk)
assert pi.shape == (
data.n_reps,
1,
data.n_guides,
2,
), pi.shape
# If pi is sampled in later timepoint, account for the selection.

expanded_allele_p = pi * r.expand(
data.n_reps, len(data.timepoints), -1, -1
) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(
data.n_reps, len(data.timepoints), -1, -1
)
pyro.sample(
"allele_count",
dist.Multinomial(probs=expanded_allele_p, validate_args=False),
obs=data.allele_counts,
)
if scale_by_accessibility:
# Endogenous target site editing rate may be different
pi = scale_pi_by_accessibility(
Expand Down Expand Up @@ -486,13 +493,19 @@ def MultiMixtureNormalModel(
pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1)
),
)
# If pi is sampled in later timepoint, account for the selection.
expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi
pyro.sample(
"bulk_allele_count",
dist.Multinomial(probs=expanded_allele_p, validate_args=False),
obs=data.allele_counts_control,
)
with time_plate:
with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
# If pi is sampled in later timepoint, account for the selection.
expanded_allele_p = pi * r.expand(
data.n_reps, 1, -1, -1
) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(
data.n_reps, len(data.timepoints), -1, -1
)
pyro.sample(
"allele_count",
dist.Multinomial(probs=expanded_allele_p, validate_args=False),
obs=data.allele_counts,
)
if scale_by_accessibility:
# Endogenous target site editing rate may be different
pi = scale_pi_by_accessibility(
Expand Down
32 changes: 30 additions & 2 deletions bean/preprocessing/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,21 @@ def _post_init(
self.screen_control.samples["size_factor_bcmatch"].to_numpy()
).reshape(self.n_reps, 1)

if hasattr(self, "timepoints") and not hasattr(self, "allele_counts"):
control_allele_counts = []
for timepoint in self.timepoints:
screen_t = self.screen[:, self.screen.samples.time == timepoint.item()]
edited_control = self.transform_data(screen_t.layers["edits"], n_bins=1)
nonedited_control = (
self.transform_data(screen_t.layers["X_bcmatch"], 1)
- edited_control
)
nonedited_control[nonedited_control < 0] = 0
control_allele_counts.append(
torch.stack([nonedited_control, edited_control], axis=-1)
) # (n_reps, n_bins, n_guides, n_alleles)
self.allele_counts = torch.cat(control_allele_counts, axis=1)

edited_control = self.transform_data(
self.screen_control.layers["edits"], n_bins=1
)
Expand Down Expand Up @@ -981,12 +996,25 @@ def __init__(
def _pre_init(self, time_column: str, condition_column: str):
self.condition_column = self.time_column = time_column
try:
max_time = self.screen.samples[time_column].max()
self.screen.samples[time_column] = self.screen.samples[time_column].astype(
float
)
self.screen.samples[time_column] = (
self.screen.samples[time_column]
/ self.screen.samples[time_column].max()
self.screen.samples[time_column] / max_time
)

self.screen_selected.samples[time_column] = self.screen_selected.samples[
time_column
].astype(float)
self.screen_selected.samples[time_column] = (
self.screen_selected.samples[time_column] / max_time
)
self.screen_control.samples[time_column] = self.screen_control.samples[
time_column
].astype(float)
self.screen_control.samples[time_column] = (
self.screen_control.samples[time_column] / max_time
)
except ValueError as e:
raise ValueError(
Expand Down

0 comments on commit 2fdb440

Please sign in to comment.