Skip to content

Commit

Permalink
Partial updates for removing population attr
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 30, 2024
1 parent 9b9c7bc commit 68d5951
Showing 1 changed file with 56 additions and 47 deletions.
103 changes: 56 additions & 47 deletions algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def find(self, v):

# Once we drop support for 3.9 we can use slots=True to prevent
# writing extra attrs.
@dataclasses.dataclass # (slots=True)
@dataclasses.dataclass(slots=True) # FIXME
class Segment:
"""
A class representing a single segment. Each segment has a left
Expand All @@ -130,7 +130,7 @@ class Segment:
prev: Segment = None
next: Segment = None # noqa: A003
lineage: Lineage = None
population: int = -1
# population: int = -1
# label: int = 0

def __repr__(self):
Expand All @@ -145,10 +145,11 @@ def show_chain(seg):
return s[:-2]

def __lt__(self, other):
return (self.left, self.right, self.population, self.node) < (
# TODO not clear here why we need population in the key?
return (self.left, self.right, self.lineage.population, self.node) < (
other.left,
other.right,
other.population,
other.lineage.population,
self.node,
)

Expand Down Expand Up @@ -730,12 +731,14 @@ def __repr__(self):
@dataclasses.dataclass
class Lineage:
head: Segment
population: int
hull: Hull = None
label: int = 0

def __str__(self):
s = (
f"Lineage(id={hex(id(self))},label={self.label},hull={self.hull},"
f"Lineage(id={hex(id(self))},"
f"population={self.population},label={self.label},hull={self.hull},"
f"head={self.head.index},"
f"chain={Segment.show_chain(self.head)})"
)
Expand Down Expand Up @@ -969,9 +972,11 @@ def __init__(
def initialise(self, ts):
root_time = np.max(self.tables.nodes.time)
self.t = root_time

# TODO get rid of root_segments_head/tail when we add the
# tail attr to lineage
root_segments_head = [None for _ in range(ts.num_nodes)]
root_segments_tail = [None for _ in range(ts.num_nodes)]
root_lineages = [None for _ in range(ts.num_nodes)]
last_S = -1
for tree in ts.trees():
left, right = tree.interval
Expand All @@ -985,7 +990,9 @@ def initialise(self, ts):
for root in tree.roots:
population = ts.node(root).population
if root_segments_head[root] is None:
seg = self.alloc_segment(left, right, root, population)
seg = self.alloc_segment(left, right, root)
lineage = self.alloc_lineage(seg, population)
root_lineages[root] = lineage
root_segments_head[root] = seg
root_segments_tail[root] = seg
else:
Expand All @@ -1002,12 +1009,11 @@ def initialise(self, ts):

# Insert the segment chains into the algorithm state.
for node in range(ts.num_nodes):
seg = root_segments_head[node]
if seg is not None:
lineage = root_lineages[node]
if lineage is not None:
seg = lineage.head
left_end = seg.left
pop = seg.population
lineage = self.alloc_lineage(seg)
self.P[seg.population].add(lineage)
self.add_lineage(lineage)
while seg is not None:
self.set_segment_mass(seg)
seg = seg.next
Expand Down Expand Up @@ -1085,7 +1091,7 @@ def alloc_segment(
left,
right,
node,
population,
population=None,
prev=None,
next=None, # noqa: A002
lineage=None,
Expand All @@ -1097,14 +1103,13 @@ def alloc_segment(
s.left = left
s.right = right
s.node = node
s.population = population
s.next = next
s.prev = prev
s.lineage = lineage
return s

def alloc_lineage(self, head, *, label=0):
lineage = Lineage(head, label=label)
def alloc_lineage(self, head, population, *, label=0):
lineage = Lineage(head, population=population, label=label)
lineage.reset_segments()
x = head
while x is not None:
Expand All @@ -1117,7 +1122,6 @@ def copy_segment(self, segment):
left=segment.left,
right=segment.right,
node=segment.node,
population=segment.population,
next=segment.next,
prev=segment.prev,
lineage=segment.lineage,
Expand Down Expand Up @@ -1189,7 +1193,7 @@ def store_edge(self, left, right, parent, child):
)

def add_lineage(self, lineage):
pop = lineage.head.population
pop = lineage.population
self.P[pop].add(lineage, lineage.label)
# print("add", lineage)
x = lineage.head
Expand Down Expand Up @@ -1615,7 +1619,7 @@ def process_pedigree_common_ancestors(self, ind, ploid):
# ancestor in this ploid of this individual. First we remove
# them from the populations they are stored in:
for _, seg in common_ancestors:
pop = self.P[seg.population]
pop = self.P[seg.lineage.population]
pop.remove_individual(seg.lineage)

# Merge together these lists of ancestral segments to create the
Expand Down Expand Up @@ -1723,11 +1727,7 @@ def migration_event(self, j, k):
if self.additional_nodes.value & msprime.NODE_IS_MIG_EVENT > 0:
self.store_node(k, flags=msprime.NODE_IS_MIG_EVENT)
self.store_arg_edges(x)
# Set the population id for each segment also.
u = x
while u is not None:
u.population = k
u = u.next
lineage.population = k

def get_recomb_left_bound(self, seg):
"""
Expand Down Expand Up @@ -1794,6 +1794,8 @@ def hudson_recombination_event(self, label, return_heads=False):
"""
self.num_re_events += 1
y, bp = self.choose_breakpoint(self.recomb_mass_index[label], self.recomb_map)
left_lineage = y.lineage
assert left_lineage.label == label
x = y.prev
if y.left < bp:
# x y
Expand Down Expand Up @@ -1825,7 +1827,7 @@ def hudson_recombination_event(self, label, return_heads=False):
alpha = y
lhs_tail = x

right_lineage = self.alloc_lineage(alpha, label=label)
right_lineage = self.alloc_lineage(alpha, left_lineage.population, label=label)
if self.model == "smc_k":
# modify original hull
pop = alpha.population
Expand All @@ -1839,11 +1841,11 @@ def hudson_recombination_event(self, label, return_heads=False):
self.P[alpha.population].add_hull(label, alpha_hull)

self.set_segment_mass(alpha)
self.P[alpha.population].add(right_lineage, label)
self.add_lineage(right_lineage)
if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0:
self.store_node(lhs_tail.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_node(left_lineage.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_arg_edges(lhs_tail)
self.store_node(alpha.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_node(right_lineage.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_arg_edges(alpha)
ret = None
if return_heads:
Expand Down Expand Up @@ -1888,7 +1890,8 @@ def wiuf_gene_conversion_within_event(self, label):
self.num_gc_events += 1
hull = y.get_hull()
assert (self.model == "smc_k") == (hull is not None)
pop = y.population
lineage = y.lineage
pop = lineage.population
reset_right = -1

# Process left break
Expand Down Expand Up @@ -1998,7 +2001,7 @@ def wiuf_gene_conversion_within_event(self, label):
elif head is not None:
new_individual_head = head
if new_individual_head is not None:
lineage = self.alloc_lineage(new_individual_head)
lineage = self.alloc_lineage(new_individual_head, pop)
if self.model == "smc_k":
assert hull_left < hull_right
hull_right = min(self.L, hull_right + self.hull_offset)
Expand Down Expand Up @@ -2033,7 +2036,8 @@ def wiuf_gene_conversion_left_event(self, label):

self.num_gc_events += 1
x = y.prev
pop = y.population
lineage = y.lineage
pop = lineage.population
lhs_hull = y.get_hull()
assert (self.model == "smc_k") == (lhs_hull is not None)
if y.left < bp:
Expand Down Expand Up @@ -2081,8 +2085,8 @@ def wiuf_gene_conversion_left_event(self, label):

self.set_segment_mass(alpha)
assert alpha.prev is None
lineage = self.alloc_lineage(alpha)
self.P[alpha.population].add(lineage, label)
lineage = self.alloc_lineage(alpha, pop)
self.add_lineage(lineage)

def hudson_recombination_event_sweep_phase(self, label, sweep_site, pop_freq):
"""
Expand All @@ -2096,15 +2100,17 @@ def hudson_recombination_event_sweep_phase(self, label, sweep_site, pop_freq):
if sweep_site < rhs.left:
if r < 1.0 - pop_freq:
# move rhs to other population
self.P[rhs.population].remove_individual(right_lin, right_lin.label)
self.P[right_lin.population].remove_individual(
right_lin, right_lin.label
)
self.set_labels(right_lin, 1 - label)
self.P[rhs.population].add(right_lin, right_lin.label)
self.P[right_lin.population].add(right_lin, right_lin.label)
else:
if r < 1.0 - pop_freq:
# move lhs to other population
self.P[rhs.population].remove_individual(left_lin, left_lin.label)
self.P[left_lin.population].remove_individual(left_lin, left_lin.label)
self.set_labels(left_lin, 1 - label)
self.P[lhs.population].add(left_lin, left_lin.label)
self.P[left_lin.population].add(left_lin, left_lin.label)

def dtwf_generate_breakpoint(self, start):
left_bound = start + 1 if self.discrete_genome else start
Expand Down Expand Up @@ -2213,7 +2219,7 @@ def dtwf_recombine(self, lineage, ind_nodes):
lineage.reset_segments()
ret.append(lineage)
else:
ret.append(self.alloc_lineage(seg))
ret.append(self.alloc_lineage(seg, lineage.population))

return ret

Expand Down Expand Up @@ -2243,11 +2249,12 @@ def bottleneck_event(self, pop_id, label, intensity):
self.merge_ancestors(H, pop_id, label)

def store_additional_nodes_edges(self, flag, new_node_id, z):
# FIXME
if self.additional_nodes.value & flag > 0:
if new_node_id == -1:
new_node_id = self.store_node(z.population, flags=flag)
else:
self.update_node_flag(new_node_id, flag)
self.store_node(population_index)
new_node_id = len(self.tables.nodes) - 1
self.update_node_flag(new_node_id, flag)
self.store_arg_edges(z, new_node_id)
return new_node_id

Expand All @@ -2273,7 +2280,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
if len(X) == 1:
x = X[0]
if len(H) > 0 and H[0][0] < x.right:
alpha = self.alloc_segment(x.left, H[0][0], x.node, x.population)
alpha = self.alloc_segment(x.left, H[0][0], x.node)
x.left = H[0][0]
heapq.heappush(H, (x.left, x))
else:
Expand Down Expand Up @@ -2320,7 +2327,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
# loop tail; update alpha and integrate it into the state.
if alpha is not None:
if z is None:
new_lineage = self.alloc_lineage(alpha)
new_lineage = self.alloc_lineage(alpha, pop_id)
pop.add(new_lineage, label)
else:
if (coalescence and not self.coalescing_segments_only) or (
Expand Down Expand Up @@ -2517,7 +2524,9 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):
# loop tail; update alpha and integrate it into the state.
if alpha is not None:
if z is None:
new_lineage = self.alloc_lineage(alpha, label=label)
new_lineage = self.alloc_lineage(
alpha, population_index, label=label
)
else:
if (coalescence and not self.coalescing_segments_only) or (
self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0
Expand Down Expand Up @@ -2629,11 +2638,12 @@ def print_state(self, verify=False):
self.verify()

def verify_segments(self):
for pop in self.P:
for pop_index, pop in enumerate(self.P):
for label in range(self.num_labels):
for lineage in pop.iter_label(label):
assert isinstance(lineage, Lineage)
assert lineage.label == label
assert lineage.population == pop_index
head = lineage.head
assert head.lineage is lineage
assert head.prev is None
Expand All @@ -2645,7 +2655,6 @@ def verify_segments(self):
assert prev.next is u
assert u.prev is prev
assert u.left >= prev.right
assert u.population == head.population
prev = u
u = u.next

Expand Down Expand Up @@ -2702,10 +2711,10 @@ def verify_mass_index(self, label, mass_index, rate_map, compute_left_bound):
for pop_index, pop in enumerate(self.P):
for lineage in pop.iter_label(label):
u = lineage.head
assert lineage.population == pop_index
assert u.prev is None
left = compute_left_bound(u)
while u is not None:
assert u.population == pop_index
assert u.left < u.right
left_bound = compute_left_bound(u)
s = rate_map.mass_between(left_bound, u.right)
Expand Down

0 comments on commit 68d5951

Please sign in to comment.