Skip to content

Commit

Permalink
Warn if peak mem exceeds allowed_mem (#516)
Browse files Browse the repository at this point in the history
* Warn if peak mem exceeds `allowed_mem`

* Remove usage of Python 3.10 API (Counter.total)

* Don't run mem warn test on Windows
  • Loading branch information
tomwhite authored Jul 23, 2024
1 parent ac6f243 commit 59c593d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
35 changes: 35 additions & 0 deletions cubed/extensions/mem_warn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import warnings
from collections import Counter

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback


class MemoryWarningCallback(Callback):
def on_compute_start(self, event):
# store ops keyed by name
self.ops = {}
for name, node in visit_nodes(event.dag, event.resume):
primitive_op = node["primitive_op"]
self.ops[name] = primitive_op

# count number of times each op exceeds allowed mem
self.counter = Counter()

def on_task_end(self, event):
allowed_mem = self.ops[event.name].allowed_mem
if (
event.peak_measured_mem_end is not None
and event.peak_measured_mem_end > allowed_mem
):
self.counter.update({event.name: 1})

def on_compute_end(self, event):
if sum(self.counter.values()) > 0:
exceeded = [
f"{k} ({v}/{self.ops[k].num_tasks})" for k, v in self.counter.items()
]
warnings.warn(
f"Peak memory usage exceeded allowed_mem when running tasks: {', '.join(exceeded)}",
UserWarning,
)
23 changes: 23 additions & 0 deletions cubed/tests/test_executor_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import cubed.array_api as xp
import cubed.random
from cubed.extensions.history import HistoryCallback
from cubed.extensions.mem_warn import MemoryWarningCallback
from cubed.extensions.rich import RichProgressBar
from cubed.extensions.timeline import TimelineVisualizationCallback
from cubed.extensions.tqdm import TqdmProgressBar
Expand Down Expand Up @@ -148,6 +149,28 @@ def test_callbacks_modal(spec, modal_executor):
fs.rm(tmp_path, recursive=True)


@pytest.mark.skipif(
platform.system() == "Windows", reason="measuring memory does not run on windows"
)
def test_mem_warn(tmp_path, executor):
if executor.name not in ("processes", "lithops"):
pytest.skip(f"{executor.name} executor does not support MemoryWarningCallback")

spec = cubed.Spec(tmp_path, allowed_mem=200_000_000, reserved_mem=100_000_000)
mem_warn = MemoryWarningCallback()

def func(a):
np.ones(100_000_000) # blow memory
return a

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = cubed.map_blocks(func, a, dtype=a.dtype)
with pytest.raises(
UserWarning, match="Peak memory usage exceeded allowed_mem when running tasks"
):
b.compute(executor=executor, callbacks=[mem_warn])


def test_resume(spec, executor):
if executor.name == "beam":
pytest.skip(f"{executor.name} executor does not support resume")
Expand Down
4 changes: 3 additions & 1 deletion cubed/tests/test_mem_utilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import cubed.random
from cubed.backend_array_api import namespace as nxp
from cubed.extensions.history import HistoryCallback
from cubed.extensions.mem_warn import MemoryWarningCallback
from cubed.runtime.executors.lithops import LithopsExecutor
from cubed.tests.utils import LITHOPS_LOCAL_CONFIG

Expand Down Expand Up @@ -277,12 +278,13 @@ def run_operation(tmp_path, name, result_array, *, optimize_function=None):
# result_array.visualize(f"cubed-{name}", optimize_function=optimize_function)
executor = LithopsExecutor(config=LITHOPS_LOCAL_CONFIG)
hist = HistoryCallback()
mem_warn = MemoryWarningCallback()
# use store=None to write to temporary zarr
cubed.to_zarr(
result_array,
store=None,
executor=executor,
callbacks=[hist],
callbacks=[hist, mem_warn],
optimize_function=optimize_function,
)

Expand Down

0 comments on commit 59c593d

Please sign in to comment.