diff --git a/new-docs/source/tutorial/1-getting-started.ipynb b/new-docs/source/tutorial/1-getting-started.ipynb index 0cb4402c8..469cd8d50 100644 --- a/new-docs/source/tutorial/1-getting-started.ipynb +++ b/new-docs/source/tutorial/1-getting-started.ipynb @@ -293,6 +293,50 @@ "print(\"\\n\".join(str(p) for p in outputs.out_file))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Executing tasks in parallel\n", + "\n", + "By default, Pydra will use the *debug* worker, which executes each task sequentially.\n", + "This makes it easier to debug tasks and workflows, however, in most cases, once a workflow\n", + "is ready to go, a concurrent worker is preferable so tasks can be executed in parallel\n", + "(see [Workers](./2-advanced-execution.html#Workers)). if issubclass(field.type, FileSet): + if inspect.isclass(field.type) and issubclass(field.type, FileSet): details += ( f"- {changed}: value passed to the {field.type} field is of type " f"{field_type} ('{val}'). If it is intended to contain output data " diff --git a/pydra/engine/state.py b/pydra/engine/state.py index c97d71a53..ef65487ca 100644 --- a/pydra/engine/state.py +++ b/pydra/engine/state.py @@ -39,6 +39,11 @@ def __init__(self, indices: dict[str, int] | None = None): else: self.indices = OrderedDict(sorted(indices.items())) + def __repr__(self): + return ( + "StateIndex(" + ", ".join(f"{n}={v}" for n, v in self.indices.items()) + ")" + ) + def __hash__(self): return hash(tuple(self.indices.items())) diff --git a/pydra/engine/submitter.py b/pydra/engine/submitter.py index b9af10488..40fef6d8f 100644 --- a/pydra/engine/submitter.py +++ b/pydra/engine/submitter.py @@ -7,6 +7,7 @@ from pathlib import Path from tempfile import mkdtemp from copy import copy +from datetime import datetime from collections import defaultdict from .workers import Worker, WORKERS from .graph import DiGraph @@ -21,7 +22,7 @@ from .core import Task from pydra.utils.messenger import AuditFlag, Messenger from pydra.utils import user_cache_dir - +from pydra.design import workflow import logging logger = logging.getLogger("pydra.submitter") @@ -62,20 +63,37 @@ class Submitter: Messengers, by default None messenger_args : dict, optional Messenger arguments, by default None + clean_stale_locks : bool, optional + Whether to clean stale lock files, i.e. lock files that were created before the + start of the current run. Don't set if using a global cache where there are + potentially multiple workflows that are running concurrently. By default (None), + lock files will be cleaned if the *debug* worker is used **kwargs : dict Keyword arguments to pass on to the worker initialisation """ + cache_dir: os.PathLike + worker: Worker + environment: "Environment | None" + rerun: bool + cache_locations: list[os.PathLike] + audit_flags: AuditFlag + messengers: ty.Iterable[Messenger] + messenger_args: dict[str, ty.Any] + clean_stale_locks: bool + run_start_time: datetime | None + def __init__( self, cache_dir: os.PathLike | None = None, - worker: ty.Union[str, ty.Type[Worker]] = "debug", + worker: str | ty.Type[Worker] | Worker = "debug", environment: "Environment | None" = None, rerun: bool = False, cache_locations: list[os.PathLike] | None = None, audit_flags: AuditFlag = AuditFlag.NONE, messengers: ty.Iterable[Messenger] | None = None, messenger_args: dict[str, ty.Any] | None = None, + clean_stale_locks: bool | None = None, **kwargs, ): @@ -113,6 +131,12 @@ def __init__( except TypeError as e: e.add_note(WORKER_KWARG_FAIL_NOTE) raise + self.run_start_time = None + self.clean_stale_locks = ( + clean_stale_locks + if clean_stale_locks is not None + else (self.worker_name == "debug") + ) self.worker_kwargs = kwargs self._worker.loop = self.loop @@ -133,18 +157,16 @@ def __call__( task_def._check_rules() # If the outer task is split, create an implicit workflow to hold the split nodes if task_def._splitter: - - from pydra.design import workflow from pydra.engine.specs import TaskDef output_types = {o.name: list[o.type] for o in list_fields(task_def.Outputs)} @workflow.define(outputs=output_types) - def Split(defn: TaskDef): + def Split(defn: TaskDef, output_types: dict): node = workflow.add(defn) return tuple(getattr(node, o) for o in output_types) - task_def = Split(defn=task_def) + task_def = Split(defn=task_def, output_types=output_types) elif task_def._combiner: raise ValueError( @@ -152,17 +174,23 @@ def Split(defn: TaskDef): "Use the `split` method to split the task before combining." ) task = Task(task_def, submitter=self, name="task", environment=self.environment) - if task.is_async: # Only workflow tasks can be async - self.loop.run_until_complete(self.worker.run_async(task, rerun=self.rerun)) - else: - self.worker.run(task, rerun=self.rerun) + try: + self.run_start_time = datetime.now() + if task.is_async: # Only workflow tasks can be async + self.loop.run_until_complete( + self.worker.run_async(task, rerun=self.rerun) + ) + else: + self.worker.run(task, rerun=self.rerun) + finally: + self.run_start_time = None PersistentCache().clean_up() result = task.result() if result is None: if task.lockfile.exists(): raise RuntimeError( f"Task {task} has a lockfile, but no result was found. " - "This may be due to another submission process running, or the hard " + "This may be due to another submission process queued, or the hard " "interrupt (e.g. a debugging abortion) interrupting a previous run. " f"In the case of an interrupted run, please remove {str(task.lockfile)!r} " "and resubmit." @@ -228,18 +256,30 @@ async def expand_workflow_async(self, workflow_task: "Task[WorkflowDef]") -> Non # this might be related to some delays saving the files # so try to get_runnable_tasks for another minute ii = 0 - while not tasks and exec_graph.nodes: + while not tasks and any(not n.done for n in exec_graph.nodes): tasks = self.get_runnable_tasks(exec_graph) ii += 1 # don't block the event loop! await asyncio.sleep(1) - if ii > 60: + if ii > 10: + not_done = "\n".join( + ( + f"{n.name}: started={bool(n.started)}, " + f"blocked={list(n.blocked)}, queued={list(n.queued)}" + ) + for n in exec_graph.nodes + if not n.done + ) msg = ( - f"Graph of '{wf}' workflow is not empty, but not able to get " - "more tasks - something has gone wrong when retrieving the " - "results predecessors:\n\n" + "Something has gone wrong when retrieving the predecessor " + f"results. Not able to get any more tasks but he following " + f"nodes of the {wf.name!r} workflow are not done:\n{not_done}\n\n" ) - # Get blocked tasks and the predecessors they are waiting on + not_done = [n for n in exec_graph.nodes if not n.done] + msg += "\n" + ", ".join( + f"{t.name}: {t.done}" for t in not_done[0].queued.values() + ) + # Get blocked tasks and the predecessors they are blocked on outstanding: dict[Task[DefType], list[Task[DefType]]] = { t: [ p for p in exec_graph.predecessors[t.name] if not p.done @@ -248,11 +288,11 @@ async def expand_workflow_async(self, workflow_task: "Task[WorkflowDef]") -> Non } hashes_have_changed = False - for task, waiting_on in outstanding.items(): - if not waiting_on: + for task, blocked_on in outstanding.items(): + if not blocked_on: continue msg += f"- '{task.name}' node blocked due to\n" - for pred in waiting_on: + for pred in blocked_on: if ( pred.checksum != wf.inputs._graph_checksums[pred.name] @@ -302,13 +342,21 @@ def close(self): """ Close submitter. - Do not close previously running loop. + Do not close previously queued loop. """ self.worker.close() if self._own_loop: self.loop.close() + def _check_locks(self, tasks: list[Task]) -> None: + """Check for stale lock files and remove them.""" + if self.clean_stale_locks: + for task in tasks: + start_time = task.run_start_time + if start_time and start_time < self.run_start_time: + task.lockfile.unlink() + def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: """Parse a graph and return all runnable tasks. @@ -338,6 +386,7 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: if not node.started: not_started.add(node) tasks.extend(node.get_runnable_tasks(graph)) + self._check_locks(tasks) return tasks @property @@ -369,10 +418,12 @@ class NodeExecution(ty.Generic[DefType]): errored: dict[StateIndex | None, "Task[DefType]"] # List of tasks that couldn't be run due to upstream errors unrunnable: dict[StateIndex | None, list["Task[DefType]"]] - # List of tasks that are running - running: dict[StateIndex | None, "Task[DefType]"] - # List of tasks that are waiting on other tasks to complete before they can be run - waiting: dict[StateIndex | None, "Task[DefType]"] + # List of tasks that are queued + queued: dict[StateIndex | None, "Task[DefType]"] + # List of tasks that are queued + running: dict[StateIndex | None, tuple["Task[DefType]", datetime]] + # List of tasks that are blocked on other tasks to complete before they can be run + blocked: dict[StateIndex | None, "Task[DefType]"] _tasks: dict[StateIndex | None, "Task[DefType]"] | None @@ -391,10 +442,11 @@ def __init__( self.submitter = submitter # Initialize the state dictionaries self._tasks = None - self.waiting = {} + self.blocked = {} self.successful = {} self.errored = {} - self.running = {} + self.queued = {} + self.running = {} # Not used in logic, but may be useful for progress tracking self.unrunnable = defaultdict(list) self.state_names = self.node.state.names self.workflow_inputs = workflow_inputs @@ -430,18 +482,44 @@ def started(self) -> bool: self.successful or self.errored or self.unrunnable - or self.running - or self.waiting + or self.queued + or self.blocked ) @property def done(self) -> bool: - return self.started and not (self.running or self.waiting) + self.update_status() + if not self.started: + return False + # Check to see if any previously queued tasks have completed + return not (self.queued or self.blocked or self.running) + + def update_status(self) -> None: + """Updates the status of the tasks in the node.""" + if not self.started: + return + # Check to see if any previously queued tasks have completed + for index, task in list(self.queued.items()): + if task.done: + self.successful[task.state_index] = self.queued.pop(index) + elif task.errored: + self.errored[task.state_index] = self.queued.pop(index) + elif task.run_start_time: + self.running[task.state_index] = ( + self.queued.pop(index), + task.run_start_time, + ) + # Check to see if any previously running tasks have completed + for index, (task, start_time) in list(self.running.items()): + if task.done: + self.successful[task.state_index] = self.running.pop(index)[0] + elif task.errored: + self.errored[task.state_index] = self.running.pop(index)[0] @property def all_failed(self) -> bool: return (self.unrunnable or self.errored) and not ( - self.successful or self.waiting or self.running + self.successful or self.blocked or self.queued ) def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]: @@ -470,7 +548,7 @@ def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]: def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: """For a given node, check to see which tasks have been successfully run, are ready - to run, can't be run due to upstream errors, or are waiting on other tasks to complete. + to run, can't be run due to upstream errors, or are blocked on other tasks to complete. Parameters ---------- @@ -488,29 +566,23 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: runnable: list["Task[DefType]"] = [] self.tasks # Ensure tasks are loaded if not self.started: - self.waiting = copy(self._tasks) - # Check to see if any previously running tasks have completed - for index, task in list(self.running.items()): - if task.done: - self.successful[task.state_index] = self.running.pop(index) - elif task.errored: - self.errored[task.state_index] = self.running.pop(index) - # Check to see if any waiting tasks are now runnable/unrunnable - for index, task in list(self.waiting.items()): + self.blocked = copy(self._tasks) + # Check to see if any blocked tasks are now runnable/unrunnable + for index, task in list(self.blocked.items()): pred: NodeExecution is_runnable = True for pred in graph.predecessors[self.node.name]: if index not in pred.successful: is_runnable = False if index in pred.errored: - self.unrunnable[index].append(self.waiting.pop(index)) + self.unrunnable[index].append(self.blocked.pop(index)) if index in pred.unrunnable: self.unrunnable[index].extend(pred.unrunnable[index]) - self.waiting.pop(index) + self.blocked.pop(index) break if is_runnable: - runnable.append(self.waiting.pop(index)) - self.running.update({t.state_index: t for t in runnable}) + runnable.append(self.blocked.pop(index)) + self.queued.update({t.state_index: t for t in runnable}) return runnable diff --git a/pydra/utils/hash.py b/pydra/utils/hash.py index 224af25fb..a836eaddf 100644 --- a/pydra/utils/hash.py +++ b/pydra/utils/hash.py @@ -4,12 +4,9 @@ import os import struct import inspect -import re from datetime import datetime import typing as ty import types -import ast -import cloudpickle as cp from pathlib import Path from collections.abc import Mapping from functools import singledispatch @@ -331,7 +328,17 @@ def bytes_repr(obj: object, cache: Cache) -> Iterator[bytes]: elif hasattr(obj, "__slots__"): dct = {attr: getattr(obj, attr) for attr in obj.__slots__} else: - dct = obj.__dict__ + try: + dct = obj.__dict__ + except AttributeError: + dct = { + n: getattr(obj, n) + for n in dir(obj) + if not ( + (n.startswith("__") and n.endswith("__")) + or inspect.ismethod(getattr(obj, n)) + ) + } yield from bytes_repr_mapping_contents(dct, cache) yield b"}" @@ -525,31 +532,39 @@ def bytes_repr_set(obj: Set, cache: Cache) -> Iterator[bytes]: yield b"}" +@register_serializer +def bytes_repr_code(obj: types.CodeType, cache: Cache) -> Iterator[bytes]: + yield b"code:(" + yield from bytes_repr_sequence_contents( + ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + obj.co_filename, + obj.co_freevars, + obj.co_name, + obj.co_firstlineno, + obj.co_lnotab, + obj.co_cellvars, + ), + cache, + ) + yield b")" + + @register_serializer def bytes_repr_function(obj: types.FunctionType, cache: Cache) -> Iterator[bytes]: """Serialize a function, attempting to use the AST of the source code if available otherwise falling back to using cloudpickle to serialize the byte-code of the function.""" - try: - src = inspect.getsource(obj) - except OSError: - # Fallback to using cloudpickle to serialize the function if the source - # code is not available - bytes_repr = cp.dumps(obj) - else: - indent = re.match(r"(\s*)", src).group(1) - if indent: - src = re.sub(f"^{indent}", "", src, flags=re.MULTILINE) - src_ast = ast.parse(src) - # Remove the function definition from the source code - bytes_repr = ast.dump( - src_ast, annotate_fields=False, include_attributes=False - ).encode() - - yield b"function:(" - for i in range(0, len(bytes_repr), FUNCTION_SRC_CHUNK_LEN_DEFAULT): - yield hash_single(bytes_repr[i : i + FUNCTION_SRC_CHUNK_LEN_DEFAULT], cache) - yield b")" + yield from bytes_repr(obj.__code__, cache) def bytes_repr_mapping_contents(mapping: Mapping, cache: Cache) -> Iterator[bytes]: