Skip to content

Commit

Permalink
First pass
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmey committed Jan 31, 2024
1 parent 0a9893f commit 2e88d3a
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 366 deletions.
7 changes: 2 additions & 5 deletions lineage/BaumWelch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@ def do_E_step(tHMMobj: tHMM) -> Tuple[list, list, list, list]:

for ii, lO in enumerate(tHMMobj.X):
MSD.append(get_MSD(lO.cell_to_parent, tHMMobj.estimate.pi, tHMMobj.estimate.T))
NF.append(get_leaf_Normalizing_Factors(lO.leaves_idx, MSD[ii], EL[ii]))
NF.append(get_leaf_Normalizing_Factors(MSD[ii], EL[ii]))
betas.append(
get_beta(
lO.leaves_idx,
lO.cell_to_daughters,
tHMMobj.estimate.T,
MSD[ii],
EL[ii],
Expand Down Expand Up @@ -171,14 +169,13 @@ def do_M_T_step(
for num, lO in enumerate(tt.X):
# local T estimate
numer_e += get_all_zetas(
lO.leaves_idx,
lO.cell_to_daughters,
betas[i][num],
MSD[i][num],
gammas[i][num],
tt.estimate.T,
)
denom_e += sum_nonleaf_gammas(lO.leaves_idx, gammas[i][num])
denom_e += sum_nonleaf_gammas(gammas[i][num])

T_estimate = numer_e / denom_e[:, np.newaxis]
T_estimate /= T_estimate.sum(axis=1)[:, np.newaxis]
Expand Down
118 changes: 0 additions & 118 deletions lineage/CellVar.py

This file was deleted.

41 changes: 8 additions & 33 deletions lineage/HMM/E_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from numba import njit


@njit
def get_leaf_Normalizing_Factors(
leaves_idx: npt.NDArray[np.uintp],
MSD: npt.NDArray[np.float64],
EL: npt.NDArray[np.float64],
) -> npt.NDArray[np.float64]:
Expand Down Expand Up @@ -36,6 +34,7 @@ def get_leaf_Normalizing_Factors(
:return: normalizing factor. The marginal observation distribution P(x_n = x)
"""
NF_array = np.zeros(MSD.shape[0], dtype=float) # instantiating N by 1 array
leaves_idx = slice(int(MSD.shape[0] / 2), MSD.shape[0])

# P(x_n = x , z_n = k) = P(x_n = x | z_n = k) * P(z_n = k)
# this product is the joint probability
Expand All @@ -46,7 +45,6 @@ def get_leaf_Normalizing_Factors(
return NF_array


@njit
def get_MSD(
cell_to_parent: np.ndarray, pi: npt.NDArray[np.float64], T: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
Expand Down Expand Up @@ -83,25 +81,7 @@ def get_MSD(
return m


@njit
def np_apply_along_axis(func1d, axis, arr):
assert arr.ndim == 2
assert axis in [0, 1]
if axis == 0:
result = np.empty(arr.shape[1])
for i in range(len(result)):
result[i] = func1d(arr[:, i])
else:
result = np.empty(arr.shape[0])
for i in range(len(result)):
result[i] = func1d(arr[i, :])
return result


@njit
def get_beta(
leaves_idx: npt.NDArray[np.uintp],
cell_to_daughters: npt.NDArray[np.intp],
T: npt.NDArray[np.float64],
MSD: npt.NDArray[np.float64],
EL: npt.NDArray[np.float64],
Expand Down Expand Up @@ -136,13 +116,13 @@ def get_beta(
the terms in the NF equation. This term is also used in the calculation
of the betas.
:param tHMMobj: A class object with properties of the lineages of cells
:param MSD: The marginal state distribution P(z_n = k)
:param EL: The emissions likelihood
:param NF: normalizing factor. The marginal observation distribution P(x_n = x)
:return: beta values. The conditional probability of states, given observations of the sub-tree rooted in cell_n
"""
beta = np.zeros_like(MSD)
leaves_idx = slice(MSD.shape[0] / 2, MSD.shape[0])

# Emission Likelihood, Marginal State Distribution, Normalizing Factor (same regardless of state)
# P(x_n = x | z_n = k), P(z_n = k), P(x_n = x)
Expand All @@ -157,24 +137,19 @@ def get_beta(
) # MSD of the respective lineage
ELMSD = EL * MSD

cIDXs = np.arange(MSD.shape[0])
cIDXs = np.delete(cIDXs, leaves_idx)
cIDXs = np.flip(cIDXs)

for pii in cIDXs:
ch_ii = cell_to_daughters[pii, :]
# Loop over non-leaves backwards
for pii in reversed(range(int(MSD.shape[0] / 2))):
ch_ii = np.array([pii*2+1, pii*2+2]).T
ratt = (beta[ch_ii, :] / MSD_array[ch_ii, :]) @ T.T
fac1 = np_apply_along_axis(np.prod, axis=0, arr=ratt) * ELMSD[pii, :]
fac1 = np.prod(ratt, axis=0) * ELMSD[pii, :]

NF[pii] = np.sum(fac1)
beta[pii, :] = fac1 / NF[pii]

return beta


@njit
def get_gamma(
cell_to_daughters: npt.NDArray[np.uintp],
T: npt.NDArray[np.float64],
MSD: npt.NDArray[np.float64],
beta: npt.NDArray[np.float64],
Expand All @@ -198,8 +173,8 @@ def get_gamma(
beta_parents = T @ coeffs.T

# Getting lineage by generation, but it is sorted this way
for pidx, cis in enumerate(cell_to_daughters):
for ci in cis:
for pidx in range(int(gamma.shape[0] / 2)):
for ci in (pidx*2+1, pidx*2+2):
if ci == -1:
continue

Expand Down
21 changes: 9 additions & 12 deletions lineage/HMM/M_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def sum_nonleaf_gammas(
leaves_idx, gammas: npt.NDArray[np.float64]
gammas: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
"""
Sum of the gammas of the cells that are able to divide, that is,
Expand All @@ -12,20 +12,15 @@ def sum_nonleaf_gammas(
This is downward recursion.
:param lO: the object of lineage tree
:param gamma_arr: the gamma values for each lineage
:return: the sum of gamma values for each state for non-leaf cells.
"""
# Remove leaves
gs = np.delete(gammas, leaves_idx, axis=0)

# sum the gammas for cells that are transitioning (all but gen 0)
return np.sum(gs[1:, :], axis=0)
# also leave out leaves
return np.sum(gammas[1:int(gammas.shape[0] / 2), :], axis=0)


def get_all_zetas(
leaves_idx: npt.NDArray[np.uintp],
cell_to_daughters: npt.NDArray[np.uintp],
beta_array: npt.NDArray[np.float64],
MSD_array: npt.NDArray[np.float64],
gammas: npt.NDArray[np.float64],
Expand All @@ -42,15 +37,17 @@ def get_all_zetas(
:param T: transition probability matrix
:return: numerator for calculating the transition probabilities
"""

betaMSD = beta_array / np.clip(MSD_array, np.finfo(float).eps, np.inf)
TbetaMSD = np.clip(betaMSD @ T.T, np.finfo(float).eps, np.inf)

cIDXs = np.arange(gammas.shape[0])
cIDXs = np.delete(cIDXs, leaves_idx, axis=0)
nonleaf_cell_idxs = np.arange(int(gammas.shape[0] / 2))

cell_to_daughters = 2 * np.arange(nonleaf_cell_idxs.size)[:, np.newaxis] + np.array([1, 2])[np.newaxis, :]

dIDXs = cell_to_daughters[cIDXs, :]
dIDXs = cell_to_daughters[nonleaf_cell_idxs, :]

# Getting lineage by generation, but it is sorted this way
js = gammas[cIDXs, np.newaxis, :] / TbetaMSD[dIDXs, :]
js = gammas[nonleaf_cell_idxs, np.newaxis, :] / TbetaMSD[dIDXs, :]
holder = np.einsum("ijk,ijl->kl", js, betaMSD[dIDXs, :])
return holder * T
Loading

0 comments on commit 2e88d3a

Please sign in to comment.