Skip to content

Commit

Permalink
Adding compile_function as execute option. (#536)
Browse files Browse the repository at this point in the history
* Adding `compile_function` as execute option.

This allows users to use jit or aot compilation during the dag finalization process within executors. It should work straightforwardly on jax/numba style jit compilation. It's possible, but maybe ugly, to perform jax-aot-style compilation.

* Add numba for compilation tests.

* Singlequotes not needed.

* Update function doc.

* Added another compile test for failure case.

* Added another test to ensure config was applied.

* Improve tests

- remove todo for new test
- use pytest conventions

* I don’t think numba jit works well on jax arrays.

* Update plan.py

Simplifying type to make mypy happy.
  • Loading branch information
alxmrs authored Aug 7, 2024
1 parent e8aaf3f commit 479f1bf
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 2 deletions.
40 changes: 38 additions & 2 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import atexit
import dataclasses
import inspect
import shutil
import tempfile
Expand Down Expand Up @@ -30,6 +31,8 @@
# Delete local context dirs when Python exits
CONTEXT_DIRS = set()

Decorator = Callable


def delete_on_exit(context_dir: str) -> None:
if context_dir not in CONTEXT_DIRS and is_local_path(context_dir):
Expand Down Expand Up @@ -200,13 +203,45 @@ def _create_lazy_zarr_arrays(self, dag):

return dag

def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGraph:
"""Compiles functions from all blockwise ops by mutating the input dag."""
# Recommended: make a copy of the dag before calling this function.

compile_with_config = 'config' in inspect.getfullargspec(compile_function).kwonlyargs

for n in dag.nodes:
node = dag.nodes[n]

if "primitive_op" not in node:
continue

if not isinstance(node["pipeline"].config, BlockwiseSpec):
continue

if compile_with_config:
compiled = compile_function(node["pipeline"].config.function, config=node["pipeline"].config)
else:
compiled = compile_function(node["pipeline"].config.function)

# node is a blockwise primitive_op.
# maybe we should investigate some sort of optics library for frozen dataclasses...
new_pipeline = dataclasses.replace(
node["pipeline"],
config=dataclasses.replace(node["pipeline"].config, function=compiled)
)
node["pipeline"] = new_pipeline

return dag

@lru_cache
def _finalize_dag(
self, optimize_graph: bool = True, optimize_function=None
self, optimize_graph: bool = True, optimize_function=None, compile_function: Optional[Decorator] = None,
) -> nx.MultiDiGraph:
dag = self.optimize(optimize_function).dag if optimize_graph else self.dag
# create a copy since _create_lazy_zarr_arrays mutates the dag
dag = dag.copy()
if callable(compile_function):
dag = self._compile_blockwise(dag, compile_function)
dag = self._create_lazy_zarr_arrays(dag)
return nx.freeze(dag)

Expand All @@ -216,11 +251,12 @@ def execute(
callbacks=None,
optimize_graph=True,
optimize_function=None,
compile_function=None,
resume=None,
spec=None,
**kwargs,
):
dag = self._finalize_dag(optimize_graph, optimize_function)
dag = self._finalize_dag(optimize_graph, optimize_function, compile_function)

compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}"

Expand Down
52 changes: 52 additions & 0 deletions cubed/tests/test_executor_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,55 @@ def test_check_runtime_memory_processes(spec, executor):

# OK if we use fewer workers
c.compute(executor=executor, max_workers=max_workers // 2)


COMPILE_FUNCTIONS = [lambda fn: fn]

try:
from numba import jit as numba_jit
COMPILE_FUNCTIONS.append(numba_jit)
except ModuleNotFoundError:
pass

try:
if 'jax' in os.environ.get('CUBED_BACKEND_ARRAY_API_MODULE', ''):
from jax import jit as jax_jit
COMPILE_FUNCTIONS.append(jax_jit)
except ModuleNotFoundError:
pass


@pytest.mark.parametrize("compile_function", COMPILE_FUNCTIONS)
def test_check_compilation(spec, executor, compile_function):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
assert_array_equal(
c.compute(executor=executor, compile_function=compile_function), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]])
)


def test_compilation_can_fail(spec, executor):
def compile_function(func):
raise NotImplementedError(f"Cannot compile {func}")

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
with pytest.raises(NotImplementedError) as excinfo:
c.compute(executor=executor, compile_function=compile_function)

assert "add" in str(excinfo.value), "Compile function was applied to add operation."


def test_compilation_with_config_can_fail(spec, executor):
def compile_function(func, *, config=None):
raise NotImplementedError(f"Cannot compile {func} with {config}")

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
with pytest.raises(NotImplementedError) as excinfo:
c.compute(executor=executor, compile_function=compile_function)

assert "BlockwiseSpec" in str(excinfo.value), "Compile function was applied with a config argument."

0 comments on commit 479f1bf

Please sign in to comment.