diff --git a/algorithms.py b/algorithms.py index 5925b16b2..dec9976e7 100644 --- a/algorithms.py +++ b/algorithms.py @@ -2,6 +2,7 @@ Python version of the simulation algorithm. """ import argparse +import dataclasses import heapq import itertools import logging @@ -127,6 +128,7 @@ def __init__(self, index): self.label = 0 self.index = index self.hull = None + self.lineage = None def __repr__(self): return repr((self.left, self.right, self.node)) @@ -164,6 +166,11 @@ def get_left_index(self): return index +@dataclasses.dataclass +class Lineage: + head: Segment + + class Population: """ Class representing a population in the simulation. @@ -204,8 +211,8 @@ def print_state(self): print("\tAncestors: ", len(self._ancestors)) for label, ancestors in enumerate(self._ancestors): print("\tLabel = ", label) - for u in ancestors: - print("\t\t" + Segment.show_chain(u)) + for lineage in ancestors: + print("\t\t" + Segment.show_chain(lineage.head)) def set_growth_rate(self, growth_rate, time): # TODO This doesn't work because we need to know what the time @@ -366,6 +373,7 @@ def remove_individual(self, individual, label=0): """ Removes the given individual from its population. """ + assert isinstance(individual, Lineage) return self._ancestors[label].remove(individual) def add_hull(self, label, hull): @@ -408,7 +416,8 @@ def add(self, individual, label=0): """ Inserts the specified individual into this population. """ - assert individual.label == label + assert isinstance(individual, Lineage) + assert individual.head.label == label self._ancestors[label].append(individual) def __iter__(self): @@ -429,12 +438,6 @@ def iter_ancestors(self): for ancestors in self._ancestors: yield from ancestors - def find_indv(self, indv): - """ - find the index of an ancestor in population - """ - return self._ancestors[indv.label].index(indv) - class Pedigree: """ @@ -694,7 +697,7 @@ class Hull: def __init__(self, index): self.left = None self.right = None - self.lineage_head = None + self.lineage = None self.index = index self.insertion_order = math.inf @@ -984,7 +987,8 @@ def initialise(self, ts): left_end = seg.left pop = seg.population label = seg.label - self.P[seg.population].add(seg) + lineage = self.alloc_lineage(seg) + self.P[seg.population].add(lineage) while seg is not None: self.set_segment_mass(seg) seg = seg.next @@ -996,9 +1000,9 @@ def initialise(self, ts): left_end = seg.left pop = seg.population label = seg.label - lineage_head = seg + lineage = seg.lineage right_end = root_segments_tail[node].right - new_hull = self.alloc_hull(left_end, right_end, lineage_head) + new_hull = self.alloc_hull(left_end, right_end, lineage) # insert Hull floor = self.P[pop].hulls_left[label].floor_key(new_hull) insertion_order = 0 @@ -1049,15 +1053,15 @@ def change_population_growth_rate(self, pop_id, rate, time): def change_migration_matrix_element(self, pop_i, pop_j, rate): self.migration_matrix[pop_i][pop_j] = rate - def alloc_hull(self, left, right, lineage_head): - alpha = lineage_head + def alloc_hull(self, left, right, lineage): + alpha = lineage.head hull = self.hull_stack.pop() hull.left = left hull.right = right while alpha.prev is not None: alpha = alpha.prev assert alpha is not None - hull.lineage_head = alpha + hull.lineage = lineage alpha.hull = hull return hull @@ -1086,6 +1090,11 @@ def alloc_segment( s.hull = hull return s + def alloc_lineage(self, head): + lineage = Lineage(head) + head.lineage = lineage + return lineage + def copy_segment(self, segment): return self.alloc_segment( left=segment.left, @@ -1171,11 +1180,11 @@ def finalise(self): # Insert unary edges for any remainining lineages. current_time = self.t for population in self.P: - for ancestor in population.iter_ancestors(): + for lineage in population.iter_ancestors(): node = tskit.NULL # See if there is already a node in this ancestor at the # current time - seg = ancestor + seg = lineage.head while seg is not None: if self.tables.nodes[seg.node].time == current_time: node = seg.node @@ -1187,7 +1196,7 @@ def finalise(self): flags=0, time=current_time, population=population.id ) # Add in edges pointing to this ancestor - seg = ancestor + seg = lineage.head while seg is not None: if seg.node != node: self.tables.edges.add_row(seg.left, seg.right, node, seg.node) @@ -1383,12 +1392,12 @@ def single_sweep_simulate(self): # a bit ugly with the two loops because # of dealing with the pops indices = [] - for idx, u in enumerate(self.P[0].iter_label(0)): + for idx, lineage in enumerate(self.P[0].iter_label(0)): if random.random() < x: - self.set_labels(u, 1) + self.set_labels(lineage, 1) indices.append(idx) else: - assert u.label == 0 + assert lineage.head.label == 0 popped = 0 for i in indices: tmp = self.P[0].remove(i - popped, 0) @@ -1469,9 +1478,9 @@ def single_sweep_simulate(self): 0, self.sweep_site, 1.0 - x ) # clean up the labels at end - for idx, u in enumerate(self.P[0].iter_label(1)): - tmp = self.P[0].remove(idx, u.label) - self.set_labels(u, 0) + for idx, lineage in enumerate(self.P[0].iter_label(1)): + tmp = self.P[0].remove(idx, label=1) + self.set_labels(lineage, 0) self.P[0].add(tmp) def pedigree_simulate(self): @@ -1524,18 +1533,17 @@ def dtwf_generation(self): parent_nodes = [-1, -1] H = [[], []] for child in children: - segs_pair = self.dtwf_recombine(child, parent_nodes) - for seg in segs_pair: - if seg is not None and seg.index != child.index: - pop.add(seg) + lin_pair = self.dtwf_recombine(child, parent_nodes) + for lin in lin_pair: + if lin is not None and lin != child: + pop.add(lin) self.verify() # Collect segments inherited from the same individual - for i, seg in enumerate(segs_pair): - if seg is None: - continue - assert seg.prev is None - heapq.heappush(H[i], (seg.left, seg)) + for i, lin in enumerate(lin_pair): + if lin is not None: + assert lin.head.prev is None + heapq.heappush(H[i], (lin.head.left, lin.head)) # Merge segments for ploid, h in enumerate(H): @@ -1552,8 +1560,8 @@ def dtwf_generation(self): ) h = [] elif segments_to_merge >= 2: - for _, individual in h: - pop.remove_individual(individual) + for _, seg in h: + pop.remove_individual(seg.lineage) # parent_nodes[ploid] does not need to be updated here if segments_to_merge == 2: self.merge_two_ancestors( @@ -1580,9 +1588,9 @@ def process_pedigree_common_ancestors(self, ind, ploid): # All the segment chains in common_ancestors reach a common # ancestor in this ploid of this individual. First we remove # them from the populations they are stored in: - for _, anc in common_ancestors: - pop = self.P[anc.population] - pop.remove_individual(anc) + for _, seg in common_ancestors: + pop = self.P[seg.population] + pop.remove_individual(seg.lineage) # Merge together these lists of ancestral segments to create the # monoploid genome for this ploid of this individual. @@ -1600,7 +1608,7 @@ def process_pedigree_common_ancestors(self, ind, ploid): # simulation because we are *not* simulating the entire # population process, only the subset that we have information # about within the pedigree. - seg = genome + seg = genome.head while seg is not None: if seg.node != node: self.store_edge(seg.left, seg.right, parent=node, child=seg.node) @@ -1613,15 +1621,16 @@ def process_pedigree_common_ancestors(self, ind, ploid): # to create two independent lines of ancestry. parent = self.pedigree.individuals[ind.parents[ploid]] parent_ancestry = self.dtwf_recombine(genome, parent.nodes) + assert len(parent_ancestry) == ind.ploidy for parent_ploid in range(ind.ploidy): - seg = parent_ancestry[parent_ploid] - if seg is not None: + parent_lin = parent_ancestry[parent_ploid] + if parent_lin is not None: # Add this segment chain of ancestry to the accumulating # set in the parent on the corresponding ploid. - parent.add_common_ancestor(seg, ploid=parent_ploid) - if seg != genome: + parent.add_common_ancestor(parent_lin.head, ploid=parent_ploid) + if parent_lin != genome: # Add the recombined ancestor to the population - pop.add(seg) + pop.add(parent_lin) self.flush_edges() self.verify() @@ -1636,11 +1645,12 @@ def dtwf_climb_pedigree(self): # Go through the extant lineages and gather the ancestral material # into the corresponding pedigree individuals. - for anc in pop.iter_ancestors(): - node = self.tables.nodes[anc.node] + for lineage in pop.iter_ancestors(): + u = lineage.head.node + node = self.tables.nodes[u] assert node.individual != tskit.NULL ind = self.pedigree.individuals[node.individual] - ind.add_common_ancestor(anc, ploid=ind.nodes.index(anc.node)) + ind.add_common_ancestor(lineage.head, ploid=ind.nodes.index(u)) # Visit pedigree individuals in time order. visit_order = sorted(self.pedigree.individuals, key=lambda x: (x.time, x.id)) @@ -1676,10 +1686,11 @@ def migration_event(self, j, k): source = self.P[j] dest = self.P[k] index = random.randint(0, source.get_num_ancestors(label) - 1) - x = source.remove(index, label) + lineage = source.remove(index, label) + x = lineage.head hull = x.get_hull() assert (self.model == "smc_k") == (hull is not None) - dest.add(x, label) + dest.add(lineage, label) if self.model == "smc_k": source.remove_hull(label, hull) dest.add_hull(label, hull) @@ -1723,11 +1734,12 @@ def set_segment_mass(self, seg): gc_mass = self.gc_map.mass_between(gc_left_bound, seg.right) mass_index.set_value(seg.index, gc_mass) - def set_labels(self, segment, new_label): + def set_labels(self, lineage, new_label): """ - Move the specified segment to the specified label. + Move the specified lineage to the specified label. """ mass_indexes = [self.recomb_mass_index, self.gc_mass_index] + segment = lineage.head while segment is not None: masses = [] for mass_index in mass_indexes: @@ -1789,6 +1801,7 @@ def hudson_recombination_event(self, label, return_heads=False): alpha = y lhs_tail = x + right_lineage = self.alloc_lineage(alpha) if self.model == "smc_k": # modify original hull pop = alpha.population @@ -1798,11 +1811,11 @@ def hudson_recombination_event(self, label, return_heads=False): self.P[pop].reset_hull_right(label, lhs_hull, rhs_right, lhs_hull.right) # create hull for alpha - alpha_hull = self.alloc_hull(alpha.left, rhs_right, alpha) + alpha_hull = self.alloc_hull(alpha.left, rhs_right, right_lineage) self.P[alpha.population].add_hull(label, alpha_hull) self.set_segment_mass(alpha) - self.P[alpha.population].add(alpha, label) + self.P[alpha.population].add(right_lineage, label) if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0: self.store_node(lhs_tail.population, flags=msprime.NODE_IS_RE_EVENT) self.store_arg_edges(lhs_tail) @@ -1814,7 +1827,8 @@ def hudson_recombination_event(self, label, return_heads=False): # Seek back to the head of the x chain while x.prev is not None: x = x.prev - ret = x, alpha + left_lineage = x.lineage + ret = left_lineage, right_lineage return ret def generate_gc_tract_length(self): @@ -1959,15 +1973,16 @@ def wiuf_gene_conversion_within_event(self, label): elif head is not None: new_individual_head = head if new_individual_head is not None: + lineage = self.alloc_lineage(new_individual_head) if self.model == "smc_k": assert hull_left < hull_right hull_right = min(self.L, hull_right + self.hull_offset) - hull = self.alloc_hull(hull_left, hull_right, new_individual_head) + hull = self.alloc_hull(hull_left, hull_right, lineage) self.P[new_individual_head.population].add_hull( new_individual_head.label, hull ) self.P[new_individual_head.population].add( - new_individual_head, new_individual_head.label + lineage, new_individual_head.label ) def wiuf_gene_conversion_left_event(self, label): @@ -1976,7 +1991,8 @@ def wiuf_gene_conversion_left_event(self, label): """ random_gc_left = random.uniform(0, self.get_total_gc_left(label)) # Get segment where gene conversion starts from left - y = self.find_cleft_individual(label, random_gc_left) + lineage = self.find_cleft_individual(label, random_gc_left) + y = lineage.head assert y is not None # generate tract_length @@ -2044,29 +2060,30 @@ def wiuf_gene_conversion_left_event(self, label): self.set_segment_mass(alpha) assert alpha.prev is None - self.P[alpha.population].add(alpha, label) + lineage = self.alloc_lineage(alpha) + self.P[alpha.population].add(lineage, label) def hudson_recombination_event_sweep_phase(self, label, sweep_site, pop_freq): """ Implements a recombination event in during a selective sweep. """ - lhs, rhs = self.hudson_recombination_event(label, return_heads=True) + left_lin, right_lin = self.hudson_recombination_event(label, return_heads=True) + lhs = left_lin.head + rhs = right_lin.head r = random.random() if sweep_site < rhs.left: if r < 1.0 - pop_freq: # move rhs to other population - t_idx = self.P[rhs.population].find_indv(rhs) - self.P[rhs.population].remove(t_idx, rhs.label) - self.set_labels(rhs, 1 - label) - self.P[rhs.population].add(rhs, rhs.label) + self.P[rhs.population].remove_individual(right_lin, rhs.label) + self.set_labels(right_lin, 1 - label) + self.P[rhs.population].add(right_lin, rhs.label) else: if r < 1.0 - pop_freq: # move lhs to other population - t_idx = self.P[lhs.population].find_indv(lhs) - self.P[lhs.population].remove(t_idx, lhs.label) - self.set_labels(lhs, 1 - label) - self.P[lhs.population].add(lhs, lhs.label) + self.P[rhs.population].remove_individual(left_lin, lhs.label) + self.set_labels(left_lin, 1 - label) + self.P[lhs.population].add(left_lin, lhs.label) def dtwf_generate_breakpoint(self, start): left_bound = start + 1 if self.discrete_genome else start @@ -2076,7 +2093,7 @@ def dtwf_generate_breakpoint(self, start): bp = math.floor(bp) return bp - def dtwf_recombine(self, x, ind_nodes): + def dtwf_recombine(self, lineage, ind_nodes): """ Chooses breakpoints and returns segments sorted by inheritance direction, by iterating through segment chain starting with x @@ -2084,6 +2101,7 @@ def dtwf_recombine(self, x, ind_nodes): u = self.alloc_segment(-1, -1, -1, -1, None, None) v = self.alloc_segment(-1, -1, -1, -1, None, None) seg_tails = [u, v] + x = lineage.head # TODO Should this be the recombination rate going foward from x.left? if self.recomb_map.total_mass > 0: @@ -2162,12 +2180,22 @@ def dtwf_recombine(self, x, ind_nodes): segment, ) - return u, v + ret = [] + for seg in [u, v]: + if seg is None: + ret.append(None) + else: + if seg.lineage is lineage: + ret.append(lineage) + else: + ret.append(self.alloc_lineage(seg)) + + return ret def census_event(self, time): for pop in self.P: - for ancestor in pop.iter_ancestors(): - seg = ancestor + for lineage in pop.iter_ancestors(): + seg = lineage.head u = self.tables.nodes.add_row( time=time, flags=msprime.NODE_IS_CEN_EVENT, population=pop.id ) @@ -2184,7 +2212,8 @@ def bottleneck_event(self, pop_id, label, intensity): H = [] for _ in range(pop.get_num_ancestors()): if random.random() < intensity: - x = pop.remove(0) + lineage = pop.remove(0) + x = lineage.head heapq.heappush(H, (x.left, x)) self.merge_ancestors(H, pop_id, label) @@ -2204,7 +2233,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): pass_through = len(H) == 1 alpha = None z = None - merged_head = None + new_lineage = None while len(H) > 0: alpha = None left = H[0][0] @@ -2267,8 +2296,8 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): # loop tail; update alpha and integrate it into the state. if alpha is not None: if z is None: - pop.add(alpha, label) - merged_head = alpha + new_lineage = self.alloc_lineage(alpha) + pop.add(new_lineage, label) else: if (coalescence and not self.coalescing_segments_only) or ( self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0 @@ -2303,7 +2332,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): self.defrag_segment_chain(z) if coalescence: self.defrag_breakpoints() - return merged_head + return new_lineage def defrag_segment_chain(self, z): y = z @@ -2363,11 +2392,11 @@ def common_ancestor_event(self, population_index, label): hull_i_ptr, hull_j_ptr = random_pair hull_i = self.hulls[hull_i_ptr] hull_j = self.hulls[hull_j_ptr] - x = hull_i.lineage_head - y = hull_j.lineage_head - pop.remove_individual(x, label) + x_lin = hull_i.lineage + y_lin = hull_j.lineage + pop.remove_individual(x_lin, label) pop.remove_hull(label, hull_i) - pop.remove_individual(y, label) + pop.remove_individual(y_lin, label) pop.remove_hull(label, hull_j) self.free_hull(hull_i) self.free_hull(hull_j) @@ -2375,17 +2404,18 @@ def common_ancestor_event(self, population_index, label): else: # Choose two ancestors uniformly. j = random.randint(0, pop.get_num_ancestors(label) - 1) - x = pop.remove(j, label) + x_lin = pop.remove(j, label) j = random.randint(0, pop.get_num_ancestors(label) - 1) - y = pop.remove(j, label) - + y_lin = pop.remove(j, label) + x = x_lin.head + y = y_lin.head self.merge_two_ancestors(population_index, label, x, y) def merge_two_ancestors(self, population_index, label, x, y, u=-1): pop = self.P[population_index] self.num_ca_events += 1 z = None - merged_head = None + new_lineage = None coalescence = False defrag_required = False while x is not None or y is not None: @@ -2463,8 +2493,8 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): # loop tail; update alpha and integrate it into the state. if alpha is not None: if z is None: - pop.add(alpha, label) - merged_head = alpha + new_lineage = self.alloc_lineage(alpha) + pop.add(new_lineage, label) else: if (coalescence and not self.coalescing_segments_only) or ( self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0 @@ -2491,9 +2521,10 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): if coalescence: self.defrag_breakpoints() - if merged_head is not None and self.model == "smc_k": + if new_lineage is not None and self.model == "smc_k": + merged_head = new_lineage.head assert merged_head.prev is None - hull = self.alloc_hull(merged_head.left, merged_head.right, merged_head) + hull = self.alloc_hull(merged_head.left, merged_head.right, new_lineage) while merged_head is not None: right = merged_head.right merged_head = merged_head.next @@ -2505,15 +2536,19 @@ def print_state(self, verify=False): for label in range(self.num_labels): print( "Recomb mass = ", - 0 - if self.recomb_mass_index is None - else self.recomb_mass_index[label].get_total(), + ( + 0 + if self.recomb_mass_index is None + else self.recomb_mass_index[label].get_total() + ), ) print( "GC mass = ", - 0 - if self.gc_mass_index is None - else self.gc_mass_index[label].get_total(), + ( + 0 + if self.gc_mass_index is None + else self.gc_mass_index[label].get_total() + ), ) print("Modifier events = ") for t, f, args in self.modifier_events: @@ -2564,7 +2599,10 @@ def print_state(self, verify=False): def verify_segments(self): for pop in self.P: for label in range(self.num_labels): - for head in pop.iter_label(label): + for lineage in pop.iter_label(label): + assert isinstance(lineage, Lineage) + head = lineage.head + assert head.lineage is lineage assert head.prev is None prev = head u = head.next @@ -2581,7 +2619,8 @@ def verify_overlaps(self): overlap_counter = OverlapCounter(self.L) for pop in self.P: for label in range(self.num_labels): - for u in pop.iter_label(label): + for lineage in pop.iter_label(label): + u = lineage.head while u is not None: overlap_counter.increment_interval(u.left, u.right) u = u.next @@ -2597,7 +2636,8 @@ def verify_overlaps(self): A[self.L] = -1 for pop in self.P: for label in range(self.num_labels): - for u in pop.iter_label(label): + for lineage in pop.iter_label(label): + u = lineage.head while u is not None: if u.left not in A: k = A.floor_key(u.left) @@ -2626,7 +2666,8 @@ def verify_mass_index(self, label, mass_index, rate_map, compute_left_bound): total_mass = 0 alt_total_mass = 0 for pop_index, pop in enumerate(self.P): - for u in pop.iter_label(label): + for lineage in pop.iter_label(label): + u = lineage.head assert u.prev is None left = compute_left_bound(u) while u is not None: @@ -2717,8 +2758,9 @@ def verify(self): self.verify_hulls() -def make_hull(a, L, offset=0): +def make_hull(lineage, L, offset=0): hull = Hull(-1) + a = lineage.head assert a.prev is None b = a tracked_hull = a.get_hull() @@ -2729,7 +2771,7 @@ def make_hull(a, L, offset=0): hull.right = min(right + offset, L) assert tracked_hull.left == hull.left assert tracked_hull.right == hull.right - assert tracked_hull.lineage_head == a + assert tracked_hull.lineage.head == a return hull diff --git a/lib/msprime.c b/lib/msprime.c index aeceb7248..a60b0dd46 100644 --- a/lib/msprime.c +++ b/lib/msprime.c @@ -1,5 +1,5 @@ /* -** Copyright (C) 2015-2021 University of Oxford +** Copyright (C) 2015-2024 University of Oxford ** ** This file is part of msprime. ** @@ -102,9 +102,11 @@ get_population_size(population_t *pop, double t) static int cmp_individual(const void *a, const void *b) { - const segment_t *ia = (const segment_t *) a; - const segment_t *ib = (const segment_t *) b; - return (ia->id > ib->id) - (ia->id < ib->id); + const lineage_t *ia = (const lineage_t *) a; + const lineage_t *ib = (const lineage_t *) b; + /* Compare by ID of the head segment to ensure reproducibility of + * results when we use the same seed */ + return (ia->head->id > ib->head->id) - (ia->head->id < ib->head->id); } /* For the segment priority queue we want to sort on the left @@ -197,8 +199,9 @@ segment_get_hull(segment_t *seg) while (seg->prev != NULL) { seg = seg->prev; } + tsk_bug_assert(seg->lineage != NULL); hull = seg->hull; - tsk_bug_assert(hull->lineage == seg); + tsk_bug_assert(hull->lineage == seg->lineage); return hull; } @@ -466,6 +469,7 @@ msp_reindex_segments(msp_t *self) avl_node_t *node; avl_tree_t *population_ancestors; segment_t *seg; + lineage_t *lin; size_t j; label_id_t label; @@ -473,7 +477,8 @@ msp_reindex_segments(msp_t *self) for (label = 0; label < (label_id_t) self->num_labels; label++) { population_ancestors = &self->populations[j].ancestors[label]; for (node = population_ancestors->head; node != NULL; node = node->next) { - for (seg = (segment_t *) node->item; seg != NULL; seg = seg->next) { + lin = (lineage_t *) node->item; + for (seg = lin->head; seg != NULL; seg = seg->next) { msp_set_segment_mass(self, seg); } } @@ -878,6 +883,26 @@ msp_alloc_segment(msp_t *self, double left, double right, tsk_id_t value, return seg; } +static lineage_t *MSP_WARN_UNUSED +msp_alloc_lineage(msp_t *self, segment_t *head) +{ + lineage_t *lin = NULL; + + if (object_heap_empty(&self->lineage_heap)) { + if (object_heap_expand(&self->lineage_heap) != 0) { + goto out; + } + } + lin = (lineage_t *) object_heap_alloc_object(&self->lineage_heap); + if (lin == NULL) { + goto out; + } + lin->head = head; + head->lineage = lin; +out: + return lin; +} + static segment_t *MSP_WARN_UNUSED msp_copy_segment(msp_t *self, const segment_t *seg) { @@ -886,13 +911,14 @@ msp_copy_segment(msp_t *self, const segment_t *seg) } static hull_t *MSP_WARN_UNUSED -msp_alloc_hull(msp_t *self, double left, double right, segment_t *lineage) +msp_alloc_hull(msp_t *self, double left, double right, lineage_t *lineage) { hull_t *hull = NULL; label_id_t label; uint32_t j; - label = lineage->label; + tsk_bug_assert(lineage != NULL); + label = lineage->head->label; if (object_heap_empty(&self->hull_heap[label])) { if (object_heap_expand(&self->hull_heap[label]) != 0) { @@ -923,9 +949,8 @@ msp_alloc_hull(msp_t *self, double left, double right, segment_t *lineage) hull->lineage = lineage; hull->count = 0; hull->insertion_order = UINT64_MAX; - tsk_bug_assert(lineage->prev == NULL); - lineage->hull = hull; - + tsk_bug_assert(lineage->head->prev == NULL); + lineage->head->hull = hull; out: return hull; } @@ -1041,6 +1066,11 @@ msp_alloc_memory_blocks(msp_t *self) if (ret != 0) { goto out; } + ret = object_heap_init( + &self->lineage_heap, sizeof(lineage_t), self->node_mapping_block_size, NULL); + if (ret != 0) { + goto out; + } /* allocate the segments */ for (j = 0; j < self->num_labels; j++) { ret = object_heap_init(&self->segment_heap[j], sizeof(segment_t), @@ -1141,6 +1171,7 @@ msp_free(msp_t *self) /* free the object heaps */ object_heap_free(&self->avl_node_heap); object_heap_free(&self->node_mapping_heap); + object_heap_free(&self->lineage_heap); rate_map_free(&self->recomb_map); rate_map_free(&self->gc_map); if (self->model.free != NULL) { @@ -1219,6 +1250,13 @@ msp_free_hullend(msp_t *self, hullend_t *hullend, label_id_t label) object_heap_free_object(&self->hullend_heap[label], hullend); } +static void +msp_free_lineage(msp_t *self, lineage_t *lineage) +{ + object_heap_free_object(&self->lineage_heap, lineage); + lineage->head = NULL; +} + /* * Returns the segment with the specified id. */ @@ -1296,7 +1334,7 @@ msp_insert_hull(msp_t *self, hull_t *hull) /* setting hull->count requires two steps step 1: num_starting before hull->left */ tsk_bug_assert(hull != NULL); - u = hull->lineage; + u = hull->lineage->head; hulls_left = &self->populations[u->population].hulls_left[u->label]; coal_mass_index = &self->populations[u->population].coal_mass_index[u->label]; /* insert hull into state */ @@ -1371,7 +1409,7 @@ msp_remove_hull(msp_t *self, hull_t *hull) fenwick_t *coal_mass_index; segment_t *u; - u = hull->lineage; + u = hull->lineage->head; tsk_bug_assert(u != NULL); hulls_left = &self->populations[u->population].hulls_left[u->label]; coal_mass_index = &self->populations[u->population].coal_mass_index[u->label]; @@ -1420,35 +1458,37 @@ msp_remove_hull(msp_t *self, hull_t *hull) } static inline int MSP_WARN_UNUSED -msp_insert_individual(msp_t *self, segment_t *u) +msp_insert_individual(msp_t *self, lineage_t *lin) { int ret = 0; avl_node_t *node; - tsk_bug_assert(u != NULL); + tsk_bug_assert(lin != NULL); + tsk_bug_assert(lin->head != NULL); node = msp_alloc_avl_node(self); if (node == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; } - avl_init_node(node, u); - node = avl_insert_node(msp_get_segment_population(self, u), node); + avl_init_node(node, lin); + node = avl_insert_node(msp_get_segment_population(self, lin->head), node); tsk_bug_assert(node != NULL); out: return ret; } static inline void -msp_remove_individual(msp_t *self, segment_t *u) +msp_remove_individual(msp_t *self, lineage_t *lin) { avl_node_t *node; - avl_tree_t *pop = msp_get_segment_population(self, u); - - tsk_bug_assert(u != NULL); - node = avl_search(pop, u); + avl_tree_t *pop; + tsk_bug_assert(lin != NULL); + pop = msp_get_segment_population(self, lin->head); + node = avl_search(pop, lin); tsk_bug_assert(node != NULL); avl_unlink_node(pop, node); msp_free_avl_node(self, node); + msp_free_lineage(self, lin); } static void @@ -1456,7 +1496,7 @@ msp_remove_individuals_from_population(msp_t *self, avl_tree_t *Q) { avl_node_t *node; for (node = Q->head; node != NULL; node = node->next) { - msp_remove_individual(self, (segment_t *) node->item); + msp_remove_individual(self, ((segment_t *) node->item)->lineage); } } @@ -1497,8 +1537,11 @@ static void msp_print_segment_chain(msp_t *MSP_UNUSED(self), segment_t *head, FILE *out) { segment_t *s = head; + lineage_t *lin = head->lineage; + + tsk_bug_assert(lin != NULL); - fprintf(out, "[pop=%d,label=%d]", s->population, s->label); + fprintf(out, "[%p,pop=%d,label=%d]", (void *) lin, s->population, s->label); while (s != NULL) { fprintf(out, "[(%.14g,%.14g) %d] ", s->left, s->right, (int) s->value); s = s->next; @@ -1517,6 +1560,7 @@ msp_verify_segment_index( size_t j, k; const double epsilon = 1e-10; avl_node_t *node; + lineage_t *lin; segment_t *u; for (k = 0; k < self->num_labels; k++) { @@ -1525,7 +1569,8 @@ msp_verify_segment_index( for (j = 0; j < self->num_populations; j++) { node = (&self->populations[j].ancestors[k])->head; while (node != NULL) { - u = (segment_t *) node->item; + lin = (lineage_t *) node->item; + u = lin->head; left = u->left; while (u != NULL) { if (u->prev != NULL) { @@ -1574,6 +1619,7 @@ msp_verify_segments(msp_t *self, bool verify_breakpoints) avl_node_t *node; segment_t *u; individual_t *ind; + lineage_t *lin; for (j = 0; j < self->input_position.nodes; j++) { for (u = self->root_segments[j]; u != NULL; u = u->next) { @@ -1589,7 +1635,9 @@ msp_verify_segments(msp_t *self, bool verify_breakpoints) for (j = 0; j < self->num_populations; j++) { node = (&self->populations[j].ancestors[k])->head; while (node != NULL) { - u = (segment_t *) node->item; + lin = (lineage_t *) node->item; + u = lin->head; + tsk_bug_assert(u->lineage == lin); tsk_bug_assert(u->prev == NULL); while (u != NULL) { label_segments++; @@ -1617,6 +1665,8 @@ msp_verify_segments(msp_t *self, bool verify_breakpoints) total_avl_nodes = msp_get_num_ancestors(self) + avl_count(&self->breakpoints) + avl_count(&self->overlap_counts) + avl_count(&self->non_empty_populations); + tsk_bug_assert(msp_get_num_ancestors(self) + == object_heap_get_num_allocated(&self->lineage_heap)); if (self->model.type == MSP_MODEL_SMC_K) { for (j = 0; j < self->num_populations; j++) { for (k = 0; k < self->num_labels; k++) { @@ -1778,6 +1828,7 @@ msp_verify_overlaps(msp_t *self) avl_node_t *node; node_mapping_t *nm; sampling_event_t se; + lineage_t *lin; segment_t *u; size_t j; uint32_t label, count; @@ -1798,7 +1849,8 @@ msp_verify_overlaps(msp_t *self) for (j = 0; j < self->num_populations; j++) { for (node = (&self->populations[j].ancestors[label])->head; node != NULL; node = node->next) { - for (u = (segment_t *) node->item; u != NULL; u = u->next) { + lin = (lineage_t *) node->item; + for (u = lin->head; u != NULL; u = u->next) { overlap_counter_increment_interval(&counter, u->left, u->right); } } @@ -1875,6 +1927,7 @@ msp_verify_hulls(msp_t *self) int count, num_coalescing_pairs; avl_tree_t *avl; avl_node_t *a, *b; + lineage_t *lin; segment_t *x, *y; hull_t *hull, hull_a, hull_b; hullend_t *hullend; @@ -1894,7 +1947,8 @@ msp_verify_hulls(msp_t *self) continue; } for (a = avl->head; a->next != NULL; a = a->next) { - x = (segment_t *) a->item; + lin = (lineage_t *) a->item; + x = lin->head; hull_right = x->hull->right; hull_a.left = x->left; while (x->next != NULL) { @@ -1905,7 +1959,8 @@ msp_verify_hulls(msp_t *self) self->sequence_length); tsk_bug_assert(hull_a.right == hull_right); for (b = a->next; b != NULL; b = b->next) { - y = (segment_t *) b->item; + lin = (lineage_t *) b->item; + y = lin->head; hull_b.left = y->left; while (y->next != NULL) { y = y->next; @@ -2289,6 +2344,8 @@ msp_print_state(msp_t *self, FILE *out) object_heap_print_state(&self->avl_node_heap, out); fprintf(out, "node_mapping_heap:"); object_heap_print_state(&self->node_mapping_heap, out); + fprintf(out, "lineage_heap:"); + object_heap_print_state(&self->lineage_heap, out); fflush(out); msp_verify(self, 0); out: @@ -2511,7 +2568,8 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, population_id_t dest_pop, label_id_t dest_label) { int ret = 0; - segment_t *ind, *x, *y, *new_ind; + lineage_t *ind; + segment_t *x, *y; double recomb_mass, gc_mass; hull_t *hull, *new_hull, *h; @@ -2520,12 +2578,12 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, goto out; } - ind = (segment_t *) node->item; + ind = (lineage_t *) node->item; avl_unlink_node(source, node); msp_free_avl_node(self, node); hull = NULL; if (self->model.type == MSP_MODEL_SMC_K) { - hull = segment_get_hull(ind); + hull = segment_get_hull(ind->head); tsk_bug_assert(hull != NULL); msp_remove_hull(self, hull); } @@ -2536,16 +2594,15 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, if (ret < 0) { goto out; } - ret = msp_store_arg_edges(self, ind, TSK_NULL); + ret = msp_store_arg_edges(self, ind->head, TSK_NULL); if (ret != 0) { goto out; } } - if (ind->label == dest_label) { + if (ind->head->label == dest_label) { /* Need to set the population and label for each segment. */ - new_ind = ind; new_hull = hull; - for (x = ind; x != NULL; x = x->next) { + for (x = ind->head; x != NULL; x = x->next) { if (self->store_migrations) { ret = msp_record_migration( self, x->left, x->right, x->value, x->population, dest_pop); @@ -2558,7 +2615,6 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, } else { /* Because we are changing to a different Fenwick tree we must allocate * new segments each time. */ - new_ind = NULL; y = NULL; new_hull = NULL; tsk_bug_assert(hull == NULL); @@ -2568,11 +2624,12 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, // msp_free_hull(self, hull, ind->population, ind->label); //} h = new_hull; - for (x = ind; x != NULL; x = x->next) { + for (x = ind->head; x != NULL; x = x->next) { y = msp_alloc_segment(self, x->left, x->right, x->value, x->population, dest_label, y, NULL, h); - if (new_ind == NULL) { - new_ind = y; + if (x->prev == NULL) { + ind->head = y; + y->lineage = ind; } else { y->prev->next = y; } @@ -2591,13 +2648,13 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, } } if (new_hull != NULL) { - new_hull->lineage = new_ind; + new_hull->lineage = ind; ret = msp_insert_hull(self, new_hull); if (ret != 0) { goto out; } } - ret = msp_insert_individual(self, new_ind); + ret = msp_insert_individual(self, ind); out: return ret; } @@ -2870,7 +2927,7 @@ msp_pedigree_initialise(msp_t *self) { int ret = 0; population_t *pop; - segment_t *segment; + lineage_t *lin; avl_node_t *a; label_id_t label = 0; tsk_size_t j; @@ -2898,8 +2955,8 @@ msp_pedigree_initialise(msp_t *self) for (j = 0; j < self->num_populations; j++) { pop = &self->populations[j]; for (a = pop->ancestors[label].head; a != NULL; a = a->next) { - segment = (segment_t *) a->item; - ret = msp_pedigree_add_sample_ancestry(self, segment); + lin = (lineage_t *) a->item; + ret = msp_pedigree_add_sample_ancestry(self, lin->head); if (ret != 0) { goto out; } @@ -2912,16 +2969,19 @@ msp_pedigree_initialise(msp_t *self) static int MSP_WARN_UNUSED msp_dtwf_recombine( - msp_t *self, segment_t *x, segment_t **u, segment_t **v, tsk_id_t *ind_nodes) + msp_t *self, segment_t *x_head, segment_t **u, segment_t **v, tsk_id_t *ind_nodes) { int ret = 0; int ix; + int j; double k; - segment_t *y, *z, *tail; + lineage_t *lin; + segment_t *x, *y, *z, *tail; segment_t s1, s2; segment_t *seg_tails[] = { &s1, &s2 }; segment_t **rec_heads[MSP_MAX_PED_PLOIDY] = { u, v }; + x = x_head; k = msp_dtwf_generate_breakpoint(self, x->left); s1.next = NULL; s2.next = NULL; @@ -2987,16 +3047,27 @@ msp_dtwf_recombine( x = y; } } - // Remove sentinal segments + // Remove sentinel segments *u = s1.next; *v = s2.next; + for (j = 0; j < MSP_MAX_PED_PLOIDY; j++) { + y = *rec_heads[j]; + if (y != x_head && y != NULL) { + lin = msp_alloc_lineage(self, y); + if (lin == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + } + } + if (*u != NULL && *v != NULL) { - for (int i = 0; i < MSP_MAX_PED_PLOIDY; i++) { - ret = msp_store_additional_nodes_edges(self, *rec_heads[i], ind_nodes[i], - MSP_NODE_IS_RE_EVENT, (*rec_heads[i])->population, TSK_NULL, - &ind_nodes[i]); + for (j = 0; j < MSP_MAX_PED_PLOIDY; j++) { + ret = msp_store_additional_nodes_edges(self, *rec_heads[j], ind_nodes[j], + MSP_NODE_IS_RE_EVENT, (*rec_heads[j])->population, TSK_NULL, + &ind_nodes[j]); if (ret < 0) { goto out; } @@ -3195,6 +3266,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ { int ret = 0; double breakpoint; + lineage_t *right_lineage; segment_t *x, *y, *alpha, *lhs_tail; hull_t *lhs_hull, *rhs_hull; double lhs_right, rhs_right; @@ -3243,7 +3315,12 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ } tsk_bug_assert(alpha->left < alpha->right); msp_set_segment_mass(self, alpha); - ret = msp_insert_individual(self, alpha); + right_lineage = msp_alloc_lineage(self, alpha); + if (right_lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + ret = msp_insert_individual(self, right_lineage); if (ret != 0) { goto out; } @@ -3258,7 +3335,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ self, lhs_hull, rhs_right, lhs_right, lhs_tail->population, label); /* create new hull for alpha */ - rhs_hull = msp_alloc_hull(self, alpha->left, rhs_right, alpha); + rhs_hull = msp_alloc_hull(self, alpha->left, rhs_right, alpha->lineage); if (rhs_hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3315,6 +3392,7 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) { int ret = 0; segment_t *x, *y, *alpha, *head, *tail, *z, *new_individual_head; + lineage_t *new_lineage; double left_breakpoint, right_breakpoint, tl; bool insert_alpha; hull_t *hull = NULL; @@ -3499,13 +3577,18 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) new_individual_head = head; } if (new_individual_head != NULL) { + new_lineage = msp_alloc_lineage(self, new_individual_head); + if (new_lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + ret = msp_insert_individual(self, new_lineage); if (self->model.type == MSP_MODEL_SMC_K) { tsk_bug_assert(tract_hull_left < tract_hull_right); tract_hull_right = GSL_MIN( tract_hull_right + self->model.params.smc_k_coalescent.hull_offset, self->sequence_length); - hull = msp_alloc_hull( - self, tract_hull_left, tract_hull_right, new_individual_head); + hull = msp_alloc_hull(self, tract_hull_left, tract_hull_right, new_lineage); if (hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3515,7 +3598,6 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) goto out; } } - ret = msp_insert_individual(self, new_individual_head); } else { self->num_noneffective_gc_events++; } @@ -3576,6 +3658,7 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l double l, r, l_min, r_max; avl_node_t *node; node_mapping_t *nm, search; + lineage_t *new_lineage; segment_t *x, *y, *z, *alpha, *beta, *merged_head; hull_t *hull = NULL; @@ -3704,7 +3787,12 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l } if (alpha != NULL) { if (z == NULL) { - ret = msp_insert_individual(self, alpha); + new_lineage = msp_alloc_lineage(self, alpha); + if (new_lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + ret = msp_insert_individual(self, new_lineage); if (ret != 0) { goto out; } @@ -3765,8 +3853,8 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l y = y->next; } r += self->model.params.smc_k_coalescent.hull_offset; - hull = msp_alloc_hull( - self, merged_head->left, GSL_MIN(r, self->sequence_length), merged_head); + hull = msp_alloc_hull(self, merged_head->left, + GSL_MIN(r, self->sequence_length), merged_head->lineage); if (hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3831,6 +3919,7 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, segment_t *x, *z, *alpha; segment_t **H = NULL; segment_t *merged_head = NULL; + lineage_t *new_lineage = NULL; tsk_id_t individual = TSK_NULL; H = malloc(avl_count(Q) * sizeof(segment_t *)); @@ -3965,7 +4054,12 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, if (alpha != NULL) { if (z == NULL) { merged_head = alpha; - ret = msp_insert_individual(self, alpha); + new_lineage = msp_alloc_lineage(self, alpha); + if (new_lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + ret = msp_insert_individual(self, new_lineage); if (ret != 0) { goto out; } @@ -4039,9 +4133,10 @@ msp_merge_n_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, /* Migrate any of the child segments to this population, if necessary */ for (a = Q->head; a != NULL; a = a->next) { u = (segment_t *) a->item; + tsk_bug_assert(u->lineage != NULL); if (u->population != population_id) { current_pop = &self->populations[u->population]; - avl_node = avl_search(¤t_pop->ancestors[label], u); + avl_node = avl_search(¤t_pop->ancestors[label], u->lineage); tsk_bug_assert(avl_node != NULL); ret = msp_move_individual( self, avl_node, ¤t_pop->ancestors[label], population_id, label); @@ -4107,6 +4202,7 @@ msp_reset_memory_state(msp_t *self) avl_node_t *node; node_mapping_t *nm; population_t *pop; + lineage_t *lin; segment_t *u, *v; hull_t *x; hullend_t *y; @@ -4117,7 +4213,8 @@ msp_reset_memory_state(msp_t *self) pop = &self->populations[j]; for (label = 0; label < (label_id_t) self->num_labels; label++) { for (node = pop->ancestors[label].head; node != NULL; node = node->next) { - u = (segment_t *) node->item; + lin = (lineage_t *) node->item; + u = lin->head; while (u != NULL) { v = u->next; msp_free_segment(self, u); @@ -4125,6 +4222,7 @@ msp_reset_memory_state(msp_t *self) } avl_unlink_node(&pop->ancestors[label], node); msp_free_avl_node(self, node); + msp_free_lineage(self, lin); } if (pop->hulls_left != NULL) { for (node = pop->hulls_left[label].head; node != NULL; @@ -4165,6 +4263,7 @@ static int msp_insert_root_segments(msp_t *self, const segment_t *head, segment_t **new_head) { int ret = 0; + lineage_t *lineage; segment_t *copy, *prev; const segment_t *seg; double breakpoints[2]; @@ -4196,14 +4295,19 @@ msp_insert_root_segments(msp_t *self, const segment_t *head, segment_t **new_hea } copy->prev = prev; if (prev == NULL) { - ret = msp_insert_individual(self, copy); + lineage = msp_alloc_lineage(self, copy); + if (lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + ret = msp_insert_individual(self, lineage); if (ret != 0) { goto out; } if (self->model.type == MSP_MODEL_SMC_K) { if (self->state != MSP_STATE_NEW) { /* correct hull->right is set at the end */ - hull = msp_alloc_hull(self, head->left, copy->right, copy); + hull = msp_alloc_hull(self, head->left, copy->right, lineage); if (hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4575,7 +4679,8 @@ msp_initialise_smc_k(msp_t *self) avl_node_t *h_node, *a_node; hull_t *hull; double left, right; - segment_t *seg, *head; + lineage_t *lin; + segment_t *seg; for (population_id = 0; population_id < (population_id_t) self->num_populations; population_id++) { @@ -4584,10 +4689,10 @@ msp_initialise_smc_k(msp_t *self) hulls_left = &self->populations[population_id].hulls_left[label_id]; for (a_node = population_ancestors->head; a_node != NULL; a_node = a_node->next) { - seg = (segment_t *) a_node->item; + lin = (lineage_t *) a_node->item; + seg = lin->head; tsk_bug_assert(seg->prev == NULL); left = seg->left; - head = seg; while (seg != NULL) { right = seg->right; seg = seg->next; @@ -4595,7 +4700,7 @@ msp_initialise_smc_k(msp_t *self) /* insert into hulls_left */ right += self->model.params.smc_k_coalescent.hull_offset; right = GSL_MIN(right, self->sequence_length); - hull = msp_alloc_hull(self, left, right, head); + hull = msp_alloc_hull(self, left, right, lin); if (hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4962,13 +5067,13 @@ msp_get_total_gc_left(msp_t *self) return total; } -static segment_t * +static lineage_t * msp_find_gc_left_individual(msp_t *self, label_id_t label, double value) { size_t j, num_ancestors, individual_index; avl_tree_t *ancestors; avl_node_t *node; - segment_t *ind; + lineage_t *ind; double mean_gc_rate = rate_map_get_total_mass(&self->gc_map) / self->sequence_length; individual_index = (size_t) floor(value / (mean_gc_rate * self->gc_tract_length)); @@ -4979,7 +5084,7 @@ msp_find_gc_left_individual(msp_t *self, label_id_t label, double value) /* Choose the correct individual */ node = avl_at(ancestors, (unsigned int) individual_index); assert(node != NULL); - ind = (segment_t *) node->item; + ind = (lineage_t *) node->item; return ind; } else { individual_index -= num_ancestors; @@ -5023,12 +5128,15 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) const double gc_left_total = msp_get_total_gc_left(self); double h = gsl_rng_uniform(self->rng) * gc_left_total; double tl, bp, lhs_old_right, lhs_new_right; + lineage_t *lineage; segment_t *y, *x, *alpha; hull_t *rhs_hull; hull_t *lhs_hull = NULL; lhs_hull = NULL; - y = msp_find_gc_left_individual(self, label, h); + lineage = msp_find_gc_left_individual(self, label, h); + assert(lineage != NULL); + y = lineage->head; assert(y != NULL); /* generate tract length */ @@ -5107,6 +5215,22 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) } lhs_new_right = y->right; + lineage = msp_alloc_lineage(self, alpha); + if (lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + msp_set_segment_mass(self, alpha); + tsk_bug_assert(alpha->prev == NULL); + + ret = msp_insert_individual(self, lineage); + if (self->additional_nodes & MSP_NODE_IS_GC_EVENT) { + ret = msp_store_arg_gene_conversion(self, NULL, y, alpha); + if (ret != 0) { + goto out; + } + } + if (self->model.type == MSP_MODEL_SMC_K) { // lhs logic is identical to the lhs recombination event lhs_old_right = lhs_hull->right; @@ -5118,7 +5242,7 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) // rhs tsk_bug_assert(alpha->left < lhs_old_right); - rhs_hull = msp_alloc_hull(self, alpha->left, lhs_old_right, alpha); + rhs_hull = msp_alloc_hull(self, alpha->left, lhs_old_right, alpha->lineage); if (rhs_hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -5129,15 +5253,6 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) } } - msp_set_segment_mass(self, alpha); - tsk_bug_assert(alpha->prev == NULL); - ret = msp_insert_individual(self, alpha); - if (self->additional_nodes & MSP_NODE_IS_GC_EVENT) { - ret = msp_store_arg_gene_conversion(self, NULL, y, alpha); - if (ret != 0) { - goto out; - } - } out: return ret; } @@ -5336,6 +5451,7 @@ msp_pedigree_process_common_ancestors(msp_t *self, individual_t *ind, tsk_size_t goto out; } if (genome != NULL) { + tsk_bug_assert(genome->lineage != NULL); tsk_bug_assert(genome->prev == NULL); if (parent == TSK_NULL) { @@ -5376,6 +5492,8 @@ msp_pedigree_process_common_ancestors(msp_t *self, individual_t *ind, tsk_size_t for (j = 0; j < ploidy; j++) { seg = parent_ancestry[j]; if (seg != NULL) { + tsk_bug_assert(seg->lineage != NULL); + tsk_bug_assert(seg->lineage->head == seg); tsk_bug_assert(seg->prev == NULL); ret = msp_pedigree_add_individual_common_ancestor( self, parent, seg, j); @@ -5383,7 +5501,7 @@ msp_pedigree_process_common_ancestors(msp_t *self, individual_t *ind, tsk_size_t goto out; } if (seg != genome) { - ret = msp_insert_individual(self, seg); + ret = msp_insert_individual(self, seg->lineage); if (ret != 0) { goto out; } @@ -5486,6 +5604,7 @@ msp_dtwf_generation(msp_t *self) segment_list_t **parents = NULL; segment_list_t *segment_mem = NULL; segment_list_t *s; + lineage_t *lin; avl_node_t *a, *node; avl_tree_t Q[2]; /* Only support single structured coalescent label for now. */ @@ -5542,7 +5661,8 @@ msp_dtwf_generation(msp_t *self) } for (s = parents[k]; s != NULL; s = s->next) { node = s->node; - x = (segment_t *) node->item; + lin = (lineage_t *) node->item; + x = lin->head; // Recombine ancestor // TODO Should this be the recombination rate going foward from x.left? if (rate_map_get_total_mass(&self->recomb_map) > 0) { @@ -5552,7 +5672,7 @@ msp_dtwf_generation(msp_t *self) } for (i = 0; i < 2; i++) { if (u[i] != NULL && u[i] != x) { - ret = msp_insert_individual(self, u[i]); + ret = msp_insert_individual(self, u[i]->lineage); if (ret != 0) { goto out; } @@ -5889,7 +6009,7 @@ msp_change_label(msp_t *self, segment_t *ind, label_id_t label) avl_node_t *node; /* Find the this individual in the AVL tree. */ - node = avl_search(pop, ind); + node = avl_search(pop, ind->lineage); tsk_bug_assert(node != NULL); ret = msp_move_individual(self, node, pop, ind->population, label); return ret; @@ -5908,6 +6028,9 @@ msp_sweep_recombination_event( if (ret != 0) { goto out; } + tsk_bug_assert(lhs->lineage != NULL); + tsk_bug_assert(rhs->lineage != NULL); + /* NOTE: we can look at rhs->left when we compare to the sweep site. */ r = gsl_rng_uniform(self->rng); if (sweep_locus < rhs->left) { @@ -5987,6 +6110,7 @@ msp_run_sweep(msp_t *self) if (ret != 0) { goto out; } + msp_verify(self, 0); ret = msp_sweep_initialise(self, allele_frequency[0]); if (ret != 0) { goto out; @@ -5995,6 +6119,7 @@ msp_run_sweep(msp_t *self) curr_step = 1; while (msp_get_num_ancestors(self) > 0 && curr_step < num_steps) { events++; + msp_verify(self, 0); /* Set pop sizes & rec_rates */ for (j = 0; j < self->num_labels; j++) { label = (label_id_t) j; @@ -6040,6 +6165,7 @@ msp_run_sweep(msp_t *self) rec_rates[1], self->ploidy); printf("event_prob: %g rand: %g\n", event_prob, event_rand); */ + event_prob *= 1.0 - total_rate; curr_step++; @@ -6057,7 +6183,6 @@ msp_run_sweep(msp_t *self) t_unscaled = time[curr_step - 1] * self->ploidy * pop_size; tsk_bug_assert(t_unscaled > 0); self->time = t_start + t_unscaled; - /* printf("event time: %g\n", self->time); */ if (tmp_rand < e_sum / sweep_pop_tot_rate) { /* coalescent in b background */ ret = self->common_ancestor_event(self, 0, 0); @@ -6082,7 +6207,6 @@ msp_run_sweep(msp_t *self) if (ret != 0) { goto out; } - /* msp_print_state(self, stdout); */ } /* TODO we should probably support fixed events here using @@ -6177,6 +6301,7 @@ msp_insert_uncoalesced_edges(msp_t *self) label_id_t label; avl_node_t *a; segment_t *seg; + lineage_t *lin; tsk_id_t node; int64_t edge_start; tsk_node_table_t *nodes = &self->tables->nodes; @@ -6194,7 +6319,8 @@ msp_insert_uncoalesced_edges(msp_t *self) * could only have arisen as the result of a coalescence and so this * node really does represent the current ancestor */ node = TSK_NULL; - for (seg = (segment_t *) a->item; seg != NULL; seg = seg->next) { + lin = (lineage_t *) a->item; + for (seg = lin->head; seg != NULL; seg = seg->next) { if (nodes->time[seg->value] == current_time) { node = seg->value; break; @@ -6211,7 +6337,7 @@ msp_insert_uncoalesced_edges(msp_t *self) } /* For every segment add an edge pointing to this new node */ - for (seg = (segment_t *) a->item; seg != NULL; seg = seg->next) { + for (seg = lin->head; seg != NULL; seg = seg->next) { if (seg->value != node) { tsk_bug_assert(nodes->time[node] > nodes->time[seg->value]); ret = tsk_edge_table_add_row(&self->tables->edges, seg->left, @@ -6475,6 +6601,7 @@ msp_get_ancestors(msp_t *self, segment_t **ancestors) int ret = -1; avl_node_t *node; avl_tree_t *population_ancestors; + lineage_t *lineage; size_t j; label_id_t label; size_t k = 0; @@ -6483,7 +6610,8 @@ msp_get_ancestors(msp_t *self, segment_t **ancestors) for (label = 0; label < (label_id_t) self->num_labels; label++) { population_ancestors = &self->populations[j].ancestors[label]; for (node = population_ancestors->head; node != NULL; node = node->next) { - ancestors[k] = (segment_t *) node->item; + lineage = (lineage_t *) node->item; + ancestors[k] = lineage->head; k++; } } @@ -7200,6 +7328,7 @@ msp_simple_bottleneck(msp_t *self, demographic_event_t *event) population_id_t N = (population_id_t) self->num_populations; avl_node_t *node, *next, *q_node; avl_tree_t *pop, Q; + lineage_t *lin; segment_t *u; label_id_t label = 0; /* For now only support label 0 */ @@ -7222,9 +7351,11 @@ msp_simple_bottleneck(msp_t *self, demographic_event_t *event) while (node != NULL) { next = node->next; if (gsl_rng_uniform(self->rng) < p) { - u = (segment_t *) node->item; + lin = (lineage_t *) node->item; + u = lin->head; avl_unlink_node(pop, node); msp_free_avl_node(self, node); + msp_free_lineage(self, lin); q_node = msp_alloc_avl_node(self); if (q_node == NULL) { ret = MSP_ERR_NO_MEMORY; @@ -7299,7 +7430,7 @@ msp_instantaneous_bottleneck(msp_t *self, demographic_event_t *event) double rate, t; avl_tree_t *pop; avl_node_t *node, *set_node; - segment_t *individual; + lineage_t *lin; label_id_t label = 0; /* For now only support label 0 */ if (self->model.type == MSP_MODEL_DTWF) { @@ -7373,7 +7504,7 @@ msp_instantaneous_bottleneck(msp_t *self, demographic_event_t *event) if (u >= (tsk_id_t) n) { /* Remove this node from the population, and add it into the * set for the root at u */ - individual = (segment_t *) avl_nodes[j]->item; + lin = (lineage_t *) avl_nodes[j]->item; avl_unlink_node(pop, avl_nodes[j]); msp_free_avl_node(self, avl_nodes[j]); set_node = msp_alloc_avl_node(self); @@ -7381,7 +7512,8 @@ msp_instantaneous_bottleneck(msp_t *self, demographic_event_t *event) ret = MSP_ERR_NO_MEMORY; goto out; } - avl_init_node(set_node, individual); + avl_init_node(set_node, lin->head); + msp_free_lineage(self, lin); set_node = avl_insert_node(&sets[u], set_node); tsk_bug_assert(set_node != NULL); } @@ -7396,18 +7528,10 @@ msp_instantaneous_bottleneck(msp_t *self, demographic_event_t *event) } } out: - if (lineages != NULL) { - free(lineages); - } - if (pi != NULL) { - free(pi); - } - if (sets != NULL) { - free(sets); - } - if (avl_nodes != NULL) { - free(avl_nodes); - } + msp_safe_free(lineages); + msp_safe_free(pi); + msp_safe_free(sets); + msp_safe_free(avl_nodes); return ret; } @@ -7459,6 +7583,7 @@ msp_census_event(msp_t *self, demographic_event_t *event) avl_tree_t *ancestors; avl_node_t *node; segment_t *seg; + lineage_t *lin; tsk_id_t i, j; tsk_id_t u; @@ -7470,8 +7595,8 @@ msp_census_event(msp_t *self, demographic_event_t *event) node = ancestors->head; while (node != NULL) { - seg = (segment_t *) node->item; - + lin = (lineage_t *) node->item; + seg = lin->head; while (seg != NULL) { // Add an edge to the edge table. ret = tsk_node_table_add_row(&self->tables->nodes, @@ -7635,6 +7760,7 @@ msp_std_common_ancestor_event( uint32_t j, n; avl_tree_t *ancestors; avl_node_t *x_node, *y_node, *node; + lineage_t *x_lin, *y_lin; segment_t *x, *y; ancestors = &self->populations[population_id].ancestors[label]; @@ -7643,12 +7769,14 @@ msp_std_common_ancestor_event( j = (uint32_t) gsl_rng_uniform_int(self->rng, n); x_node = avl_at(ancestors, j); tsk_bug_assert(x_node != NULL); - x = (segment_t *) x_node->item; + x_lin = (lineage_t *) x_node->item; + x = x_lin->head; avl_unlink_node(ancestors, x_node); j = (uint32_t) gsl_rng_uniform_int(self->rng, n - 1); y_node = avl_at(ancestors, j); tsk_bug_assert(y_node != NULL); - y = (segment_t *) y_node->item; + y_lin = (lineage_t *) y_node->item; + y = y_lin->head; avl_unlink_node(ancestors, y_node); /* For SMC and SMC' models we reject some events to get the required @@ -7656,16 +7784,18 @@ msp_std_common_ancestor_event( if (msp_reject_ca_event(self, x, y)) { self->num_rejected_ca_events++; /* insert x and y back into the population */ - tsk_bug_assert(x_node->item == x); + tsk_bug_assert(x_node->item == x_lin); node = avl_insert_node(ancestors, x_node); tsk_bug_assert(node != NULL); - tsk_bug_assert(y_node->item == y); + tsk_bug_assert(y_node->item == y_lin); node = avl_insert_node(ancestors, y_node); tsk_bug_assert(node != NULL); } else { self->num_ca_events++; msp_free_avl_node(self, x_node); + msp_free_lineage(self, x_lin); msp_free_avl_node(self, y_node); + msp_free_lineage(self, y_lin); ret = msp_merge_two_ancestors(self, population_id, label, x, y, TSK_NULL, NULL); } return ret; @@ -7711,6 +7841,7 @@ msp_smc_k_common_ancestor_event( avl_tree_t *avl; avl_node_t *x_node, *y_node, *search; hull_t *x_hull, *y_hull = NULL; + lineage_t *x_lin, *y_lin; segment_t *x, *y; /* find first hull */ @@ -7741,14 +7872,15 @@ msp_smc_k_common_ancestor_event( /* retrieve ancestors linked to both hulls */ avl = &self->populations[population_id].ancestors[label]; - x = (segment_t *) x_hull->lineage; - x_node = avl_search(avl, x); + x_lin = x_hull->lineage; + x = x_lin->head; + x_node = avl_search(avl, x_lin); tsk_bug_assert(x_node != NULL); avl_unlink_node(avl, x_node); - y = (segment_t *) y_hull->lineage; - y_node = avl_search(avl, y); + y_lin = y_hull->lineage; + y = y_lin->head; + y_node = avl_search(avl, y_lin); tsk_bug_assert(y_node != NULL); - y = (segment_t *) y_node->item; avl_unlink_node(avl, y_node); self->num_ca_events++; @@ -7756,8 +7888,9 @@ msp_smc_k_common_ancestor_event( msp_free_hull(self, y_hull, population_id, label); msp_free_avl_node(self, x_node); msp_free_avl_node(self, y_node); + msp_free_lineage(self, x_lin); + msp_free_lineage(self, y_lin); ret = msp_merge_two_ancestors(self, population_id, label, x, y, TSK_NULL, NULL); - return ret; } @@ -7836,6 +7969,7 @@ msp_dirac_common_ancestor_event(msp_t *self, population_id_t pop_id, label_id_t avl_tree_t *ancestors, Q[4]; /* MSVC won't let us use num_pots here */ avl_node_t *x_node, *y_node; segment_t *x, *y; + lineage_t *x_lin, *y_lin; double nC2, p; double psi = self->model.params.dirac_coalescent.psi; @@ -7865,16 +7999,20 @@ msp_dirac_common_ancestor_event(msp_t *self, population_id_t pop_id, label_id_t j = (uint32_t) gsl_rng_uniform_int(self->rng, n); x_node = avl_at(ancestors, j); tsk_bug_assert(x_node != NULL); - x = (segment_t *) x_node->item; + x_lin = (lineage_t *) x_node->item; + x = x_lin->head; avl_unlink_node(ancestors, x_node); j = (uint32_t) gsl_rng_uniform_int(self->rng, n - 1); y_node = avl_at(ancestors, j); tsk_bug_assert(y_node != NULL); - y = (segment_t *) y_node->item; + y_lin = (lineage_t *) y_node->item; + y = y_lin->head; avl_unlink_node(ancestors, y_node); self->num_ca_events++; msp_free_avl_node(self, x_node); + msp_free_lineage(self, x_lin); msp_free_avl_node(self, y_node); + msp_free_lineage(self, y_lin); ret = msp_merge_two_ancestors(self, pop_id, label, x, y, TSK_NULL, NULL); } } else { @@ -8003,6 +8141,7 @@ msp_multi_merger_common_ancestor_event( uint32_t j, i, l; avl_node_t *node, *q_node; segment_t *u; + lineage_t *lin; uint32_t pot_size; uint32_t cumul_pot_size = 0; @@ -8019,9 +8158,11 @@ msp_multi_merger_common_ancestor_event( node = avl_at(ancestors, j); tsk_bug_assert(node != NULL); - u = (segment_t *) node->item; + lin = (lineage_t *) node->item; + u = lin->head; avl_unlink_node(ancestors, node); msp_free_avl_node(self, node); + msp_free_lineage(self, lin); q_node = msp_alloc_avl_node(self); if (q_node == NULL) { diff --git a/lib/msprime.h b/lib/msprime.h index dc42a40f4..f27a22c6f 100644 --- a/lib/msprime.h +++ b/lib/msprime.h @@ -85,8 +85,13 @@ typedef struct segment_t_t { struct segment_t_t *prev; struct segment_t_t *next; struct hull_t_t *hull; + struct lineage_t_t *lineage; } segment_t; +typedef struct lineage_t_t { + segment_t *head; +} lineage_t; + typedef struct { double position; uint32_t value; @@ -95,7 +100,7 @@ typedef struct { typedef struct hull_t_t { double left; double right; - segment_t *lineage; + lineage_t *lineage; size_t id; uint64_t count; uint64_t insertion_order; @@ -277,6 +282,7 @@ typedef struct _msp_t { /* memory management */ object_heap_t avl_node_heap; object_heap_t node_mapping_heap; + object_heap_t lineage_heap; /* We keep an independent segment heap for each label */ object_heap_t *segment_heap; /* We keep an independent hull heap for each label */ diff --git a/lib/tests/test_ancestry.c b/lib/tests/test_ancestry.c index 1c4380256..e165d0bc9 100644 --- a/lib/tests/test_ancestry.c +++ b/lib/tests/test_ancestry.c @@ -1,5 +1,5 @@ /* -** Copyright (C) 2016-2021 University of Oxford +** Copyright (C) 2016-2024 University of Oxford ** ** This file is part of msprime. ** @@ -1645,6 +1645,7 @@ test_multiple_mergers_unary_nodes(void) CU_ASSERT_EQUAL(ret, 0); msp_verify(&msp, 0); + /* msp_print_state(&msp, stdout); */ CU_ASSERT_TRUE(msp_get_num_breakpoints(&msp) > 0); // verify whether there is at least one unary node num_edges = tables.edges.num_rows; @@ -4307,7 +4308,7 @@ main(int argc, char **argv) { "test_multiple_mergers_growth_rate", test_multiple_mergers_growth_rate }, { "test_dirac_coalescent_bad_parameters", test_dirac_coalescent_bad_parameters }, { "test_beta_coalescent_bad_parameters", test_beta_coalescent_bad_parameters }, - { "test_multipe_mergers_unary_nodes", test_multiple_mergers_unary_nodes }, + { "test_multiple_mergers_unary_nodes", test_multiple_mergers_unary_nodes }, { "test_simulator_getters_setters", test_simulator_getters_setters }, { "test_demographic_events", test_demographic_events }, diff --git a/lib/tests/test_sweeps.c b/lib/tests/test_sweeps.c index aa3f149d2..93f107750 100644 --- a/lib/tests/test_sweeps.c +++ b/lib/tests/test_sweeps.c @@ -374,8 +374,9 @@ static void test_sweep_genic_selection_mimic_msms(void) { /* To mimic the nrepeats = 300 parameter in msms cmdline arguments*/ - for (int i = 0; i < 300; i++) + for (int i = 0; i < 300; i++) { sweep_genic_selection_mimic_msms_single_run(i + 1); + } } int diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 0a95abd78..0abd6e579 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -432,33 +432,19 @@ def test_one_gen_pedigree(self, num_founders): ts = self.run_script(f"0 --from-ts {ts_path} -r 1 --model=fixed_pedigree") assert len(ts.dump_tables().edges) == 0 - def test_smck(self): - ts = self.run_script("10 -L 1000 -d -r 0.01 --model smc_k") - assert ts.num_trees > 1 - for tree in ts.trees(): - assert tree.num_roots == 1 - - ts = self.run_script("10 -L 1000 -r 0.01 --model smc_k") - assert ts.num_trees > 1 - for tree in ts.trees(): - assert tree.num_roots == 1 - - ts = self.run_script("10 -L 1000 -r 0.01 --model smc_k --offset 0.50") - assert ts.num_trees > 1 - for tree in ts.trees(): - assert tree.num_roots == 1 - - ts = self.run_script("10 -L 1000 -d -r 0.01 --model smc_k -p 2 -g 0.1") - assert ts.num_trees > 1 - for tree in ts.trees(): - assert tree.num_roots == 1 - - ts = self.run_script("10 -L 1000 -d -c 0.04 2 --model smc_k") - assert ts.num_trees > 1 - for tree in ts.trees(): - assert tree.num_roots == 1 - - ts = self.run_script("10 -L 1000 -c 0.04 2 --model smc_k --offset 0.75") + @pytest.mark.parametrize( + "cmd", + [ + "10 -L 1000 -d -r 0.01 --model smc_k", + "10 -L 1000 -r 0.01 --model smc_k", + "10 -L 1000 -r 0.01 --model smc_k --offset 0.50", + "10 -L 1000 -d -r 0.01 --model smc_k -p 2 -g 0.1", + "10 -L 1000 -d -c 0.04 2 --model smc_k", + "10 -L 1000 -c 0.04 2 --model smc_k --offset 0.75", + ], + ) + def test_smck(self, cmd): + ts = self.run_script(cmd) assert ts.num_trees > 1 for tree in ts.trees(): assert tree.num_roots == 1 diff --git a/tests/test_provenance.py b/tests/test_provenance.py index 805235155..5eba25b85 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -80,7 +80,7 @@ class TestBuildObjects: def decode(self, prov): # Supress warnings about schemas here - it's no big deal and # not easy to fix - with pytest.warns(UserWarning): + with pytest.warns((UserWarning, DeprecationWarning)): builder = pjs.ObjectBuilder(tskit.provenance.get_schema()) ns = builder.build_classes() return ns.TskitProvenance.from_json(prov)