Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include task result in callbacks #632

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cubed/runtime/executors/coiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def execute_dag(
coiled_function.cluster.adapt(minimum=minimum_workers)
# coiled expects a sequence (it calls `len` on it)
input = list(pipeline.mappable)
for _, stats in coiled_function.map(input, config=pipeline.config):
for result, stats in coiled_function.map(input, config=pipeline.config):
if callbacks is not None:
if name is not None:
stats["name"] = name
handle_callbacks(callbacks, stats)
handle_callbacks(callbacks, result, stats)
8 changes: 4 additions & 4 deletions cubed/runtime/executors/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ async def async_execute_dag(
handle_operation_start_callbacks(callbacks, name)
st = pipeline_to_stream(client, name, node["pipeline"], **kwargs)
async with st.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
# run pipelines in the same topological generation in parallel by merging their streams
Expand All @@ -145,8 +145,8 @@ async def async_execute_dag(
]
merged_stream = stream.merge(*streams)
async with merged_stream.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)


class DaskExecutor(DagExecutor):
Expand Down
8 changes: 4 additions & 4 deletions cubed/runtime/executors/lithops.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def execute_dag(
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
pipeline = node["pipeline"]
for _, stats in map_unordered(
for result, stats in map_unordered(
executor,
[run_func],
[pipeline.mappable],
Expand All @@ -207,7 +207,7 @@ def execute_dag(
name=name,
compute_id=compute_id,
):
handle_callbacks(callbacks, stats)
handle_callbacks(callbacks, result, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
group_map_functions = []
Expand All @@ -223,7 +223,7 @@ def execute_dag(
group_names.append(name)
for name in group_names:
handle_operation_start_callbacks(callbacks, name)
for _, stats in map_unordered(
for result, stats in map_unordered(
executor,
group_map_functions,
group_map_iterdata,
Expand All @@ -234,7 +234,7 @@ def execute_dag(
# TODO: other kwargs (func, config, name)
compute_id=compute_id,
):
handle_callbacks(callbacks, stats)
handle_callbacks(callbacks, result, stats)


def standardise_lithops_stats(future: RetryingFuture) -> Dict[str, Any]:
Expand Down
12 changes: 6 additions & 6 deletions cubed/runtime/executors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def execute_dag(
handle_operation_start_callbacks(callbacks, name)
pipeline: CubedPipeline = node["pipeline"]
for m in pipeline.mappable:
exec_stage_func(
result = exec_stage_func(
m,
pipeline.function,
config=pipeline.config,
name=name,
compute_id=compute_id,
)
if callbacks is not None:
event = TaskEndEvent(name=name)
event = TaskEndEvent(name=name, result=result)
[callback.on_task_end(event) for callback in callbacks]


Expand Down Expand Up @@ -223,8 +223,8 @@ async def async_execute_dag(
concurrent_executor, run_func, name, node["pipeline"], **kwargs
)
async with st.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
# run pipelines in the same topological generation in parallel by merging their streams
Expand All @@ -236,8 +236,8 @@ async def async_execute_dag(
]
merged_stream = stream.merge(*streams)
async with merged_stream.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)

finally:
# don't wait for any cancelled tasks
Expand Down
8 changes: 4 additions & 4 deletions cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ async def async_execute_dag(
handle_operation_start_callbacks(callbacks, name)
st = pipeline_to_stream(app_function, name, node["pipeline"], **kwargs)
async with st.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
# run pipelines in the same topological generation in parallel by merging their streams
Expand All @@ -237,8 +237,8 @@ async def async_execute_dag(
]
merged_stream = stream.merge(*streams)
async with merged_stream.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
async for result, stats in streamer:
handle_callbacks(callbacks, result, stats)


class ModalExecutor(DagExecutor):
Expand Down
5 changes: 4 additions & 1 deletion cubed/runtime/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Iterable, Optional
from typing import Any, Iterable, Optional

from networkx import MultiDiGraph

Expand Down Expand Up @@ -71,6 +71,9 @@ class TaskEndEvent:
num_tasks: int = 1
"""Number of tasks that this event applies to (default 1)."""

result: Optional[Any] = None
"""Return value of the task."""

task_create_tstamp: Optional[float] = None
"""Timestamp of when the task was created by the client."""

Expand Down
5 changes: 3 additions & 2 deletions cubed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,19 @@ def handle_operation_start_callbacks(callbacks, name):
[callback.on_operation_start(event) for callback in callbacks]


def handle_callbacks(callbacks, stats):
def handle_callbacks(callbacks, result, stats):
"""Construct a TaskEndEvent from stats and send to all callbacks."""

if callbacks is not None:
if "task_result_tstamp" not in stats:
task_result_tstamp = time.time()
event = TaskEndEvent(
result=result,
task_result_tstamp=task_result_tstamp,
**stats,
)
else:
event = TaskEndEvent(**stats)
event = TaskEndEvent(result=result, **stats)
[callback.on_task_end(event) for callback in callbacks]


Expand Down
Loading