From 8c66014029418dbb1e9bc4ac61cde45ea642080d Mon Sep 17 00:00:00 2001 From: Tom White Date: Sun, 1 Dec 2024 11:01:20 +0000 Subject: [PATCH] Include task result in callbacks --- cubed/runtime/executors/coiled.py | 4 ++-- cubed/runtime/executors/dask.py | 8 ++++---- cubed/runtime/executors/lithops.py | 8 ++++---- cubed/runtime/executors/local.py | 12 ++++++------ cubed/runtime/executors/modal.py | 8 ++++---- cubed/runtime/types.py | 5 ++++- cubed/runtime/utils.py | 5 +++-- 7 files changed, 27 insertions(+), 23 deletions(-) diff --git a/cubed/runtime/executors/coiled.py b/cubed/runtime/executors/coiled.py index fc7b51f4..3ecbc40b 100644 --- a/cubed/runtime/executors/coiled.py +++ b/cubed/runtime/executors/coiled.py @@ -47,8 +47,8 @@ def execute_dag( coiled_function.cluster.adapt(minimum=minimum_workers) # coiled expects a sequence (it calls `len` on it) input = list(pipeline.mappable) - for _, stats in coiled_function.map(input, config=pipeline.config): + for result, stats in coiled_function.map(input, config=pipeline.config): if callbacks is not None: if name is not None: stats["name"] = name - handle_callbacks(callbacks, stats) + handle_callbacks(callbacks, result, stats) diff --git a/cubed/runtime/executors/dask.py b/cubed/runtime/executors/dask.py index 4c6e6471..67d5098b 100644 --- a/cubed/runtime/executors/dask.py +++ b/cubed/runtime/executors/dask.py @@ -134,8 +134,8 @@ async def async_execute_dag( handle_operation_start_callbacks(callbacks, name) st = pipeline_to_stream(client, name, node["pipeline"], **kwargs) async with st.stream() as streamer: - async for _, stats in streamer: - handle_callbacks(callbacks, stats) + async for result, stats in streamer: + handle_callbacks(callbacks, result, stats) else: for gen in visit_node_generations(dag, resume=resume): # run pipelines in the same topological generation in parallel by merging their streams @@ -145,8 +145,8 @@ async def async_execute_dag( ] merged_stream = stream.merge(*streams) async with merged_stream.stream() as streamer: - async for _, stats in streamer: - handle_callbacks(callbacks, stats) + async for result, stats in streamer: + handle_callbacks(callbacks, result, stats) class DaskExecutor(DagExecutor): diff --git a/cubed/runtime/executors/lithops.py b/cubed/runtime/executors/lithops.py index 870f1c90..f9974fa3 100644 --- a/cubed/runtime/executors/lithops.py +++ b/cubed/runtime/executors/lithops.py @@ -193,7 +193,7 @@ def execute_dag( for name, node in visit_nodes(dag, resume=resume): handle_operation_start_callbacks(callbacks, name) pipeline = node["pipeline"] - for _, stats in map_unordered( + for result, stats in map_unordered( executor, [run_func], [pipeline.mappable], @@ -207,7 +207,7 @@ def execute_dag( name=name, compute_id=compute_id, ): - handle_callbacks(callbacks, stats) + handle_callbacks(callbacks, result, stats) else: for gen in visit_node_generations(dag, resume=resume): group_map_functions = [] @@ -223,7 +223,7 @@ def execute_dag( group_names.append(name) for name in group_names: handle_operation_start_callbacks(callbacks, name) - for _, stats in map_unordered( + for result, stats in map_unordered( executor, group_map_functions, group_map_iterdata, @@ -234,7 +234,7 @@ def execute_dag( # TODO: other kwargs (func, config, name) compute_id=compute_id, ): - handle_callbacks(callbacks, stats) + handle_callbacks(callbacks, result, stats) def standardise_lithops_stats(future: RetryingFuture) -> Dict[str, Any]: diff --git a/cubed/runtime/executors/local.py b/cubed/runtime/executors/local.py index 21ef55fe..1614f02c 100644 --- a/cubed/runtime/executors/local.py +++ b/cubed/runtime/executors/local.py @@ -50,7 +50,7 @@ def execute_dag( handle_operation_start_callbacks(callbacks, name) pipeline: CubedPipeline = node["pipeline"] for m in pipeline.mappable: - exec_stage_func( + result = exec_stage_func( m, pipeline.function, config=pipeline.config, @@ -58,7 +58,7 @@ def execute_dag( compute_id=compute_id, ) if callbacks is not None: - event = TaskEndEvent(name=name) + event = TaskEndEvent(name=name, result=result) [callback.on_task_end(event) for callback in callbacks] @@ -223,8 +223,8 @@ async def async_execute_dag( concurrent_executor, run_func, name, node["pipeline"], **kwargs ) async with st.stream() as streamer: - async for _, stats in streamer: - handle_callbacks(callbacks, stats) + async for result, stats in streamer: + handle_callbacks(callbacks, result, stats) else: for gen in visit_node_generations(dag, resume=resume): # run pipelines in the same topological generation in parallel by merging their streams @@ -236,8 +236,8 @@ async def async_execute_dag( ] merged_stream = stream.merge(*streams) async with merged_stream.stream() as streamer: - async for _, stats in streamer: - handle_callbacks(callbacks, stats) + async for result, stats in streamer: + handle_callbacks(callbacks, result, stats) finally: # don't wait for any cancelled tasks diff --git a/cubed/runtime/executors/modal.py b/cubed/runtime/executors/modal.py index e1c441d8..520877f9 100644 --- a/cubed/runtime/executors/modal.py +++ b/cubed/runtime/executors/modal.py @@ -226,8 +226,8 @@ async def async_execute_dag( handle_operation_start_callbacks(callbacks, name) st = pipeline_to_stream(app_function, name, node["pipeline"], **kwargs) async with st.stream() as streamer: - async for _, stats in streamer: - handle_callbacks(callbacks, stats) + async for result, stats in streamer: + handle_callbacks(callbacks, result, stats) else: for gen in visit_node_generations(dag, resume=resume): # run pipelines in the same topological generation in parallel by merging their streams @@ -237,8 +237,8 @@ async def async_execute_dag( ] merged_stream = stream.merge(*streams) async with merged_stream.stream() as streamer: - async for _, stats in streamer: - handle_callbacks(callbacks, stats) + async for result, stats in streamer: + handle_callbacks(callbacks, result, stats) class ModalExecutor(DagExecutor): diff --git a/cubed/runtime/types.py b/cubed/runtime/types.py index 1b26c34a..37f63634 100644 --- a/cubed/runtime/types.py +++ b/cubed/runtime/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Iterable, Optional +from typing import Any, Iterable, Optional from networkx import MultiDiGraph @@ -71,6 +71,9 @@ class TaskEndEvent: num_tasks: int = 1 """Number of tasks that this event applies to (default 1).""" + result: Optional[Any] = None + """Return value of the task.""" + task_create_tstamp: Optional[float] = None """Timestamp of when the task was created by the client.""" diff --git a/cubed/runtime/utils.py b/cubed/runtime/utils.py index 8c611c8d..cd97bb68 100644 --- a/cubed/runtime/utils.py +++ b/cubed/runtime/utils.py @@ -99,18 +99,19 @@ def handle_operation_start_callbacks(callbacks, name): [callback.on_operation_start(event) for callback in callbacks] -def handle_callbacks(callbacks, stats): +def handle_callbacks(callbacks, result, stats): """Construct a TaskEndEvent from stats and send to all callbacks.""" if callbacks is not None: if "task_result_tstamp" not in stats: task_result_tstamp = time.time() event = TaskEndEvent( + result=result, task_result_tstamp=task_result_tstamp, **stats, ) else: - event = TaskEndEvent(**stats) + event = TaskEndEvent(result=result, **stats) [callback.on_task_end(event) for callback in callbacks]