Skip to content

Commit

Permalink
extracted and refactored methods to verify parameter values
Browse files Browse the repository at this point in the history
  • Loading branch information
ashuaibi7 committed Dec 12, 2024
1 parent 87c1867 commit de148ad
Showing 1 changed file with 79 additions and 62 deletions.
141 changes: 79 additions & 62 deletions src/dialect/models/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,54 @@ def __init__(self, gene_a, gene_b):
self.tau_10 = None # P(D = 1, D' = 0) for genes A and B
self.tau_11 = None # P(D = 1, D' = 1) for genes A and B

# ---------------------------------------------------------------------------- #
# DATA VALIDATION & LOGGING #
# ---------------------------------------------------------------------------- #

def verify_bmr_pmf_and_counts_exist(self):
"""
Verify that BMR PMFs and counts exist for both genes in the interaction pair.
:raises ValueError: If BMR PMFs or counts are not defined.
"""
if not self.gene_a.bmr_pmf or not self.gene_b.bmr_pmf:
raise ValueError("BMR PMFs are not defined for one or both genes.")

if not self.gene_a.counts or not self.gene_b.counts:
raise ValueError("Counts are not defined for one or both genes.")

def verify_taus_are_valid(self, taus):
"""
Verify that tau parameters are valid (0 <= tau_i <= 1 and sum(tau) == 1).
:param taus: (list of float) Tau parameters to validate.
:raises ValueError: If any or all tau parameters are invalid.
"""
if not all(0 <= t <= 1 for t in taus) or not np.isclose(sum(taus), 1):
logging.info(f"Invalid tau parameters: {taus}")
raise ValueError(
"Invalid tau parameters. Ensure 0 <= tau_i <= 1 and sum(tau) == 1."
)

def verify_pi_values(self, pi_a, pi_b):
"""
Verify that driver probabilities (pi values) are defined for both genes in the interaction.
:param pi_a: (float or None) Driver probability for gene A.
:param pi_b: (float or None) Driver probability for gene B.
:return: None if either pi value is not defined.
:raises ValueError: If both pi values are missing.
"""
if pi_a is None or pi_b is None:
logging.warning(
f"Driver probabilities (pi) are not defined for genes in interaction {self.name}."
)
raise ValueError("Driver probabilities are not defined for both genes.")

# ---------------------------------------------------------------------------- #
# Likelihood & Metric Evaluation #
# ---------------------------------------------------------------------------- #
# TODO: Add additional metrics (KL, MI, etc.) for further exploration

def compute_log_likelihood(self, taus):
"""
Expand All @@ -50,18 +95,11 @@ def compute_log_likelihood(self, taus):
:return (float): The log-likelihood value.
:raises ValueError: If `bmr_pmf` or `counts` are not defined for either gene, or if `tau` is invalid.
"""
if not self.gene_a.bmr_pmf or not self.gene_b.bmr_pmf:
raise ValueError("BMR PMFs are not defined for one or both genes.")
if not self.gene_a.counts or not self.gene_b.counts:
raise ValueError("Counts are not defined for one or both genes.")
if not all(0 <= t <= 1 for t in taus) or not np.isclose(sum(taus), 1):
logging.info(f"Invalid tau parameters: {taus}")
raise ValueError(
"Invalid tau parameters. Ensure 0 <= tau_i <= 1 and sum(tau) == 1."
)

logging.info(f"Computing log likelihood for {self.name}. Taus: {taus}")

self.verify_bmr_pmf_and_counts_exist()
self.verify_taus_are_valid(taus)

a_counts, b_counts = self.gene_a.counts, self.gene_b.counts
a_bmr_pmf, b_bmr_pmf = self.gene_a.bmr_pmf, self.gene_b.bmr_pmf
tau_00, tau_01, tau_10, tau_11 = taus
Expand All @@ -76,7 +114,7 @@ def compute_log_likelihood(self, taus):
)
return log_likelihood

def compute_likelihood_ratio(self):
def compute_likelihood_ratio(self, taus):
"""
Compute the likelihood ratio test statistic (lambda_LR) with respect to the null hypothesis.
Expand All @@ -90,13 +128,14 @@ def compute_likelihood_ratio(self):
:return (float): Likelihood ratio.
"""
logging.info(f"Computing likelihood ratio for interaction {self.name}.")

self.verify_pi_values(self.gene_a.pi, self.gene_b.pi)
pi_a, pi_b = self.gene_a.pi, self.gene_b.pi
if pi_a is None or pi_b is None:
logging.warning(
f"Driver probabilities (pi) are not defined for genes in interaction {self.name}."
)
return None
tau_00, tau_01, tau_10, tau_11 = taus

# TODO: Validate the null hypothesis is correct
# ? Why shouldn't we use the marginals instead?
tau_null = (
(1 - pi_a) * (1 - pi_b), # tau_00: neither gene has a driver mutation
(1 - pi_a) * pi_b, # tau_01: only gene_b has a driver mutation
Expand All @@ -105,13 +144,11 @@ def compute_likelihood_ratio(self):
)
lambda_LR = -2 * (
self.compute_log_likelihood(tau_null)
- self.compute_log_likelihood(
(self.tau_00, self.tau_01, self.tau_10, self.tau_11)
)
- self.compute_log_likelihood((tau_00, tau_01, tau_10, tau_11))
)
return lambda_LR

def compute_log_odds_ratio(self):
def compute_log_odds_ratio(self, taus):
"""
Compute the log odds ratio for the interaction based on the tau parameters.
Expand All @@ -121,35 +158,25 @@ def compute_log_odds_ratio(self):
:return (float): The log odds ratio.
:raises ValueError: If tau parameters are invalid or lead to division by zero.
"""
# Validate tau parameters
if not all(
0 <= t <= 1 for t in [self.tau_00, self.tau_01, self.tau_10, self.tau_11]
) or not np.isclose(
sum([self.tau_00, self.tau_01, self.tau_10, self.tau_11]), 1
):
logging.info(
f"Invalid tau parameters: tau_00={self.tau_00}, tau_01={self.tau_01}, tau_10={self.tau_10}, tau_11={self.tau_11}"
)
raise ValueError(
"Invalid tau parameters. Ensure 0 <= tau_ij <= 1 and sum(tau) == 1."
)
logging.info(f"Computing log odds ratio for interaction {self.name}.")

self.verify_taus_are_valid(taus)
tau_00, tau_01, tau_10, tau_11 = taus

if self.tau_01 * self.tau_10 == 0 or self.tau_00 * self.tau_11 == 0:
if tau_01 * tau_10 == 0 or tau_00 * tau_11 == 0:
logging.warning(
f"Zero encountered in odds ratio computation for interaction {self.name}. "
f"tau_01={self.tau_01}, tau_10={self.tau_10}, tau_00={self.tau_00}, tau_11={self.tau_11}"
f"tau_01={tau_01}, tau_10={tau_10}, tau_00={tau_00}, tau_11={tau_11}"
)
return None # Return None when numerator or denominator is zero

log_odds_ratio = np.log(
(self.tau_01 * self.tau_10) / (self.tau_00 * self.tau_11)
)
log_odds_ratio = np.log((tau_01 * tau_10) / (tau_00 * tau_11))
logging.info(
f"Computed log odds ratio for interaction {self.name}: {log_odds_ratio}"
)
return log_odds_ratio

def compute_wald_statistic(self):
def compute_wald_statistic(self, taus):
"""
Compute the Wald statistic for the interaction.
Expand All @@ -164,19 +191,13 @@ def compute_wald_statistic(self):
:return (float or None): The Wald statistic, or None
"""
logging.info(f"Computing Wald statistic for interaction {self.name}.")
log_odds_ratio = self.compute_log_odds_ratio()

self.verify_taus_are_valid(taus)
log_odds_ratio = self.compute_log_odds_ratio(taus)
if log_odds_ratio is None:
logging.warning(f"Log odds ratio is None for interaction {self.name}.")
return None

if any(t <= 0 for t in [self.tau_00, self.tau_01, self.tau_10, self.tau_11]):
logging.warning(
f"Invalid tau parameters for interaction {self.name}. "
f"tau_00={self.tau_00}, tau_01={self.tau_01}, tau_10={self.tau_10}, tau_11={self.tau_11}."
"All tau values must be positive to compute the Wald statistic."
)
raise ValueError("Invalid tau parameters for Wald statistic computation.")

try:
std_err = np.sqrt(
(1 / self.tau_01)
Expand All @@ -196,7 +217,7 @@ def compute_wald_statistic(self):
)
return wald_statistic

def compute_rho(self):
def compute_rho(self, taus):
"""
Compute the interaction measure rho (ρ) for the given tau parameters - ρ is given by:
ρ = (tau_01 * tau_10 - tau_11 * tau_00) / sqrt(tau_0* * tau_1* * tau_*0 * tau_*1)
Expand All @@ -209,18 +230,9 @@ def compute_rho(self):
:return (float or None): The value of rho, or None if the computation is invalid (e.g., division by zero).
"""
if not all(
0 <= t <= 1 for t in [self.tau_00, self.tau_01, self.tau_10, self.tau_11]
) or not np.isclose(
sum([self.tau_00, self.tau_01, self.tau_10, self.tau_11]), 1
):
logging.warning(
f"Invalid tau parameters for interaction {self.name}: "
f"tau_00={self.tau_00}, tau_01={self.tau_01}, tau_10={self.tau_10}, tau_11={self.tau_11}."
)
raise ValueError(
"Invalid tau parameters. Ensure 0 <= tau_ij <= 1 and sum(tau) == 1."
)
logging.info(f"Computing rho for interaction {self.name}.")

self.verify_taus_are_valid(taus)

tau_0X = self.tau_00 + self.tau_01
tau_1X = self.tau_10 + self.tau_11
Expand Down Expand Up @@ -256,6 +268,8 @@ def estimate_tau_with_optimization_using_scipy(
"""
logging.info(f"Estimating tau params for {self.name} using L-BFGS-B.")

self.verify_bmr_pmf_and_counts_exist()

def negative_log_likelihood(tau):
return -self.compute_log_likelihood(tau)

Expand All @@ -281,11 +295,14 @@ def negative_log_likelihood(tau):
return self.tau_00, self.tau_01, self.tau_10, self.tau_11

# TODO: Implement this method
def estimate_tau_with_em_from_scratch():
def estimate_tau_with_em_from_scratch(self):
logging.info("Estimating tau parameters using EM algorithm from scratch.")

self.verify_bmr_pmf_and_counts_exist()

raise NotImplementedError("Method is not yet implemented.")

# TODO: Implement below to increase speed relative to from-scratch EM
def estimate_tau_with_em_using_pomegranate():
def estimate_tau_with_em_using_pomegranate(self):
logging.info("Estimating tau parameters using pomegranate.")
raise NotImplementedError("Method is not yet implemented.")

0 comments on commit de148ad

Please sign in to comment.