diff --git a/cubed/extensions/mem_warn.py b/cubed/extensions/mem_warn.py new file mode 100644 index 00000000..03f5b2bf --- /dev/null +++ b/cubed/extensions/mem_warn.py @@ -0,0 +1,35 @@ +import warnings +from collections import Counter + +from cubed.runtime.pipeline import visit_nodes +from cubed.runtime.types import Callback + + +class MemoryWarningCallback(Callback): + def on_compute_start(self, event): + # store ops keyed by name + self.ops = {} + for name, node in visit_nodes(event.dag, event.resume): + primitive_op = node["primitive_op"] + self.ops[name] = primitive_op + + # count number of times each op exceeds allowed mem + self.counter = Counter() + + def on_task_end(self, event): + allowed_mem = self.ops[event.name].allowed_mem + if ( + event.peak_measured_mem_end is not None + and event.peak_measured_mem_end > allowed_mem + ): + self.counter.update({event.name: 1}) + + def on_compute_end(self, event): + if sum(self.counter.values()) > 0: + exceeded = [ + f"{k} ({v}/{self.ops[k].num_tasks})" for k, v in self.counter.items() + ] + warnings.warn( + f"Peak memory usage exceeded allowed_mem when running tasks: {', '.join(exceeded)}", + UserWarning, + ) diff --git a/cubed/tests/test_executor_features.py b/cubed/tests/test_executor_features.py index 6ba3f2cc..22fce7bf 100644 --- a/cubed/tests/test_executor_features.py +++ b/cubed/tests/test_executor_features.py @@ -13,6 +13,7 @@ import cubed.array_api as xp import cubed.random from cubed.extensions.history import HistoryCallback +from cubed.extensions.mem_warn import MemoryWarningCallback from cubed.extensions.rich import RichProgressBar from cubed.extensions.timeline import TimelineVisualizationCallback from cubed.extensions.tqdm import TqdmProgressBar @@ -148,6 +149,28 @@ def test_callbacks_modal(spec, modal_executor): fs.rm(tmp_path, recursive=True) +@pytest.mark.skipif( + platform.system() == "Windows", reason="measuring memory does not run on windows" +) +def test_mem_warn(tmp_path, executor): + if executor.name not in ("processes", "lithops"): + pytest.skip(f"{executor.name} executor does not support MemoryWarningCallback") + + spec = cubed.Spec(tmp_path, allowed_mem=200_000_000, reserved_mem=100_000_000) + mem_warn = MemoryWarningCallback() + + def func(a): + np.ones(100_000_000) # blow memory + return a + + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = cubed.map_blocks(func, a, dtype=a.dtype) + with pytest.raises( + UserWarning, match="Peak memory usage exceeded allowed_mem when running tasks" + ): + b.compute(executor=executor, callbacks=[mem_warn]) + + def test_resume(spec, executor): if executor.name == "beam": pytest.skip(f"{executor.name} executor does not support resume") diff --git a/cubed/tests/test_mem_utilization.py b/cubed/tests/test_mem_utilization.py index 853f993f..85f067c8 100644 --- a/cubed/tests/test_mem_utilization.py +++ b/cubed/tests/test_mem_utilization.py @@ -14,6 +14,7 @@ import cubed.random from cubed.backend_array_api import namespace as nxp from cubed.extensions.history import HistoryCallback +from cubed.extensions.mem_warn import MemoryWarningCallback from cubed.runtime.executors.lithops import LithopsExecutor from cubed.tests.utils import LITHOPS_LOCAL_CONFIG @@ -277,12 +278,13 @@ def run_operation(tmp_path, name, result_array, *, optimize_function=None): # result_array.visualize(f"cubed-{name}", optimize_function=optimize_function) executor = LithopsExecutor(config=LITHOPS_LOCAL_CONFIG) hist = HistoryCallback() + mem_warn = MemoryWarningCallback() # use store=None to write to temporary zarr cubed.to_zarr( result_array, store=None, executor=executor, - callbacks=[hist], + callbacks=[hist, mem_warn], optimize_function=optimize_function, )