Skip to content

Commit

Permalink
implemented EM algorithm from scratch for single gene model and creat…
Browse files Browse the repository at this point in the history
…ed single unit test to validate baseline functionality
  • Loading branch information
ashuaibi7 committed Dec 10, 2024
1 parent b5f2fd9 commit 85fc53d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
61 changes: 59 additions & 2 deletions src/dialect/models/gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,70 @@ def negative_log_likelihood(pi):
self.pi = result.x[0]
logging.info(f"Estimated pi for gene {self.name}: {self.pi}")

def estimate_pi_with_em_from_scratch(self):
def estimate_pi_with_em_from_scratch(self, max_iter=1000, tol=1e-6, pi_init=0.5):
"""
Estimate the pi parameter using the Expectation-Maximization (EM) algorithm.
Implements the EM algorithm from scratch.
TODO: write out E and M step equations in docstring
:param max_iter (int): Maximum number of iterations (default: 1000).
:param epsilon (float): Convergence threshold for log-likelihood improvement (default: 1e-6).
:return (float): The estimated value of pi.
"""
logging.info(f"Estimating pi for gene {self.name} using the EM algorithm.")
raise NotImplementedError("EM algorithm not implemented yet.")

if self.bmr_pmf is None:
raise ValueError("BMR PMF is not defined for this gene.")
if self.counts is None:
raise ValueError("Counts are not defined for this gene.")

# TODO: Refactor to Extract Method between here and log likelihood calculation
missing_bmr_pmf_counts = [c for c in self.counts if c not in self.bmr_pmf]
if missing_bmr_pmf_counts:
logging.warning(
f"Counts {missing_bmr_pmf_counts} are not in bmr_pmf for gene {self.name}."
f"These samples will be skipped. Please ensure bmr_pmf includes all relevant counts."
)

# TODO: Double check logic and validity of removing counts here
valid_counts = [
c for c in self.counts if c in self.bmr_pmf and self.bmr_pmf[c] > 0
]

pi_hat = pi_init
for iteration in range(max_iter):
# E-step: Compute responsibilities
z_i = [
(pi_hat * self.bmr_pmf.get(c - 1, 0))
/ (
self.bmr_pmf.get(c, 0) * (1 - pi_hat)
+ self.bmr_pmf.get(c - 1, 0) * pi_hat
)
for c in valid_counts
]

# M-step: Update pi as the mean of responsibilities
new_pi = np.mean(z_i)

# Compute log-likelihood for convergence check
prev_log_likelihood = self.compute_log_likelihood(pi_hat)
curr_log_likelihood = self.compute_log_likelihood(new_pi)

logging.info(
f"Iteration {iteration}: pi={pi_hat:.4f}, log_likelihood={curr_log_likelihood:.4f}"
)

# Check convergence
if abs(curr_log_likelihood - prev_log_likelihood) < tol:
logging.info(f"EM algorithm converged after {iteration} iterations.")
break

pi_hat = new_pi

self.pi = pi_hat
logging.info(f"Estimated pi for gene {self.name}: {self.pi:.4f}")
return self.pi

def estimate_pi_with_em_using_pomegranate(self):
"""
Expand Down
14 changes: 14 additions & 0 deletions tests/test_gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@ def test_estimate_pi_with_optimization(self):
self.assertIsInstance(self.gene.pi, float)
self.assertGreaterEqual(self.gene.pi, 0.0)
self.assertLessEqual(self.gene.pi, 1.0)

def test_estimate_pi_with_em(self):
"""
Test estimate_pi_with_em_from_scratch for a realistic case.
"""
self.gene.estimate_pi_with_em_from_scratch()
self.assertIsInstance(self.gene.pi, float)
self.assertGreaterEqual(self.gene.pi, 0.0)
self.assertLessEqual(self.gene.pi, 1.0)

# TODO Add additional edge cases and tests, including:
# - TODO: test for non-normalized bmr_pmf
# - TODO: tests for values outside of allowed bounds for all methods
# - TODO: tests for missing values in bmr_pmf

0 comments on commit 85fc53d

Please sign in to comment.