From d429d75138b0001ced40984222920ee5a3977f27 Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 16 Feb 2024 09:00:52 +0000 Subject: [PATCH] Add callback for when an operation starts --- cubed/runtime/executors/coiled.py | 7 ++++++- cubed/runtime/executors/dask_distributed_async.py | 8 +++++++- cubed/runtime/executors/lithops.py | 5 ++++- cubed/runtime/executors/modal.py | 7 ++++++- cubed/runtime/executors/modal_async.py | 3 ++- cubed/runtime/executors/python.py | 2 ++ cubed/runtime/executors/python_async.py | 7 ++++++- cubed/runtime/types.py | 11 +++++++++++ cubed/runtime/utils.py | 8 +++++++- 9 files changed, 51 insertions(+), 7 deletions(-) diff --git a/cubed/runtime/executors/coiled.py b/cubed/runtime/executors/coiled.py index 5515ff08..ee61496b 100644 --- a/cubed/runtime/executors/coiled.py +++ b/cubed/runtime/executors/coiled.py @@ -5,7 +5,11 @@ from cubed.runtime.pipeline import visit_nodes from cubed.runtime.types import Callback, DagExecutor -from cubed.runtime.utils import execution_stats, handle_callbacks +from cubed.runtime.utils import ( + execution_stats, + handle_callbacks, + handle_operation_start_callbacks, +) from cubed.spec import Spec @@ -27,6 +31,7 @@ def execute_dag( ) -> None: # Note this currently only builds the task graph for each stage once it gets to that stage in computation for name, node in visit_nodes(dag, resume=resume): + handle_operation_start_callbacks(callbacks, name) pipeline = node["pipeline"] coiled_function = make_coiled_function(pipeline.function, coiled_kwargs) input = list( diff --git a/cubed/runtime/executors/dask_distributed_async.py b/cubed/runtime/executors/dask_distributed_async.py index 7f3264a3..76f540d2 100644 --- a/cubed/runtime/executors/dask_distributed_async.py +++ b/cubed/runtime/executors/dask_distributed_async.py @@ -20,7 +20,12 @@ from cubed.runtime.executors.asyncio import async_map_unordered from cubed.runtime.pipeline import visit_node_generations, visit_nodes from cubed.runtime.types import Callback, CubedPipeline, DagExecutor -from cubed.runtime.utils import execution_stats, gensym, handle_callbacks +from cubed.runtime.utils import ( + execution_stats, + gensym, + handle_callbacks, + handle_operation_start_callbacks, +) from cubed.spec import Spec @@ -123,6 +128,7 @@ async def async_execute_dag( if not compute_arrays_in_parallel: # run one pipeline at a time for name, node in visit_nodes(dag, resume=resume): + handle_operation_start_callbacks(callbacks, name) st = pipeline_to_stream(client, name, node["pipeline"], **kwargs) async with st.stream() as streamer: async for _, stats in streamer: diff --git a/cubed/runtime/executors/lithops.py b/cubed/runtime/executors/lithops.py index e9aa5f04..27efcfcc 100644 --- a/cubed/runtime/executors/lithops.py +++ b/cubed/runtime/executors/lithops.py @@ -27,7 +27,7 @@ ) from cubed.runtime.pipeline import visit_node_generations, visit_nodes from cubed.runtime.types import Callback, DagExecutor -from cubed.runtime.utils import handle_callbacks +from cubed.runtime.utils import handle_callbacks, handle_operation_start_callbacks from cubed.spec import Spec logger = logging.getLogger(__name__) @@ -180,6 +180,7 @@ def execute_dag( with RetryingFunctionExecutor(function_executor) as executor: if not compute_arrays_in_parallel: for name, node in visit_nodes(dag, resume=resume): + handle_operation_start_callbacks(callbacks, name) pipeline = node["pipeline"] for _, stats in map_unordered( executor, @@ -207,6 +208,8 @@ def execute_dag( group_map_functions.append(f) group_map_iterdata.append(pipeline.mappable) group_names.append(name) + for name in group_names: + handle_operation_start_callbacks(callbacks, name) for _, stats in map_unordered( executor, group_map_functions, diff --git a/cubed/runtime/executors/modal.py b/cubed/runtime/executors/modal.py index 46432711..60c3fccf 100644 --- a/cubed/runtime/executors/modal.py +++ b/cubed/runtime/executors/modal.py @@ -10,7 +10,11 @@ from cubed.runtime.pipeline import visit_nodes from cubed.runtime.types import Callback, DagExecutor -from cubed.runtime.utils import execute_with_stats, handle_callbacks +from cubed.runtime.utils import ( + execute_with_stats, + handle_callbacks, + handle_operation_start_callbacks, +) from cubed.spec import Spec RUNTIME_MEMORY_MIB = 2000 @@ -128,6 +132,7 @@ def execute_dag( else: raise ValueError(f"Unrecognized cloud: {cloud}") for name, node in visit_nodes(dag, resume=resume): + handle_operation_start_callbacks(callbacks, name) pipeline = node["pipeline"] task_create_tstamp = time.time() for _, stats in app_function.map( diff --git a/cubed/runtime/executors/modal_async.py b/cubed/runtime/executors/modal_async.py index 20ffa4f5..b5ae5bff 100644 --- a/cubed/runtime/executors/modal_async.py +++ b/cubed/runtime/executors/modal_async.py @@ -18,7 +18,7 @@ ) from cubed.runtime.pipeline import visit_node_generations, visit_nodes from cubed.runtime.types import Callback, DagExecutor -from cubed.runtime.utils import handle_callbacks +from cubed.runtime.utils import handle_callbacks, handle_operation_start_callbacks from cubed.spec import Spec @@ -127,6 +127,7 @@ async def async_execute_dag( if not compute_arrays_in_parallel: # run one pipeline at a time for name, node in visit_nodes(dag, resume=resume): + handle_operation_start_callbacks(callbacks, name) st = pipeline_to_stream(app_function, name, node["pipeline"], **kwargs) async with st.stream() as streamer: async for _, stats in streamer: diff --git a/cubed/runtime/executors/python.py b/cubed/runtime/executors/python.py index e4fc1eaf..cd933a5c 100644 --- a/cubed/runtime/executors/python.py +++ b/cubed/runtime/executors/python.py @@ -4,6 +4,7 @@ from cubed.runtime.pipeline import visit_nodes from cubed.runtime.types import Callback, CubedPipeline, DagExecutor, TaskEndEvent +from cubed.runtime.utils import handle_operation_start_callbacks from cubed.spec import Spec @@ -24,6 +25,7 @@ def execute_dag( **kwargs, ) -> None: for name, node in visit_nodes(dag, resume=resume): + handle_operation_start_callbacks(callbacks, name) pipeline: CubedPipeline = node["pipeline"] for m in pipeline.mappable: exec_stage_func( diff --git a/cubed/runtime/executors/python_async.py b/cubed/runtime/executors/python_async.py index e475f3d9..6c5400e6 100644 --- a/cubed/runtime/executors/python_async.py +++ b/cubed/runtime/executors/python_async.py @@ -11,7 +11,11 @@ from cubed.runtime.executors.asyncio import async_map_unordered from cubed.runtime.pipeline import visit_node_generations, visit_nodes from cubed.runtime.types import Callback, CubedPipeline, DagExecutor -from cubed.runtime.utils import execution_stats, handle_callbacks +from cubed.runtime.utils import ( + execution_stats, + handle_callbacks, + handle_operation_start_callbacks, +) from cubed.spec import Spec @@ -92,6 +96,7 @@ async def async_execute_dag( if not compute_arrays_in_parallel: # run one pipeline at a time for name, node in visit_nodes(dag, resume=resume): + handle_operation_start_callbacks(callbacks, name) st = pipeline_to_stream( concurrent_executor, name, node["pipeline"], **kwargs ) diff --git a/cubed/runtime/types.py b/cubed/runtime/types.py index 0777030c..400eea6e 100644 --- a/cubed/runtime/types.py +++ b/cubed/runtime/types.py @@ -49,6 +49,14 @@ class ComputeEndEvent: """The computation DAG.""" +@dataclass +class OperationStartEvent: + """Callback information about an operation that is about to start.""" + + name: str + """Name of the operation.""" + + @dataclass class TaskEndEvent: """Callback information about a completed task (or tasks).""" @@ -101,6 +109,9 @@ def on_compute_end(self, ComputeEndEvent): """ pass # pragma: no cover + def on_operation_start(self, event): + pass + def on_task_end(self, event): """Called when the a task ends. diff --git a/cubed/runtime/utils.py b/cubed/runtime/utils.py index a28f67ec..894e6ba7 100644 --- a/cubed/runtime/utils.py +++ b/cubed/runtime/utils.py @@ -2,7 +2,7 @@ from functools import partial from itertools import islice -from cubed.runtime.types import TaskEndEvent +from cubed.runtime.types import OperationStartEvent, TaskEndEvent from cubed.utils import peak_measured_mem sym_counter = 0 @@ -39,6 +39,12 @@ def execution_stats(func): return partial(execute_with_stats, func) +def handle_operation_start_callbacks(callbacks, name): + if callbacks is not None: + event = OperationStartEvent(name) + [callback.on_operation_start(event) for callback in callbacks] + + def handle_callbacks(callbacks, stats): """Construct a TaskEndEvent from stats and send to all callbacks."""