diff --git a/src/dialect/models/interaction.py b/src/dialect/models/interaction.py index 03fb6ce..b5d02f2 100644 --- a/src/dialect/models/interaction.py +++ b/src/dialect/models/interaction.py @@ -47,10 +47,14 @@ def __str__(self): ) pi_a = ( - f"{self.gene_a.pi:.3e}" if self.gene_a.pi is not None else "Not estimated" + f"{self.gene_a.pi:.3e}" + if self.gene_a.pi is not None + else "Not estimated" ) pi_b = ( - f"{self.gene_b.pi:.3e}" if self.gene_b.pi is not None else "Not estimated" + f"{self.gene_b.pi:.3e}" + if self.gene_b.pi is not None + else "Not estimated" ) cm = self.compute_contingency_table() @@ -146,7 +150,9 @@ def verify_taus_are_valid(self, taus, tol=1e-6): :param tol: (float) Tolerance for the sum of tau parameters (default: 1e-1). :raises ValueError: If any or all tau parameters are invalid. """ - if not all(0 <= t <= 1 for t in taus) or not np.isclose(sum(taus), 1, atol=tol): + if not all(0 <= t <= 1 for t in taus) or not np.isclose( + sum(taus), 1, atol=tol + ): logging.info(f"Invalid tau parameters: {taus}") raise ValueError( "Invalid tau parameters. Ensure 0 <= tau_i <= 1 and sum(tau) == 1." @@ -173,7 +179,9 @@ def verify_pi_values(self, pi_a, pi_b): logging.warning( f"Driver probabilities (pi) are not defined for genes in interaction {self.name}." ) - raise ValueError("Driver probabilities are not defined for both genes.") + raise ValueError( + "Driver probabilities are not defined for both genes." + ) # ---------------------------------------------------------------------------- # # Likelihood & Metric Evaluation # @@ -294,7 +302,8 @@ def compute_likelihood_ratio(self, taus): driver_b_marginal = tau_01 + tau_11 tau_null = ( - (1 - driver_a_marginal) * (1 - driver_b_marginal), # both genes passengers + (1 - driver_a_marginal) + * (1 - driver_b_marginal), # both genes passengers (1 - driver_a_marginal) * driver_b_marginal, # gene a passenger, gene b driver driver_a_marginal @@ -370,7 +379,9 @@ def compute_wald_statistic(self, taus): self.verify_taus_are_valid(taus) log_odds_ratio = self.compute_log_odds_ratio(taus) if log_odds_ratio is None: - logging.warning(f"Log odds ratio is None for interaction {self.name}.") + logging.warning( + f"Log odds ratio is None for interaction {self.name}." + ) return None try: @@ -522,7 +533,9 @@ def estimate_tau_with_em_from_scratch( **Returns**: :return: (tuple) The estimated values of \( (\\tau_{00}, \\tau_{01}, \\tau_{10}, \\tau_{11}) \). """ - logging.info("Estimating tau parameters using EM algorithm from scratch.") + logging.info( + "Estimating tau parameters using EM algorithm from scratch." + ) self.verify_bmr_pmf_and_counts_exist() @@ -532,10 +545,22 @@ def estimate_tau_with_em_from_scratch( total_probabilities = self.compute_total_probability( tau_00, tau_01, tau_10, tau_11 ) # denominator in E-Step equation - z_i_00 = self.compute_joint_probability(tau_00, 0, 0) / total_probabilities - z_i_01 = self.compute_joint_probability(tau_01, 0, 1) / total_probabilities - z_i_10 = self.compute_joint_probability(tau_10, 1, 0) / total_probabilities - z_i_11 = self.compute_joint_probability(tau_11, 1, 1) / total_probabilities + z_i_00 = ( + self.compute_joint_probability(tau_00, 0, 0) + / total_probabilities + ) + z_i_01 = ( + self.compute_joint_probability(tau_01, 0, 1) + / total_probabilities + ) + z_i_10 = ( + self.compute_joint_probability(tau_10, 1, 0) + / total_probabilities + ) + z_i_11 = ( + self.compute_joint_probability(tau_11, 1, 1) + / total_probabilities + ) # M-Step: Update tau parameters curr_tau_00 = np.mean(z_i_00) @@ -566,7 +591,9 @@ def estimate_tau_with_em_from_scratch( tau_10, tau_11, ) - logging.info(" EM algorithm converged after {} iterations.".format(it + 1)) + logging.info( + " EM algorithm converged after {} iterations.".format(it + 1) + ) logging.info( f"Estimated tau parameters for interaction {self.name}: tau_00={self.tau_00}, tau_01={self.tau_01}, tau_10={self.tau_10}, tau_11={self.tau_11}" )