From 479f1bf1cc3554177ef29ea10d36002cb1a7705d Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Wed, 7 Aug 2024 09:29:15 +0100 Subject: [PATCH] Adding `compile_function` as execute option. (#536) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Adding `compile_function` as execute option. This allows users to use jit or aot compilation during the dag finalization process within executors. It should work straightforwardly on jax/numba style jit compilation. It's possible, but maybe ugly, to perform jax-aot-style compilation. * Add numba for compilation tests. * Singlequotes not needed. * Update function doc. * Added another compile test for failure case. * Added another test to ensure config was applied. * Improve tests - remove todo for new test - use pytest conventions * I don’t think numba jit works well on jax arrays. * Update plan.py Simplifying type to make mypy happy. --- cubed/core/plan.py | 40 +++++++++++++++++++-- cubed/tests/test_executor_features.py | 52 +++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 2 deletions(-) diff --git a/cubed/core/plan.py b/cubed/core/plan.py index e91070bf..9dc9f22b 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -1,4 +1,5 @@ import atexit +import dataclasses import inspect import shutil import tempfile @@ -30,6 +31,8 @@ # Delete local context dirs when Python exits CONTEXT_DIRS = set() +Decorator = Callable + def delete_on_exit(context_dir: str) -> None: if context_dir not in CONTEXT_DIRS and is_local_path(context_dir): @@ -200,13 +203,45 @@ def _create_lazy_zarr_arrays(self, dag): return dag + def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGraph: + """Compiles functions from all blockwise ops by mutating the input dag.""" + # Recommended: make a copy of the dag before calling this function. + + compile_with_config = 'config' in inspect.getfullargspec(compile_function).kwonlyargs + + for n in dag.nodes: + node = dag.nodes[n] + + if "primitive_op" not in node: + continue + + if not isinstance(node["pipeline"].config, BlockwiseSpec): + continue + + if compile_with_config: + compiled = compile_function(node["pipeline"].config.function, config=node["pipeline"].config) + else: + compiled = compile_function(node["pipeline"].config.function) + + # node is a blockwise primitive_op. + # maybe we should investigate some sort of optics library for frozen dataclasses... + new_pipeline = dataclasses.replace( + node["pipeline"], + config=dataclasses.replace(node["pipeline"].config, function=compiled) + ) + node["pipeline"] = new_pipeline + + return dag + @lru_cache def _finalize_dag( - self, optimize_graph: bool = True, optimize_function=None + self, optimize_graph: bool = True, optimize_function=None, compile_function: Optional[Decorator] = None, ) -> nx.MultiDiGraph: dag = self.optimize(optimize_function).dag if optimize_graph else self.dag # create a copy since _create_lazy_zarr_arrays mutates the dag dag = dag.copy() + if callable(compile_function): + dag = self._compile_blockwise(dag, compile_function) dag = self._create_lazy_zarr_arrays(dag) return nx.freeze(dag) @@ -216,11 +251,12 @@ def execute( callbacks=None, optimize_graph=True, optimize_function=None, + compile_function=None, resume=None, spec=None, **kwargs, ): - dag = self._finalize_dag(optimize_graph, optimize_function) + dag = self._finalize_dag(optimize_graph, optimize_function, compile_function) compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}" diff --git a/cubed/tests/test_executor_features.py b/cubed/tests/test_executor_features.py index f6bb5bd6..6a8dad86 100644 --- a/cubed/tests/test_executor_features.py +++ b/cubed/tests/test_executor_features.py @@ -315,3 +315,55 @@ def test_check_runtime_memory_processes(spec, executor): # OK if we use fewer workers c.compute(executor=executor, max_workers=max_workers // 2) + + +COMPILE_FUNCTIONS = [lambda fn: fn] + +try: + from numba import jit as numba_jit + COMPILE_FUNCTIONS.append(numba_jit) +except ModuleNotFoundError: + pass + +try: + if 'jax' in os.environ.get('CUBED_BACKEND_ARRAY_API_MODULE', ''): + from jax import jit as jax_jit + COMPILE_FUNCTIONS.append(jax_jit) +except ModuleNotFoundError: + pass + + +@pytest.mark.parametrize("compile_function", COMPILE_FUNCTIONS) +def test_check_compilation(spec, executor, compile_function): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec) + c = xp.add(a, b) + assert_array_equal( + c.compute(executor=executor, compile_function=compile_function), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]) + ) + + +def test_compilation_can_fail(spec, executor): + def compile_function(func): + raise NotImplementedError(f"Cannot compile {func}") + + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec) + c = xp.add(a, b) + with pytest.raises(NotImplementedError) as excinfo: + c.compute(executor=executor, compile_function=compile_function) + + assert "add" in str(excinfo.value), "Compile function was applied to add operation." + + +def test_compilation_with_config_can_fail(spec, executor): + def compile_function(func, *, config=None): + raise NotImplementedError(f"Cannot compile {func} with {config}") + + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec) + c = xp.add(a, b) + with pytest.raises(NotImplementedError) as excinfo: + c.compute(executor=executor, compile_function=compile_function) + + assert "BlockwiseSpec" in str(excinfo.value), "Compile function was applied with a config argument."