Skip to content

Commit

Permalink
Fix logic bug with lineage struct
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 17, 2024
1 parent cafec23 commit 0a171f0
Showing 1 changed file with 60 additions and 37 deletions.
97 changes: 60 additions & 37 deletions algorithms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Python version of the simulation algorithm.
"""

from __future__ import annotations

import argparse
Expand Down Expand Up @@ -120,6 +121,7 @@ class Segment:
and right, denoting the loci over which it spans, a node and a
next, giving the next in the chain.
"""

index: int
left: float = None
right: float = None
Expand All @@ -135,17 +137,17 @@ class Segment:
# def __repr__(self):
# return object.__repr__(self)

# def __init__(self, index):
# self.left = None
# self.right = None
# self.node = None
# self.prev = None
# self.next = None
# self.index = index
# self.lineage = None
# self.population = None
# self.label = 0
# self.hull = None
# def __init__(self, index):
# self.left = None
# self.right = None
# self.node = None
# self.prev = None
# self.next = None
# self.index = index
# self.lineage = None
# self.population = None
# self.label = 0
# self.hull = None

# def __repr__(self):
# return repr((self.left, self.right, self.node))
Expand Down Expand Up @@ -1123,6 +1125,7 @@ def alloc_segment(
next=None, # noqa: A002
label=0,
hull=None,
lineage=None,
):
"""
Pops a new segment off the stack and sets its properties.
Expand All @@ -1136,6 +1139,7 @@ def alloc_segment(
s.prev = prev
s.label = label
s.hull = hull
s.lineage = lineage
return s

def alloc_lineage(self, head=None, tail=None, population=None, label=None):
Expand All @@ -1158,6 +1162,7 @@ def copy_segment(self, segment):
next=segment.next,
prev=segment.prev,
label=segment.label,
lineage=segment.lineage,
)

def free_segment(self, u):
Expand Down Expand Up @@ -1864,7 +1869,6 @@ def hudson_recombination_event(self, label, return_heads=False):
seg.lineage = right_lineage
seg = seg.next


if self.model == "smc_k":
# modify original hull
pop = alpha.population
Expand All @@ -1885,6 +1889,7 @@ def hudson_recombination_event(self, label, return_heads=False):
self.store_node(alpha.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_arg_edges(alpha)

assert not return_heads
self.verify()

ret = None
Expand Down Expand Up @@ -2470,6 +2475,7 @@ def merge_two_ancestors(self, population_index, label, lin_x, lin_y, u=-1):
self.num_ca_events += 1
new_lineage = None
merged_head = None
z = None
coalescence = False
defrag_required = False
x = lin_x.head
Expand Down Expand Up @@ -2548,13 +2554,15 @@ def merge_two_ancestors(self, population_index, label, lin_x, lin_y, u=-1):

# loop tail; update alpha and integrate it into the state.
if alpha is not None:
print("ADD alpha", alpha, repr(new_lineage))
if new_lineage is None:
new_lineage = self.alloc_lineage(alpha, alpha, population=population_index, label=label)
if z is None:
new_lineage = self.alloc_lineage(
alpha, alpha, population=population_index, label=label
)
pop.add(new_lineage, label)
z = None
else:
z = new_lineage.tail
# print("add", new_lineage)
# print("alpha", alpha)
# z = new_lineage.tail
if (coalescence and not self.coalescing_segments_only) or (
self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0
):
Expand All @@ -2566,8 +2574,23 @@ def merge_two_ancestors(self, population_index, label, lin_x, lin_y, u=-1):
z.next = alpha
new_lineage.tail = alpha
alpha.lineage = new_lineage
# print("TAIL:", alpha, id(alpha.lineage))
alpha.prev = z
self.set_segment_mass(alpha)
z = alpha

if new_lineage is not None:
x = new_lineage.tail
while x is not None:
x.lineage = new_lineage
new_lineage.tail = x
x = x.next

x = new_lineage.head
while x is not None:
x.lineage = new_lineage
new_lineage.head = x
x = x.prev

if coalescence:
if not self.coalescing_segments_only:
Expand All @@ -2581,18 +2604,18 @@ def merge_two_ancestors(self, population_index, label, lin_x, lin_y, u=-1):
if coalescence:
self.defrag_breakpoints()

self.verify()
# self.print_state()

if merged_head is not None and self.model == "smc_k":
assert False # get rid of merged_head
assert merged_head.prev is None
hull = self.alloc_hull(merged_head.left, merged_head.right, merged_head)
while merged_head is not None:
right = merged_head.right
merged_head = merged_head.next
hull.right = min(right + self.hull_offset, self.L)
pop.add_hull(label, hull)
# self.verify()

# if merged_head is not None and self.model == "smc_k":
# assert False # get rid of merged_head
# assert merged_head.prev is None
# hull = self.alloc_hull(merged_head.left, merged_head.right, merged_head)
# while merged_head is not None:
# right = merged_head.right
# merged_head = merged_head.next
# hull.right = min(right + self.hull_offset, self.L)
# pop.add_hull(label, hull)

def print_state(self, verify=False):
print("State @ time ", self.t)
Expand Down Expand Up @@ -2660,25 +2683,25 @@ def print_state(self, verify=False):
self.verify()

def verify_segments(self):
print("VERIFY")
# print("VERIFY")
for pop_index, pop in enumerate(self.P):
for label in range(self.num_labels):
assert len(set(id(lin) for lin in pop._ancestors[label])) == len(pop._ancestors[label])
segment_ids = set()
assert len(set(id(lin) for lin in pop._ancestors[label])) == len(
pop._ancestors[label]
)
for lin in pop.iter_label(label):
assert lin.population == pop_index
assert lin.label == label
head = lin.head
assert head.prev is None
u = head
while u is not None:
assert u.lineage is lin
u = u.next

prev = head
u = head.next
print(id(lin), repr(lin))
while u is not None:
print(u)
assert u.index not in segment_ids
segment_ids.add(u.index)
print("\t", u.lineage)
# print(u, u.lineage)
assert u.lineage is lin
assert prev.next is u
assert u.prev is prev
Expand Down

0 comments on commit 0a171f0

Please sign in to comment.