diff --git a/sisyphus/graph.py b/sisyphus/graph.py index d9c7e70..4a4bed4 100644 --- a/sisyphus/graph.py +++ b/sisyphus/graph.py @@ -488,7 +488,7 @@ def get_unfinished_jobs(job): self.for_all_nodes(get_unfinished_jobs, nodes=nodes) return states - def for_all_nodes(self, f, nodes=None, bottom_up=False): + def for_all_nodes(self, f, nodes=None, bottom_up=False, *, pool: Optional[ThreadPool] = None): """ Run function f for each node and ancestor for `nodes` from top down, stop expanding tree branch if functions returns False. Does not stop on None to allow functions with no @@ -497,6 +497,7 @@ def for_all_nodes(self, f, nodes=None, bottom_up=False): :param (Job)->bool f: function will be executed for all nodes :param nodes: all nodes that will be checked, defaults to all output nodes in graph :param bool bottom_up: start with deepest nodes first, ignore return value of f + :param pool: use custom thread pool :return: set with all visited nodes """ @@ -544,7 +545,8 @@ def for_all_nodes(self, f, nodes=None, bottom_up=False): pool_lock = threading.Lock() finished_lock = threading.Lock() - pool = self.pool + if not pool: + pool = self.pool # recursive function to run through tree def runner(job): diff --git a/sisyphus/manager.py b/sisyphus/manager.py index d3539bc..a41f7cf 100644 --- a/sisyphus/manager.py +++ b/sisyphus/manager.py @@ -38,7 +38,7 @@ def f(job): return True while not self.stopped: - self.sis_graph.for_all_nodes(f) + self.sis_graph.for_all_nodes(f, pool=self.thread_pool) time.sleep(gs.JOB_CLEANER_INTERVAL) def close(self):