diff --git a/src/dialect/models/gene.py b/src/dialect/models/gene.py index 986f630..0c01871 100644 --- a/src/dialect/models/gene.py +++ b/src/dialect/models/gene.py @@ -1,72 +1,107 @@ +import logging import numpy as np class Gene: - def __init__(self, name, counts, bmr_pmf=None): + def __init__(self, name, counts, bmr_pmf): """ Initialize a Gene object. - :param name: Name of the gene. - :param counts: Mutation counts for the gene. - :param bmr_pmf: BMR PMF (multinomial passed as single list). + :param name (str): Name of the gene. + :param counts (np.ndarray) Mutation counts for the gene. + :param bmr_pmf (defaultdict): BMR PMF (multinomial passed as single list). """ self.name = name self.counts = counts self.bmr_pmf = bmr_pmf self.pi = None - def compute_log_likelihood(self, pi): + def compute_log_likelihood(self): """ - Compute the log likelihood of the data given parameter pi. + Compute the complete data log-likelihood for the gene given the estimated pi. - :param pi: Parameter pi (float). - :return: Log likelihood (float). + The likelihood function is given by: + \sum_{i=1}^{N} \log(\mathbb{P}(P_i = c_i)(1 - \pi) + \mathbb{P}(P_i = c_i - 1) \pi) + + where: + - `N` is the number of samples. + - `P_i` represents the RV for passenger mutations + - `c_i` is the observed count of somatic mutations for sample i + - `\pi` is the estimated driver mutation rate parameter value. + + return (float): The log-likelihood value. + raises (ValueError): If `bmr_pmf`, `counts`, or `pi` is not defined. """ - if self.bmr_pmf is None: + if not self.bmr_pmf: raise ValueError("BMR PMF is not defined for this gene.") - if self.counts is None: + if not self.counts: raise ValueError("Counts are not defined for this gene.") - raise NotImplementedError( - "Log likelihood ratio computation is not yet implemented." + if not self.pi: + raise ValueError("Pi has not been estimated for this gene.") + + logging.info( + f"Computing log likelihood for gene {self.name}. Pi: {self.pi}. BMR PMF: {self.bmr_pmf}" + ) + + log_likelihood = sum( + np.log( + self.bmr_pmf.get(c, 0) * (1 - self.pi) + + self.bmr_pmf.get(c - 1, 0) * self.pi + ) + for c in self.counts ) + return log_likelihood def binarize_counts(self): """ Get the binarized counts for the gene based on a threshold. - :param threshold: Threshold for binarization (default is 1). - :return: Binarized counts (numpy array). + :return (np.ndarray): Binarized counts. """ return (self.counts >= 1).astype(int) - def get_contingency_table(self, other_gene_counts): + def get_contingency_table(self, other_counts): """ Get the contingency table for the current gene and another gene. - :param other_gene_counts: Counts of another gene (numpy array). - :return: Contingency table (2x2 numpy array). + :param other_counts (np.ndarray): Counts of another gene. + :return (np.ndarray): Contingency table. """ raise NotImplementedError( "Contingency table computation is not yet implemented." ) - def compute_likelihood_ratio(self, contingency_table): + def compute_likelihood_ratio(self): """ Compute the likelihood ratio with respect to the null hypothesis. - :param contingency_table: A 2x2 contingency table. - :return: Likelihood ratio (float). + :return (float): Likelihood ratio. """ # Implement a proper likelihood ratio computation based on the contingency table. raise NotImplementedError( "Likelihood ratio computation is not yet implemented." ) - def compute_log_odds_ratio(self, contingency_table): + def compute_log_odds_ratio(self): """ Compute the log odds ratio from the contingency table. - :param contingency_table: A 2x2 contingency table. - :return: Log odds ratio (float). + :return (float): Log odds ratio. """ - raise NotImplementedError("Log odds ratio computation is not yet implemented.") + if not self.pi: + raise ValueError("Pi has not been esitmated for this gene.") + if self.pi < 0 or self.pi > 1: + logging.warning(f"Invalid pi value: {self.pi}.") + raise ValueError( + "Estimated pi must be between 0 and 1, inclusive." + ) # TODO: move this error to EM method + if self.pi == 0: + logging.info(f"Pi for gene {self.name} is 0") + log_odds_ratio = -np.inf + elif self.pi == 1: + logging.info(f"Pi for gene {self.name} is 1") + log_odds_ratio = np.inf + else: + log_odds_ratio = np.log(self.pi / (1 - self.pi)) + return log_odds_ratio + # raise NotImplementedError("Log odds ratio computation is not yet implemented.")