Skip to content

Commit

Permalink
Add callback for when an operation starts
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Feb 16, 2024
1 parent 53b889d commit d429d75
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 7 deletions.
7 changes: 6 additions & 1 deletion cubed/runtime/executors/coiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback, DagExecutor
from cubed.runtime.utils import execution_stats, handle_callbacks
from cubed.runtime.utils import (
execution_stats,
handle_callbacks,
handle_operation_start_callbacks,
)
from cubed.spec import Spec


Expand All @@ -27,6 +31,7 @@ def execute_dag(
) -> None:
# Note this currently only builds the task graph for each stage once it gets to that stage in computation
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
pipeline = node["pipeline"]
coiled_function = make_coiled_function(pipeline.function, coiled_kwargs)
input = list(
Expand Down
8 changes: 7 additions & 1 deletion cubed/runtime/executors/dask_distributed_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from cubed.runtime.executors.asyncio import async_map_unordered
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, CubedPipeline, DagExecutor
from cubed.runtime.utils import execution_stats, gensym, handle_callbacks
from cubed.runtime.utils import (
execution_stats,
gensym,
handle_callbacks,
handle_operation_start_callbacks,
)
from cubed.spec import Spec


Expand Down Expand Up @@ -123,6 +128,7 @@ async def async_execute_dag(
if not compute_arrays_in_parallel:
# run one pipeline at a time
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
st = pipeline_to_stream(client, name, node["pipeline"], **kwargs)
async with st.stream() as streamer:
async for _, stats in streamer:
Expand Down
5 changes: 4 additions & 1 deletion cubed/runtime/executors/lithops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, DagExecutor
from cubed.runtime.utils import handle_callbacks
from cubed.runtime.utils import handle_callbacks, handle_operation_start_callbacks
from cubed.spec import Spec

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -180,6 +180,7 @@ def execute_dag(
with RetryingFunctionExecutor(function_executor) as executor:
if not compute_arrays_in_parallel:
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
pipeline = node["pipeline"]
for _, stats in map_unordered(
executor,
Expand Down Expand Up @@ -207,6 +208,8 @@ def execute_dag(
group_map_functions.append(f)
group_map_iterdata.append(pipeline.mappable)
group_names.append(name)
for name in group_names:
handle_operation_start_callbacks(callbacks, name)
for _, stats in map_unordered(
executor,
group_map_functions,
Expand Down
7 changes: 6 additions & 1 deletion cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback, DagExecutor
from cubed.runtime.utils import execute_with_stats, handle_callbacks
from cubed.runtime.utils import (
execute_with_stats,
handle_callbacks,
handle_operation_start_callbacks,
)
from cubed.spec import Spec

RUNTIME_MEMORY_MIB = 2000
Expand Down Expand Up @@ -128,6 +132,7 @@ def execute_dag(
else:
raise ValueError(f"Unrecognized cloud: {cloud}")
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
pipeline = node["pipeline"]
task_create_tstamp = time.time()
for _, stats in app_function.map(
Expand Down
3 changes: 2 additions & 1 deletion cubed/runtime/executors/modal_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, DagExecutor
from cubed.runtime.utils import handle_callbacks
from cubed.runtime.utils import handle_callbacks, handle_operation_start_callbacks
from cubed.spec import Spec


Expand Down Expand Up @@ -127,6 +127,7 @@ async def async_execute_dag(
if not compute_arrays_in_parallel:
# run one pipeline at a time
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
st = pipeline_to_stream(app_function, name, node["pipeline"], **kwargs)
async with st.stream() as streamer:
async for _, stats in streamer:
Expand Down
2 changes: 2 additions & 0 deletions cubed/runtime/executors/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback, CubedPipeline, DagExecutor, TaskEndEvent
from cubed.runtime.utils import handle_operation_start_callbacks
from cubed.spec import Spec


Expand All @@ -24,6 +25,7 @@ def execute_dag(
**kwargs,
) -> None:
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
pipeline: CubedPipeline = node["pipeline"]
for m in pipeline.mappable:
exec_stage_func(
Expand Down
7 changes: 6 additions & 1 deletion cubed/runtime/executors/python_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from cubed.runtime.executors.asyncio import async_map_unordered
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, CubedPipeline, DagExecutor
from cubed.runtime.utils import execution_stats, handle_callbacks
from cubed.runtime.utils import (
execution_stats,
handle_callbacks,
handle_operation_start_callbacks,
)
from cubed.spec import Spec


Expand Down Expand Up @@ -92,6 +96,7 @@ async def async_execute_dag(
if not compute_arrays_in_parallel:
# run one pipeline at a time
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
st = pipeline_to_stream(
concurrent_executor, name, node["pipeline"], **kwargs
)
Expand Down
11 changes: 11 additions & 0 deletions cubed/runtime/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ class ComputeEndEvent:
"""The computation DAG."""


@dataclass
class OperationStartEvent:
"""Callback information about an operation that is about to start."""

name: str
"""Name of the operation."""


@dataclass
class TaskEndEvent:
"""Callback information about a completed task (or tasks)."""
Expand Down Expand Up @@ -101,6 +109,9 @@ def on_compute_end(self, ComputeEndEvent):
"""
pass # pragma: no cover

def on_operation_start(self, event):
pass

def on_task_end(self, event):
"""Called when the a task ends.
Expand Down
8 changes: 7 additions & 1 deletion cubed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
from itertools import islice

from cubed.runtime.types import TaskEndEvent
from cubed.runtime.types import OperationStartEvent, TaskEndEvent
from cubed.utils import peak_measured_mem

sym_counter = 0
Expand Down Expand Up @@ -39,6 +39,12 @@ def execution_stats(func):
return partial(execute_with_stats, func)


def handle_operation_start_callbacks(callbacks, name):
if callbacks is not None:
event = OperationStartEvent(name)
[callback.on_operation_start(event) for callback in callbacks]


def handle_callbacks(callbacks, stats):
"""Construct a TaskEndEvent from stats and send to all callbacks."""

Expand Down

0 comments on commit d429d75

Please sign in to comment.