Skip to content

Commit

Permalink
Check allowed mem does not exceed total on machine for processes ex…
Browse files Browse the repository at this point in the history
…ecutor (#517)

Add psutil dependency
  • Loading branch information
tomwhite authored Jul 23, 2024
1 parent 32c9ab2 commit ac6f243
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 2 deletions.
17 changes: 15 additions & 2 deletions cubed/runtime/executors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, AsyncIterator, Callable, Iterable, Optional, Sequence

import cloudpickle
import psutil
from aiostream import stream
from aiostream.core import Stream
from networkx import MultiDiGraph
Expand Down Expand Up @@ -154,6 +155,16 @@ def pipeline_to_stream(
)


def check_runtime_memory(spec, max_workers):
allowed_mem = spec.allowed_mem if spec is not None else None
total_mem = psutil.virtual_memory().total
if allowed_mem is not None:
if total_mem < allowed_mem * max_workers:
raise ValueError(
f"Total memory on machine ({total_mem}) is less than allowed_mem * max_workers ({allowed_mem} * {max_workers} = {allowed_mem * max_workers})"
)


async def async_execute_dag(
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
Expand All @@ -163,16 +174,18 @@ async def async_execute_dag(
**kwargs,
) -> None:
concurrent_executor: Executor
max_workers = kwargs.pop("max_workers", os.cpu_count())
use_processes = kwargs.pop("use_processes", False)
if spec is not None:
check_runtime_memory(spec, max_workers)
if use_processes:
max_workers = kwargs.pop("max_workers", None)
context = multiprocessing.get_context("spawn")
# max_tasks_per_child is only supported from Python 3.11
concurrent_executor = ProcessPoolExecutor(
max_workers=max_workers, mp_context=context, max_tasks_per_child=1
)
else:
concurrent_executor = ThreadPoolExecutor()
concurrent_executor = ThreadPoolExecutor(max_workers=max_workers)
try:
if not compute_arrays_in_parallel:
# run one pipeline at a time
Expand Down
2 changes: 2 additions & 0 deletions cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"psutil",
"pytest-mock", # TODO: only needed for tests
"s3fs",
"tenacity",
Expand All @@ -55,6 +56,7 @@
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"psutil",
"pytest-mock", # TODO: only needed for tests
"gcsfs",
"tenacity",
Expand Down
28 changes: 28 additions & 0 deletions cubed/tests/test_executor_features.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import contextlib
import os
import platform
import re

import fsspec
import numpy as np
import psutil
import pytest
from numpy.testing import assert_array_equal

Expand Down Expand Up @@ -264,3 +267,28 @@ def test_check_runtime_memory_modal(spec, modal_executor):
match=r"Runtime memory \(2097152000\) is less than allowed_mem \(4000000000\)",
):
c.compute(executor=modal_executor)


def test_check_runtime_memory_processes(spec, executor):
if executor.name != "processes":
pytest.skip(f"{executor.name} executor does not support check_runtime_memory")

total_mem = psutil.virtual_memory().total
max_workers = os.cpu_count()
mem_per_worker = total_mem // max_workers
allowed_mem = mem_per_worker * 2 # larger than will fit

spec = cubed.Spec(spec.work_dir, allowed_mem=allowed_mem)
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
with pytest.raises(
ValueError,
match=re.escape(
f"Total memory on machine ({total_mem}) is less than allowed_mem * max_workers ({allowed_mem} * {max_workers} = {allowed_mem * max_workers})"
),
):
c.compute(executor=executor)

# OK if we use fewer workers
c.compute(executor=executor, max_workers=max_workers // 2)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"ndindex",
"networkx != 2.8.3, != 2.8.4, != 2.8.5, != 2.8.6, != 2.8.7, != 2.8.8, != 3.0.*, != 3.1.*, != 3.2.*",
"numpy >= 1.22",
"psutil",
"tenacity",
"toolz",
"zarr",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mypy_extensions # for rechunker
ndindex
networkx != 2.8.3, != 2.8.4, != 2.8.5, != 2.8.6, != 2.8.7, != 2.8.8, != 3.0.*, != 3.1.*, != 3.2.*
numpy >= 1.22
psutil
tenacity
toolz
zarr
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-pandas.*]
ignore_missing_imports = True
[mypy-psutil.*]
ignore_missing_imports = True
[mypy-pylab.*]
ignore_missing_imports = True
[mypy-pytest.*]
Expand Down

0 comments on commit ac6f243

Please sign in to comment.