Skip to content

Commit

Permalink
Working
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmey committed Feb 24, 2024
1 parent 4e1d3ae commit 6c28062
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 174 deletions.
24 changes: 15 additions & 9 deletions lineage/BaumWelch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,18 @@ def do_E_step(tHMMobj: tHMM) -> Tuple[list, list, list, list]:
EL = get_Emission_Likelihoods(tHMMobj.X, tHMMobj.estimate.E)

for ii, lO in enumerate(tHMMobj.X):
MSD.append(get_MSD(lO.cell_to_parent, tHMMobj.estimate.pi, tHMMobj.estimate.T))
NF.append(get_leaf_Normalizing_Factors(lO.leaves_idx, MSD[ii], EL[ii]))
MSD.append(get_MSD(len(lO), tHMMobj.estimate.pi, tHMMobj.estimate.T))
NF.append(get_leaf_Normalizing_Factors(MSD[ii], EL[ii]))
betas.append(
get_beta(
lO.leaves_idx,
lO.cell_to_daughters,
tHMMobj.estimate.T,
MSD[ii],
EL[ii],
NF[ii],
)
)
gammas.append(
get_gamma(lO.cell_to_daughters, tHMMobj.estimate.T, MSD[ii], betas[ii])
get_gamma(tHMMobj.estimate.T, MSD[ii], betas[ii])
)

return MSD, NF, betas, gammas
Expand Down Expand Up @@ -172,14 +170,12 @@ def do_M_T_step(
for num, lO in enumerate(tt.X):
# local T estimate
numer_e += get_all_zetas(
lO.leaves_idx,
lO.cell_to_daughters,
betas[i][num],
MSD[i][num],
gammas[i][num],
tt.estimate.T,
)
denom_e += sum_nonleaf_gammas(lO.leaves_idx, gammas[i][num])
denom_e += sum_nonleaf_gammas(gammas[i][num])

T_estimate = numer_e / denom_e[:, np.newaxis]
T_estimate /= T_estimate.sum(axis=1)[:, np.newaxis]
Expand All @@ -200,7 +196,17 @@ def do_M_E_step(tHMMobj: tHMM, gammas: list[np.ndarray]):
:type tHMMobj: object
:param gammas: gamma values. The conditional probability of states, given the observation of the whole tree
"""
all_cells = [cell.obs for lineage in tHMMobj.X for cell in lineage.output_lineage]
all_cells: list[np.ndarray] = []

for lineage in tHMMobj.X:
for cell in lineage.output_lineage:
if cell is None:
all_cells.append(-1 * np.ones(all_cells[0].size))
else:
all_cells.append(np.array(cell.obs))

all_cells = np.array(all_cells) # type: ignore

all_gammas = np.vstack(gammas)
for state_j in range(tHMMobj.num_states):
tHMMobj.estimate.E[state_j].estimator(all_cells, all_gammas[:, state_j])
Expand Down
35 changes: 16 additions & 19 deletions lineage/HMM/E_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


def get_leaf_Normalizing_Factors(
leaves_idx: npt.NDArray[np.uintp],
MSD: npt.NDArray[np.float64],
EL: npt.NDArray[np.float64],
) -> npt.NDArray[np.float64]:
Expand Down Expand Up @@ -34,18 +33,19 @@ def get_leaf_Normalizing_Factors(
:return: normalizing factor. The marginal observation distribution P(x_n = x)
"""
NF_array = np.zeros(MSD.shape[0], dtype=float) # instantiating N by 1 array
first_leaf = int(np.floor(MSD.shape[0] / 2))

# P(x_n = x , z_n = k) = P(x_n = x | z_n = k) * P(z_n = k)
# this product is the joint probability
# P(x_n = x) = sum_k ( P(x_n = x , z_n = k) )
# the sum of the joint probabilities is the marginal probability
NF_array[leaves_idx] = np.sum(MSD[leaves_idx, :] * EL[leaves_idx, :], axis=1)
NF_array[first_leaf:] = np.sum(MSD[first_leaf:, :] * EL[first_leaf:, :], axis=1)
assert np.all(np.isfinite(NF_array))
return NF_array


def get_MSD(
cell_to_parent: np.ndarray, pi: npt.NDArray[np.float64], T: npt.NDArray[np.float64]
n_cells: int, pi: npt.NDArray[np.float64], T: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
r"""Marginal State Distribution (MSD) matrix by upward recursion.
This is the probability that a hidden state variable :math:`z_n` is of
Expand All @@ -68,21 +68,20 @@ def get_MSD(
:param T: State transitions matrix
:return: The marginal state distribution
"""
m = np.zeros((cell_to_parent.size, pi.size))
m = np.zeros((n_cells, pi.size))
m[0, :] = pi

# recursion based on parent cell
for cIDX, pIDX in enumerate(cell_to_parent[1:]):
m[cIDX + 1, :] = m[pIDX, :] @ T
for cIDX in range(1, n_cells):
pIDX = int(np.floor(cIDX / 2))
m[cIDX, :] = m[pIDX, :] @ T

# Assert all ~= 1.0
assert np.linalg.norm(np.sum(m, axis=1) - 1.0) < 1e-9
return m


def get_beta(
leaves_idx: npt.NDArray[np.uintp],
cell_to_daughters: npt.NDArray[np.intp],
T: npt.NDArray[np.float64],
MSD: npt.NDArray[np.float64],
EL: npt.NDArray[np.float64],
Expand Down Expand Up @@ -124,11 +123,12 @@ def get_beta(
:return: beta values. The conditional probability of states, given observations of the sub-tree rooted in cell_n
"""
beta = np.zeros_like(MSD)
first_leaf = int(np.floor(MSD.shape[0] / 2))

# Emission Likelihood, Marginal State Distribution, Normalizing Factor (same regardless of state)
# P(x_n = x | z_n = k), P(z_n = k), P(x_n = x)
ZZ = EL[leaves_idx, :] * MSD[leaves_idx, :] / NF[leaves_idx, np.newaxis]
beta[leaves_idx, :] = ZZ
ZZ = EL[first_leaf:, :] * MSD[first_leaf:, :] / NF[first_leaf:, np.newaxis]
beta[first_leaf:, :] = ZZ

# Assert all ~= 1.0
assert np.abs(np.sum(beta[-1]) - 1.0) < 1e-9
Expand All @@ -138,12 +138,11 @@ def get_beta(
) # MSD of the respective lineage
ELMSD = EL * MSD

cIDXs = np.arange(MSD.shape[0])
cIDXs = np.delete(cIDXs, leaves_idx)
cIDXs = np.arange(first_leaf)
cIDXs = np.flip(cIDXs)

for pii in cIDXs:
ch_ii = cell_to_daughters[pii, :]
ch_ii = np.array([pii * 2 + 1, pii * 2 + 2])
ratt = (beta[ch_ii, :] / MSD_array[ch_ii, :]) @ T.T
fac1 = np.prod(ratt, axis=0) * ELMSD[pii, :]

Expand All @@ -154,7 +153,6 @@ def get_beta(


def get_gamma(
cell_to_daughters: npt.NDArray[np.uintp],
T: npt.NDArray[np.float64],
MSD: npt.NDArray[np.float64],
beta: npt.NDArray[np.float64],
Expand All @@ -177,12 +175,11 @@ def get_gamma(
coeffs = np.maximum(coeffs, epss)
beta_parents = T @ coeffs.T

# Getting lineage by generation, but it is sorted this way
for pidx, cis in enumerate(cell_to_daughters):
for ci in cis:
if ci == -1:
continue
first_leaf = int(np.floor(MSD.shape[0] / 2))

# Getting lineage by generation, but it is sorted this way
for pidx in range(first_leaf):
for ci in [pidx * 2 + 1, pidx * 2 + 2]:
A = gamma[pidx, :].T / beta_parents[:, ci]

gamma[ci, :] = coeffs[ci, :] * (A @ T)
Expand Down
17 changes: 5 additions & 12 deletions lineage/HMM/M_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import numpy.typing as npt


def sum_nonleaf_gammas(
leaves_idx, gammas: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
def sum_nonleaf_gammas(gammas: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
"""
Sum of the gammas of the cells that are able to divide, that is,
sum the of the gammas of all the nonleaf cells. It is used in estimating the transition probability matrix.
Expand All @@ -16,16 +14,13 @@ def sum_nonleaf_gammas(
:param gamma_arr: the gamma values for each lineage
:return: the sum of gamma values for each state for non-leaf cells.
"""
# Remove leaves
gs = np.delete(gammas, leaves_idx, axis=0)
first_leaf = int(np.floor(gammas.shape[0] / 2))

# sum the gammas for cells that are transitioning (all but gen 0)
return np.sum(gs[1:, :], axis=0)
return np.sum(gammas[1:first_leaf, :], axis=0)


def get_all_zetas(
leaves_idx: npt.NDArray[np.uintp],
cell_to_daughters: npt.NDArray[np.uintp],
beta_array: npt.NDArray[np.float64],
MSD_array: npt.NDArray[np.float64],
gammas: npt.NDArray[np.float64],
Expand All @@ -45,10 +40,8 @@ def get_all_zetas(
betaMSD = beta_array / np.clip(MSD_array, np.finfo(float).eps, np.inf)
TbetaMSD = np.clip(betaMSD @ T.T, np.finfo(float).eps, np.inf)

cIDXs = np.arange(gammas.shape[0])
cIDXs = np.delete(cIDXs, leaves_idx, axis=0)

dIDXs = cell_to_daughters[cIDXs, :]
cIDXs = np.arange(int(np.floor(gammas.shape[0] / 2)) - 1)
dIDXs = np.vstack((cIDXs * 2 + 1, cIDXs * 2 + 2)).T

# Getting lineage by generation, but it is sorted this way
js = gammas[cIDXs, np.newaxis, :] / TbetaMSD[dIDXs, :]
Expand Down
84 changes: 23 additions & 61 deletions lineage/LineageTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,26 @@ class LineageTree:

pi: npt.NDArray[np.float64]
T: npt.NDArray[np.float64]
leaves_idx: np.ndarray
idx_by_gen: list[np.ndarray]
output_lineage: list[CellVar]
cell_to_parent: np.ndarray
cell_to_daughters: np.ndarray

def __init__(self, list_of_cells: list, E: list):
self.E = E
# output_lineage must be sorted according to generation
self.output_lineage = sorted(list_of_cells, key=operator.attrgetter("gen"))
self.idx_by_gen = max_gen(self.output_lineage)
# sort according to generation
sorted_cells = sorted(list_of_cells, key=operator.attrgetter("gen"))

# assign times using the state distribution specific time model
E[0].assign_times(self.output_lineage)
E[0].assign_times(sorted_cells)

# add root
self.output_lineage = [sorted_cells[0]]

self.cell_to_parent = cell_to_parent(self.output_lineage)
self.cell_to_daughters = cell_to_daughters(self.output_lineage)
# build remaining tree
for parent_idx in range(2**(sorted_cells[-1].gen - 1) - 1):
parent = self.output_lineage[parent_idx]

# Leaves have no daughters
self.leaves_idx = np.nonzero(np.all(self.cell_to_daughters == -1, axis=1))[0]
if parent is not None:
self.output_lineage.append(parent.left)
self.output_lineage.append(parent.right)

@classmethod
def rand_init(
Expand Down Expand Up @@ -121,7 +122,16 @@ def get_Emission_Likelihoods(X: list[LineageTree], E: list) -> list:
:param E: The emissions likelihood
:return: The marginal state distribution
"""
all_cells = np.array([cell.obs for lineage in X for cell in lineage.output_lineage])
all_cells: list[np.ndarray] = []

for lineage in X:
for cell in lineage.output_lineage:
if cell is None:
all_cells.append(-1 * np.ones(all_cells[0].size))
else:
all_cells.append(np.array(cell.obs))

all_cells = np.array(all_cells) # type: ignore
ELstack = np.zeros((len(all_cells), len(E)))

for k in range(len(E)): # for each state
Expand All @@ -136,51 +146,3 @@ def get_Emission_Likelihoods(X: list[LineageTree], E: list) -> list:
ii += nl

return EL


def max_gen(lineage: list[CellVar]) -> list[np.ndarray]:
"""
Finds the maximal generation in the tree, and cells organized by their generations.
This walks through the cells in a given lineage, finds the maximal generation, and the group of cells belonging to a same generation and
creates a list of them, appends the lists leading to have a list of the lists of cells in specific generations.
:param lineage: The list of cells in a lineageTree object.
:return max: The maximal generation in the tree.
:return cells_by_gen: The list of lists of cells belonging to the same generation separated by specific generations.
"""
gens = sorted(
{cell.gen for cell in lineage}
) # appending the generation of cells in the lineage
cells_by_gen: list[np.ndarray] = []
for gen in gens:
level = np.array(
[
lineage.index(cell)
for cell in lineage
if (cell.gen == gen and cell.observed)
],
dtype=int,
)
cells_by_gen.append(level)
return cells_by_gen


def cell_to_parent(lineage: list[CellVar]) -> np.ndarray:
output = np.full(len(lineage), -1, dtype=int)
for ii, cell in enumerate(lineage):
parent = cell.parent
if parent is not None:
output[ii] = lineage.index(parent)

return output


def cell_to_daughters(lineage: list[CellVar]) -> np.ndarray:
output = np.full((len(lineage), 2), -1, dtype=int)
for ii, cell in enumerate(lineage):
if cell.left is not None and cell.left in lineage:
output[ii, 0] = lineage.index(cell.left)

if cell.right is not None and cell.right in lineage:
output[ii, 1] = lineage.index(cell.right)

return output
Loading

0 comments on commit 6c28062

Please sign in to comment.