Skip to content

Commit

Permalink
Move population attr to lineage
Browse files Browse the repository at this point in the history
Finish moving population to Lineage in Python

Get population on lineage working for C

FIXUP python algorithms code

fixup Pyhton c interface

Tidy Python
  • Loading branch information
jeromekelleher committed Jul 31, 2024
1 parent f5376bc commit be2644d
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 129 deletions.
120 changes: 62 additions & 58 deletions algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ class Segment:
prev: Segment = None
next: Segment = None # noqa: A003
lineage: Lineage = None
population: int = -1
# label: int = 0

def __repr__(self):
return repr((self.left, self.right, self.node))
Expand All @@ -145,10 +143,11 @@ def show_chain(seg):
return s[:-2]

def __lt__(self, other):
return (self.left, self.right, self.population, self.node) < (
# TODO not clear here why we need population in the key?
return (self.left, self.right, self.lineage.population, self.node) < (
other.left,
other.right,
other.population,
other.lineage.population,
self.node,
)

Expand Down Expand Up @@ -730,12 +729,14 @@ def __repr__(self):
@dataclasses.dataclass
class Lineage:
head: Segment
population: int
hull: Hull = None
label: int = 0

def __str__(self):
s = (
f"Lineage(id={hex(id(self))},label={self.label},hull={self.hull},"
f"Lineage(id={hex(id(self))},"
f"population={self.population},label={self.label},hull={self.hull},"
f"head={self.head.index},"
f"chain={Segment.show_chain(self.head)})"
)
Expand Down Expand Up @@ -969,9 +970,12 @@ def __init__(
def initialise(self, ts):
root_time = np.max(self.tables.nodes.time)
self.t = root_time

# Note: this is done slightly differently to the C code, which
# stores the root segments so that we can implement sampling
# events easily.
root_segments_head = [None for _ in range(ts.num_nodes)]
root_segments_tail = [None for _ in range(ts.num_nodes)]
root_lineages = [None for _ in range(ts.num_nodes)]
last_S = -1
for tree in ts.trees():
left, right = tree.interval
Expand All @@ -985,7 +989,9 @@ def initialise(self, ts):
for root in tree.roots:
population = ts.node(root).population
if root_segments_head[root] is None:
seg = self.alloc_segment(left, right, root, population)
seg = self.alloc_segment(left, right, root)
lineage = self.alloc_lineage(seg, population)
root_lineages[root] = lineage
root_segments_head[root] = seg
root_segments_tail[root] = seg
else:
Expand All @@ -996,29 +1002,29 @@ def initialise(self, ts):
seg = self.alloc_segment(
left, right, root, population, tail
)
seg.lineage = root_lineages[root]
tail.next = seg
root_segments_tail[root] = seg
self.S[self.L] = -1

# Insert the segment chains into the algorithm state.
for node in range(ts.num_nodes):
seg = root_segments_head[node]
if seg is not None:
lineage = root_lineages[node]
if lineage is not None:
seg = lineage.head
left_end = seg.left
pop = seg.population
lineage = self.alloc_lineage(seg)
self.P[seg.population].add(lineage)
while seg is not None:
self.set_segment_mass(seg)
seg = seg.next
self.add_lineage(lineage)

if self.model == "smc_k":
for node in range(ts.num_nodes):
seg = root_segments_head[node]
if seg is not None:
lineage = root_lineages[node]
if lineage is not None:
seg = lineage.head
left_end = seg.left
pop = seg.population
lineage = seg.lineage
pop = lineage.population
label = lineage.label
right_end = root_segments_tail[node].right
new_hull = self.alloc_hull(left_end, right_end, lineage)
Expand Down Expand Up @@ -1085,7 +1091,7 @@ def alloc_segment(
left,
right,
node,
population,
population=None,
prev=None,
next=None, # noqa: A002
lineage=None,
Expand All @@ -1097,14 +1103,13 @@ def alloc_segment(
s.left = left
s.right = right
s.node = node
s.population = population
s.next = next
s.prev = prev
s.lineage = lineage
return s

def alloc_lineage(self, head, *, label=0):
lineage = Lineage(head, label=label)
def alloc_lineage(self, head, population, *, label=0):
lineage = Lineage(head, population=population, label=label)
lineage.reset_segments()
x = head
while x is not None:
Expand All @@ -1117,7 +1122,6 @@ def copy_segment(self, segment):
left=segment.left,
right=segment.right,
node=segment.node,
population=segment.population,
next=segment.next,
prev=segment.prev,
lineage=segment.lineage,
Expand Down Expand Up @@ -1189,7 +1193,7 @@ def store_edge(self, left, right, parent, child):
)

def add_lineage(self, lineage):
pop = lineage.head.population
pop = lineage.population
self.P[pop].add(lineage, lineage.label)
# print("add", lineage)
x = lineage.head
Expand Down Expand Up @@ -1615,7 +1619,7 @@ def process_pedigree_common_ancestors(self, ind, ploid):
# ancestor in this ploid of this individual. First we remove
# them from the populations they are stored in:
for _, seg in common_ancestors:
pop = self.P[seg.population]
pop = self.P[seg.lineage.population]
pop.remove_individual(seg.lineage)

# Merge together these lists of ancestral segments to create the
Expand Down Expand Up @@ -1723,11 +1727,7 @@ def migration_event(self, j, k):
if self.additional_nodes.value & msprime.NODE_IS_MIG_EVENT > 0:
self.store_node(k, flags=msprime.NODE_IS_MIG_EVENT)
self.store_arg_edges(x)
# Set the population id for each segment also.
u = x
while u is not None:
u.population = k
u = u.next
lineage.population = k

def get_recomb_left_bound(self, seg):
"""
Expand Down Expand Up @@ -1794,6 +1794,8 @@ def hudson_recombination_event(self, label, return_heads=False):
"""
self.num_re_events += 1
y, bp = self.choose_breakpoint(self.recomb_mass_index[label], self.recomb_map)
left_lineage = y.lineage
assert left_lineage.label == label
x = y.prev
if y.left < bp:
# x y
Expand Down Expand Up @@ -1825,25 +1827,25 @@ def hudson_recombination_event(self, label, return_heads=False):
alpha = y
lhs_tail = x

right_lineage = self.alloc_lineage(alpha, label=label)
right_lineage = self.alloc_lineage(alpha, left_lineage.population, label=label)
if self.model == "smc_k":
# modify original hull
pop = alpha.population
pop = left_lineage.population
lhs_hull = lhs_tail.get_hull()
rhs_right = lhs_hull.right
lhs_hull.right = min(lhs_tail.right + self.hull_offset, self.L)
self.P[pop].reset_hull_right(label, lhs_hull, rhs_right, lhs_hull.right)

# create hull for alpha
alpha_hull = self.alloc_hull(alpha.left, rhs_right, right_lineage)
self.P[alpha.population].add_hull(label, alpha_hull)
self.P[pop].add_hull(label, alpha_hull)

self.set_segment_mass(alpha)
self.P[alpha.population].add(right_lineage, label)
self.add_lineage(right_lineage)
if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0:
self.store_node(lhs_tail.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_node(left_lineage.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_arg_edges(lhs_tail)
self.store_node(alpha.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_node(right_lineage.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_arg_edges(alpha)
ret = None
if return_heads:
Expand Down Expand Up @@ -1888,7 +1890,8 @@ def wiuf_gene_conversion_within_event(self, label):
self.num_gc_events += 1
hull = y.get_hull()
assert (self.model == "smc_k") == (hull is not None)
pop = y.population
lineage = y.lineage
pop = lineage.population
reset_right = -1

# Process left break
Expand Down Expand Up @@ -1998,12 +2001,12 @@ def wiuf_gene_conversion_within_event(self, label):
elif head is not None:
new_individual_head = head
if new_individual_head is not None:
lineage = self.alloc_lineage(new_individual_head)
lineage = self.alloc_lineage(new_individual_head, pop)
if self.model == "smc_k":
assert hull_left < hull_right
hull_right = min(self.L, hull_right + self.hull_offset)
hull = self.alloc_hull(hull_left, hull_right, lineage)
self.P[new_individual_head.population].add_hull(lineage.label, hull)
self.P[lineage.population].add_hull(lineage.label, hull)
self.add_lineage(lineage)

def wiuf_gene_conversion_left_event(self, label):
Expand Down Expand Up @@ -2033,7 +2036,8 @@ def wiuf_gene_conversion_left_event(self, label):

self.num_gc_events += 1
x = y.prev
pop = y.population
lineage = y.lineage
pop = lineage.population
lhs_hull = y.get_hull()
assert (self.model == "smc_k") == (lhs_hull is not None)
if y.left < bp:
Expand Down Expand Up @@ -2081,30 +2085,30 @@ def wiuf_gene_conversion_left_event(self, label):

self.set_segment_mass(alpha)
assert alpha.prev is None
lineage = self.alloc_lineage(alpha)
self.P[alpha.population].add(lineage, label)
lineage = self.alloc_lineage(alpha, pop)
self.add_lineage(lineage)

def hudson_recombination_event_sweep_phase(self, label, sweep_site, pop_freq):
"""
Implements a recombination event in during a selective sweep.
"""
left_lin, right_lin = self.hudson_recombination_event(label, return_heads=True)
lhs = left_lin.head
rhs = right_lin.head

r = random.random()
if sweep_site < rhs.left:
if sweep_site < right_lin.head.left:
if r < 1.0 - pop_freq:
# move rhs to other population
self.P[rhs.population].remove_individual(right_lin, right_lin.label)
self.P[right_lin.population].remove_individual(
right_lin, right_lin.label
)
self.set_labels(right_lin, 1 - label)
self.P[rhs.population].add(right_lin, right_lin.label)
self.P[right_lin.population].add(right_lin, right_lin.label)
else:
if r < 1.0 - pop_freq:
# move lhs to other population
self.P[rhs.population].remove_individual(left_lin, left_lin.label)
self.P[left_lin.population].remove_individual(left_lin, left_lin.label)
self.set_labels(left_lin, 1 - label)
self.P[lhs.population].add(left_lin, left_lin.label)
self.P[left_lin.population].add(left_lin, left_lin.label)

def dtwf_generate_breakpoint(self, start):
left_bound = start + 1 if self.discrete_genome else start
Expand Down Expand Up @@ -2213,7 +2217,7 @@ def dtwf_recombine(self, lineage, ind_nodes):
lineage.reset_segments()
ret.append(lineage)
else:
ret.append(self.alloc_lineage(seg))
ret.append(self.alloc_lineage(seg, lineage.population))

return ret

Expand Down Expand Up @@ -2245,9 +2249,8 @@ def bottleneck_event(self, pop_id, label, intensity):
def store_additional_nodes_edges(self, flag, new_node_id, z):
if self.additional_nodes.value & flag > 0:
if new_node_id == -1:
new_node_id = self.store_node(z.population, flags=flag)
else:
self.update_node_flag(new_node_id, flag)
new_node_id = self.store_node(z.lineage.population)
self.update_node_flag(new_node_id, flag)
self.store_arg_edges(z, new_node_id)
return new_node_id

Expand All @@ -2273,7 +2276,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
if len(X) == 1:
x = X[0]
if len(H) > 0 and H[0][0] < x.right:
alpha = self.alloc_segment(x.left, H[0][0], x.node, x.population)
alpha = self.alloc_segment(x.left, H[0][0], x.node)
x.left = H[0][0]
heapq.heappush(H, (x.left, x))
else:
Expand Down Expand Up @@ -2320,7 +2323,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
# loop tail; update alpha and integrate it into the state.
if alpha is not None:
if z is None:
new_lineage = self.alloc_lineage(alpha)
new_lineage = self.alloc_lineage(alpha, pop_id)
pop.add(new_lineage, label)
else:
if (coalescence and not self.coalescing_segments_only) or (
Expand Down Expand Up @@ -2517,7 +2520,9 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):
# loop tail; update alpha and integrate it into the state.
if alpha is not None:
if z is None:
new_lineage = self.alloc_lineage(alpha, label=label)
new_lineage = self.alloc_lineage(
alpha, population_index, label=label
)
else:
if (coalescence and not self.coalescing_segments_only) or (
self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0
Expand Down Expand Up @@ -2629,23 +2634,22 @@ def print_state(self, verify=False):
self.verify()

def verify_segments(self):
for pop in self.P:
for pop_index, pop in enumerate(self.P):
for label in range(self.num_labels):
for lineage in pop.iter_label(label):
assert isinstance(lineage, Lineage)
assert lineage.label == label
assert lineage.population == pop_index
head = lineage.head
assert head.lineage is lineage
assert head.prev is None
prev = head
u = head.next
# print("LIN", lineage)
while u is not None:
assert u.lineage == lineage
assert prev.next is u
assert u.prev is prev
assert u.left >= prev.right
assert u.population == head.population
prev = u
u = u.next

Expand Down Expand Up @@ -2702,10 +2706,10 @@ def verify_mass_index(self, label, mass_index, rate_map, compute_left_bound):
for pop_index, pop in enumerate(self.P):
for lineage in pop.iter_label(label):
u = lineage.head
assert lineage.population == pop_index
assert u.prev is None
left = compute_left_bound(u)
while u is not None:
assert u.population == pop_index
assert u.left < u.right
left_bound = compute_left_bound(u)
s = rate_map.mass_between(left_bound, u.right)
Expand Down
Loading

0 comments on commit be2644d

Please sign in to comment.