Skip to content

Commit

Permalink
debugged topological sort
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitri-yatsenko committed Sep 15, 2024
1 parent adfdc65 commit 24c090d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
19 changes: 8 additions & 11 deletions datajoint/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from collections import defaultdict
from .errors import DataJointError


def extract_master(part_table):
"""
given a part table name, return master part. None if not a part table
given a part table name, return master part. None if not a part table
"""
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table)
return match['master'] + '`' if match else None

return match["master"] + "`" if match else None


def topo_sort(graph):
Expand Down Expand Up @@ -39,22 +39,19 @@ def topo_sort(graph):
# to ensure correct topological ordering of the masters.
for part in graph:
# find the part's master
master = extract_master(part)
if master:
if (master := extract_master(part)) in graph:
for edge in graph.in_edges(part):
parent = edge[0]
if parent != master and extract_master(parent) != master:
graph.add_edge(parent, master)

sorted_nodes = list(nx.topological_sort(graph))

# bring parts up to their masters
pos = len(sorted_nodes) - 1
placed = set()
while pos > 1:
part = sorted_nodes[pos]
master = extract_master(part)
if not master or part in placed:
if not (master := extract_master) or part in placed:
pos -= 1
else:
placed.add(part)
Expand All @@ -63,7 +60,7 @@ def topo_sort(graph):
except ValueError:
# master not found
pass
else:
else:
if pos > j + 1:
# move the part to its master
del sorted_nodes[pos]
Expand Down Expand Up @@ -214,8 +211,8 @@ def descendants(self, full_table_name):
:return: all dependent tables sorted in topological order. Self is included.
"""
self.load(force=False)
nodes = self.subgraph(nx.descendants(self, full_table_name))
return [full_table_name] + nodes.topo_sort()
nodes = self.subgraph(nx.descendants(self, full_table_name))
return [full_table_name] + nodes.topo_sort()

def ancestors(self, full_table_name):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def test_list_tables(schema_simp):
actual = set(schema_simp.list_tables())
assert actual == expected, f"Missing from list_tables(): {expected - actual}"


def test_schema_save_any(schema_any):
assert "class Experiment(dj.Imported)" in schema_any.code

Expand Down

0 comments on commit 24c090d

Please sign in to comment.