diff --git a/algorithms.py b/algorithms.py index 5925b16b2..58bc448b5 100644 --- a/algorithms.py +++ b/algorithms.py @@ -1,7 +1,10 @@ """ Python version of the simulation algorithm. """ +from __future__ import annotations + import argparse +import dataclasses import heapq import itertools import logging @@ -110,26 +113,42 @@ def find(self, v): return j + 1 +@dataclasses.dataclass(repr=False, eq=False) class Segment: """ A class representing a single segment. Each segment has a left and right, denoting the loci over which it spans, a node and a next, giving the next in the chain. """ - - def __init__(self, index): - self.left = None - self.right = None - self.node = None - self.prev = None - self.next = None - self.population = None - self.label = 0 - self.index = index - self.hull = None - - def __repr__(self): - return repr((self.left, self.right, self.node)) + index: int + left: float = None + right: float = None + node: int = None + prev: Segment = None + next: Segment = None + lineage: Lineage = None + # REMOVE + population: int = None + label: int = 0 + hull: Hull = None + + # def __repr__(self): + # return object.__repr__(self) + +# def __init__(self, index): +# self.left = None +# self.right = None +# self.node = None +# self.prev = None +# self.next = None +# self.index = index +# self.lineage = None +# self.population = None +# self.label = 0 +# self.hull = None + + # def __repr__(self): + # return repr((self.left, self.right, self.node)) @staticmethod def show_chain(seg): @@ -164,6 +183,27 @@ def get_left_index(self): return index +@dataclasses.dataclass(eq=False) +class Lineage: + """ + A class representing a single lineage. Each lineage has a segment chain + represented by a linked list of Segments, which are accessed by the + head and tail attributes. + """ + + head: Segment = None + tail: Segment = None + population: int = -1 + label: int = -1 + # hull + + def summary(self): + return ( + f"pop:{self.population}; label:{self.label}; " + f"head={self.head.index}; tail={self.tail.index}: {Segment.show_chain(self.head)}" + ) + + class Population: """ Class representing a population in the simulation. @@ -204,8 +244,9 @@ def print_state(self): print("\tAncestors: ", len(self._ancestors)) for label, ancestors in enumerate(self._ancestors): print("\tLabel = ", label) - for u in ancestors: - print("\t\t" + Segment.show_chain(u)) + for lin in ancestors: + # print(f"\t\t{lin}") + print(f"\t\t{lin.summary()}") # Segment.show_chain(u)) def set_growth_rate(self, growth_rate, time): # TODO This doesn't work because we need to know what the time @@ -404,6 +445,7 @@ def add_hull(self, label, hull): coal_mass_index = self.coal_mass_index[label] self.increment_avl(ost_left, coal_mass_index, hull, 1) + # TODO change all "individual" references here to "lineage" def add(self, individual, label=0): """ Inserts the specified individual into this population. @@ -947,6 +989,8 @@ def initialise(self, ts): root_time = np.max(self.tables.nodes.time) self.t = root_time + root_lineages = [None for _ in range(ts.num_nodes)] + root_segments_head = [None for _ in range(ts.num_nodes)] root_segments_tail = [None for _ in range(ts.num_nodes)] last_S = -1 @@ -961,30 +1005,38 @@ def initialise(self, ts): if tree.num_roots > 1: for root in tree.roots: population = ts.node(root).population - if root_segments_head[root] is None: - seg = self.alloc_segment(left, right, root, population) - root_segments_head[root] = seg - root_segments_tail[root] = seg + if root_lineages[root] is None: + seg = self.alloc_segment(left, right, root) + root_lineages[root] = self.alloc_lineage( + seg, seg, population, label=0 + ) + seg.lineage = root_lineages[root] + # # if root_segments_head[root] is None: + # seg = self.alloc_segment(left, right, root, population) + # root_segments_head[root] = seg + # root_segments_tail[root] = seg else: - tail = root_segments_tail[root] + tail = root_lineages[root].tail if tail.right == left: tail.right = right else: - seg = self.alloc_segment( - left, right, root, population, tail - ) + seg = self.alloc_segment(left, right, root) + seg.prev = tail tail.next = seg - root_segments_tail[root] = seg + seg.lineage = root_lineages[root] + root_lineages[root].tail = seg self.S[self.L] = -1 # Insert the segment chains into the algorithm state. for node in range(ts.num_nodes): - seg = root_segments_head[node] - if seg is not None: - left_end = seg.left - pop = seg.population - label = seg.label - self.P[seg.population].add(seg) + # seg = root_segments_head[node] + lin = root_lineages[node] + if lin is not None: + left_end = lin.head.left + pop = lin.population + label = lin.label + self.P[pop].add(lin) + seg = lin.head while seg is not None: self.set_segment_mass(seg) seg = seg.next @@ -1066,7 +1118,7 @@ def alloc_segment( left, right, node, - population, + population=-1, prev=None, next=None, # noqa: A002 label=0, @@ -1086,6 +1138,17 @@ def alloc_segment( s.hull = hull return s + def alloc_lineage(self, head=None, tail=None, population=None, label=None): + """ + Pops a new lineage off the stack and sets its properties. + """ + lin = Lineage() + lin.head = head + lin.tail = tail + lin.population = population + lin.label = label + return lin + def copy_segment(self, segment): return self.alloc_segment( left=segment.left, @@ -1776,6 +1839,7 @@ def hudson_recombination_event(self, label, return_heads=False): y.right = bp self.set_segment_mass(y) lhs_tail = y + left_lineage = y.lineage else: # x y # ===== | ========= ... @@ -1788,6 +1852,18 @@ def hudson_recombination_event(self, label, return_heads=False): y.prev = None alpha = y lhs_tail = x + left_lineage = x.lineage + + left_lineage.tail = lhs_tail + right_lineage = self.alloc_lineage( + alpha, population=left_lineage.population, label=label + ) + seg = right_lineage.head + while seg is not None: + right_lineage.tail = seg + seg.lineage = right_lineage + seg = seg.next + if self.model == "smc_k": # modify original hull @@ -1802,19 +1878,25 @@ def hudson_recombination_event(self, label, return_heads=False): self.P[alpha.population].add_hull(label, alpha_hull) self.set_segment_mass(alpha) - self.P[alpha.population].add(alpha, label) + self.P[alpha.population].add(right_lineage, label) if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0: self.store_node(lhs_tail.population, flags=msprime.NODE_IS_RE_EVENT) self.store_arg_edges(lhs_tail) self.store_node(alpha.population, flags=msprime.NODE_IS_RE_EVENT) self.store_arg_edges(alpha) + + self.verify() + ret = None if return_heads: + # FIXME return heads is now obsolete x = lhs_tail # Seek back to the head of the x chain while x.prev is not None: x = x.prev ret = x, alpha + + # self.print_state() return ret def generate_gc_tract_length(self): @@ -2305,8 +2387,8 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): self.defrag_breakpoints() return merged_head - def defrag_segment_chain(self, z): - y = z + def defrag_segment_chain(self, lineage): + y = lineage.tail while y.prev is not None: x = y.prev if x.right == y.left and x.node == y.node: @@ -2374,6 +2456,7 @@ def common_ancestor_event(self, population_index, label): else: # Choose two ancestors uniformly. + self.verify() j = random.randint(0, pop.get_num_ancestors(label) - 1) x = pop.remove(j, label) j = random.randint(0, pop.get_num_ancestors(label) - 1) @@ -2381,13 +2464,16 @@ def common_ancestor_event(self, population_index, label): self.merge_two_ancestors(population_index, label, x, y) - def merge_two_ancestors(self, population_index, label, x, y, u=-1): + def merge_two_ancestors(self, population_index, label, lin_x, lin_y, u=-1): + pop = self.P[population_index] self.num_ca_events += 1 - z = None + new_lineage = None merged_head = None coalescence = False defrag_required = False + x = lin_x.head + y = lin_y.head while x is not None or y is not None: alpha = None if x is None or y is None: @@ -2462,10 +2548,13 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): # loop tail; update alpha and integrate it into the state. if alpha is not None: - if z is None: - pop.add(alpha, label) - merged_head = alpha + print("ADD alpha", alpha, repr(new_lineage)) + if new_lineage is None: + new_lineage = self.alloc_lineage(alpha, alpha, population=population_index, label=label) + pop.add(new_lineage, label) + z = None else: + z = new_lineage.tail if (coalescence and not self.coalescing_segments_only) or ( self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0 ): @@ -2475,9 +2564,10 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): z.right == alpha.left and z.node == alpha.node ) z.next = alpha + new_lineage.tail = alpha + alpha.lineage = new_lineage alpha.prev = z self.set_segment_mass(alpha) - z = alpha if coalescence: if not self.coalescing_segments_only: @@ -2487,11 +2577,15 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): self.store_additional_nodes_edges(msprime.NODE_IS_CA_EVENT, u, z) if defrag_required: - self.defrag_segment_chain(z) + self.defrag_segment_chain(new_lineage) if coalescence: self.defrag_breakpoints() + self.verify() + # self.print_state() + if merged_head is not None and self.model == "smc_k": + assert False # get rid of merged_head assert merged_head.prev is None hull = self.alloc_hull(merged_head.left, merged_head.right, merged_head) while merged_head is not None: @@ -2505,15 +2599,19 @@ def print_state(self, verify=False): for label in range(self.num_labels): print( "Recomb mass = ", - 0 - if self.recomb_mass_index is None - else self.recomb_mass_index[label].get_total(), + ( + 0 + if self.recomb_mass_index is None + else self.recomb_mass_index[label].get_total() + ), ) print( "GC mass = ", - 0 - if self.gc_mass_index is None - else self.gc_mass_index[label].get_total(), + ( + 0 + if self.gc_mass_index is None + else self.gc_mass_index[label].get_total() + ), ) print("Modifier events = ") for t, f, args in self.modifier_events: @@ -2562,26 +2660,41 @@ def print_state(self, verify=False): self.verify() def verify_segments(self): - for pop in self.P: + print("VERIFY") + for pop_index, pop in enumerate(self.P): for label in range(self.num_labels): - for head in pop.iter_label(label): + assert len(set(id(lin) for lin in pop._ancestors[label])) == len(pop._ancestors[label]) + segment_ids = set() + for lin in pop.iter_label(label): + assert lin.population == pop_index + assert lin.label == label + head = lin.head assert head.prev is None prev = head u = head.next + print(id(lin), repr(lin)) while u is not None: + print(u) + assert u.index not in segment_ids + segment_ids.add(u.index) + print("\t", u.lineage) + # print(u, u.lineage) + assert u.lineage is lin assert prev.next is u assert u.prev is prev assert u.left >= prev.right - assert u.label == head.label - assert u.population == head.population + # assert u.label == head.label + # assert u.population == head.population prev = u u = u.next + assert lin.tail == prev def verify_overlaps(self): overlap_counter = OverlapCounter(self.L) for pop in self.P: for label in range(self.num_labels): - for u in pop.iter_label(label): + for lin in pop.iter_label(label): + u = lin.head while u is not None: overlap_counter.increment_interval(u.left, u.right) u = u.next @@ -2597,7 +2710,8 @@ def verify_overlaps(self): A[self.L] = -1 for pop in self.P: for label in range(self.num_labels): - for u in pop.iter_label(label): + for lin in pop.iter_label(label): + u = lin.head while u is not None: if u.left not in A: k = A.floor_key(u.left) @@ -2626,11 +2740,12 @@ def verify_mass_index(self, label, mass_index, rate_map, compute_left_bound): total_mass = 0 alt_total_mass = 0 for pop_index, pop in enumerate(self.P): - for u in pop.iter_label(label): + for lin in pop.iter_label(label): + assert lin.population == pop_index + u = lin.head assert u.prev is None left = compute_left_bound(u) while u is not None: - assert u.population == pop_index assert u.left < u.right left_bound = compute_left_bound(u) s = rate_map.mass_between(left_bound, u.right)