Skip to content

Commit

Permalink
Include task result in callbacks (#632)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Dec 4, 2024
1 parent f949583 commit cbe7f66
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 23 deletions.
4 changes: 2 additions & 2 deletions cubed/runtime/executors/coiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def execute_dag(
coiled_function.cluster.adapt(minimum=minimum_workers)
# coiled expects a sequence (it calls `len` on it)
input = list(pipeline.mappable)
for _, stats in coiled_function.map(input, config=pipeline.config):
for result, stats in coiled_function.map(input, config=pipeline.config):
if callbacks is not None:
if name is not None:
stats["name"] = name
handle_callbacks(callbacks, stats)
handle_callbacks(callbacks, result, stats)
8 changes: 4 additions & 4 deletions cubed/runtime/executors/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ async def async_execute_dag(
handle_operation_start_callbacks(callbacks, name)
st = pipeline_to_stream(client, name, node["pipeline"], **kwargs)
async with st.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
# run pipelines in the same topological generation in parallel by merging their streams
Expand All @@ -145,8 +145,8 @@ async def async_execute_dag(
]
merged_stream = stream.merge(*streams)
async with merged_stream.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)


class DaskExecutor(DagExecutor):
Expand Down
8 changes: 4 additions & 4 deletions cubed/runtime/executors/lithops.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def execute_dag(
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
pipeline = node["pipeline"]
for _, stats in map_unordered(
for result, stats in map_unordered(
executor,
[run_func],
[pipeline.mappable],
Expand All @@ -207,7 +207,7 @@ def execute_dag(
name=name,
compute_id=compute_id,
):
handle_callbacks(callbacks, stats)
handle_callbacks(callbacks, result, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
group_map_functions = []
Expand All @@ -223,7 +223,7 @@ def execute_dag(
group_names.append(name)
for name in group_names:
handle_operation_start_callbacks(callbacks, name)
for _, stats in map_unordered(
for result, stats in map_unordered(
executor,
group_map_functions,
group_map_iterdata,
Expand All @@ -234,7 +234,7 @@ def execute_dag(
# TODO: other kwargs (func, config, name)
compute_id=compute_id,
):
handle_callbacks(callbacks, stats)
handle_callbacks(callbacks, result, stats)


def standardise_lithops_stats(future: RetryingFuture) -> Dict[str, Any]:
Expand Down
12 changes: 6 additions & 6 deletions cubed/runtime/executors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def execute_dag(
handle_operation_start_callbacks(callbacks, name)
pipeline: CubedPipeline = node["pipeline"]
for m in pipeline.mappable:
exec_stage_func(
result = exec_stage_func(
m,
pipeline.function,
config=pipeline.config,
name=name,
compute_id=compute_id,
)
if callbacks is not None:
event = TaskEndEvent(name=name)
event = TaskEndEvent(name=name, result=result)
[callback.on_task_end(event) for callback in callbacks]


Expand Down Expand Up @@ -223,8 +223,8 @@ async def async_execute_dag(
concurrent_executor, run_func, name, node["pipeline"], **kwargs
)
async with st.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
# run pipelines in the same topological generation in parallel by merging their streams
Expand All @@ -236,8 +236,8 @@ async def async_execute_dag(
]
merged_stream = stream.merge(*streams)
async with merged_stream.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)

finally:
# don't wait for any cancelled tasks
Expand Down
8 changes: 4 additions & 4 deletions cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ async def async_execute_dag(
handle_operation_start_callbacks(callbacks, name)
st = pipeline_to_stream(app_function, name, node["pipeline"], **kwargs)
async with st.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
# run pipelines in the same topological generation in parallel by merging their streams
Expand All @@ -237,8 +237,8 @@ async def async_execute_dag(
]
merged_stream = stream.merge(*streams)
async with merged_stream.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)


class ModalExecutor(DagExecutor):
Expand Down
5 changes: 4 additions & 1 deletion cubed/runtime/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Iterable, Optional
from typing import Any, Iterable, Optional

from networkx import MultiDiGraph

Expand Down Expand Up @@ -71,6 +71,9 @@ class TaskEndEvent:
num_tasks: int = 1
"""Number of tasks that this event applies to (default 1)."""

result: Optional[Any] = None
"""Return value of the task."""

task_create_tstamp: Optional[float] = None
"""Timestamp of when the task was created by the client."""

Expand Down
5 changes: 3 additions & 2 deletions cubed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,19 @@ def handle_operation_start_callbacks(callbacks, name):
[callback.on_operation_start(event) for callback in callbacks]


def handle_callbacks(callbacks, stats):
def handle_callbacks(callbacks, result, stats):
"""Construct a TaskEndEvent from stats and send to all callbacks."""

if callbacks is not None:
if "task_result_tstamp" not in stats:
task_result_tstamp = time.time()
event = TaskEndEvent(
result=result,
task_result_tstamp=task_result_tstamp,
**stats,
)
else:
event = TaskEndEvent(**stats)
event = TaskEndEvent(result=result, **stats)
[callback.on_task_end(event) for callback in callbacks]


Expand Down

0 comments on commit cbe7f66

Please sign in to comment.