From a0c64b5f9261e989b27e1d5e84f129a7cb9faa0b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 21 Nov 2024 14:29:59 +0100 Subject: [PATCH] Reduce overhead of JITLinker --- pytensor/link/basic.py | 52 ++++++----------------------- pytensor/link/numba/linker.py | 17 ---------- pytensor/link/pytorch/linker.py | 26 ++++++--------- tests/link/numba/test_basic.py | 17 ++++++++++ tests/link/pytorch/test_basic.py | 29 ++++++++-------- tests/link/pytorch/test_elemwise.py | 2 +- 6 files changed, 52 insertions(+), 91 deletions(-) diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index daeaa5740f..9cf34983f2 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -653,41 +653,36 @@ def create_jitable_thunk( ) thunk_inputs = self.create_thunk_inputs(storage_map) - - thunks = [] - thunk_outputs = [storage_map[n] for n in self.fgraph.outputs] - fgraph_jit = self.jit_compile(converted_fgraph) def thunk( - fgraph=self.fgraph, fgraph_jit=fgraph_jit, thunk_inputs=thunk_inputs, thunk_outputs=thunk_outputs, ): - outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs]) + try: + outputs = fgraph_jit(*(x[0] for x in thunk_inputs)) + except Exception: + # TODO: Should we add a fake node that combines all outputs, + # since the error may come from any of them? + raise_with_op(self.fgraph, output_nodes[0], thunk) # strict=False because we are in a hot loop - for o_var, o_storage, o_val in zip( - fgraph.outputs, thunk_outputs, outputs, strict=False - ): - compute_map[o_var][0] = True - o_storage[0] = self.output_filter(o_var, o_val) - return outputs + for o_storage, o_val in zip(thunk_outputs, outputs, strict=False): + o_storage[0] = o_val thunk.inputs = thunk_inputs thunk.outputs = thunk_outputs thunk.lazy = False - thunks.append(thunk) + thunks = [thunk] return thunks, output_nodes, fgraph_jit def make_all(self, input_storage=None, output_storage=None, storage_map=None): fgraph = self.fgraph nodes = self.schedule(fgraph) - no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, nodes, input_storage, output_storage, storage_map @@ -701,34 +696,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None): compute_map, nodes, input_storage, output_storage, storage_map ) - computed, last_user = gc_helper(nodes) - - if self.allow_gc: - post_thunk_old_storage = [ - [ - storage_map[input] - for input in node.inputs - if (input in computed) - and (input not in fgraph.outputs) - and (node == last_user[input]) - ] - for node in nodes - ] - else: - post_thunk_old_storage = None - - if no_recycling is True: - no_recycling = list(storage_map.values()) - no_recycling = difference(no_recycling, input_storage) - else: - no_recycling = [ - storage_map[r] for r in no_recycling if r not in fgraph.inputs - ] - - fn = streamline( - fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling - ) - + [fn] = thunks fn.jit_fn = jit_fn fn.allow_gc = self.allow_gc fn.storage_map = storage_map diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index f120706f3b..553c5ef217 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -1,26 +1,9 @@ -from typing import TYPE_CHECKING, Any - -import numpy as np - -import pytensor from pytensor.link.basic import JITLinker -if TYPE_CHECKING: - from pytensor.graph.basic import Variable - - class NumbaLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" - def output_filter(self, var: "Variable", out: Any) -> Any: - if not isinstance(var, np.ndarray) and isinstance( - var.type, pytensor.tensor.TensorType - ): - return var.type.filter(out, allow_downcast=True) - - return out - def fgraph_convert(self, fgraph, **kwargs): from pytensor.link.numba.dispatch import numba_funcify diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index ec26fd252f..ac0b0c8c02 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,7 +1,3 @@ -import copy -from typing import Any - -from pytensor.graph.basic import Variable from pytensor.link.basic import JITLinker from pytensor.link.utils import unique_name_generator @@ -13,14 +9,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.gen_functors = [] - def input_filter(self, inp: Any) -> Any: - from pytensor.link.pytorch.dispatch import pytorch_typify - - return pytorch_typify(inp) - - def output_filter(self, var: Variable, out: Any) -> Any: - return out.cpu() - def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.pytorch.dispatch import pytorch_funcify @@ -49,6 +37,8 @@ def conversion_func_register(*args, **kwargs): def jit_compile(self, fn): import torch + from pytensor.link.pytorch.dispatch import pytorch_typify + class wrapper: """ Pytorch would fail compiling our method when trying @@ -62,7 +52,7 @@ class wrapper: def __init__(self, fn, gen_functors): self.fn = torch.compile(fn) - self.gen_functors = copy.copy(gen_functors) + self.gen_functors = gen_functors.copy() def __call__(self, *args, **kwargs): import pytensor.link.utils @@ -83,9 +73,15 @@ def __call__(self, *args, **kwargs): def __del__(self): del self.gen_functors - res = wrapper(fn, self.gen_functors) + inner_fn = wrapper(fn, self.gen_functors) self.gen_functors = [] - return res + + # Torch does not accept numpy inputs and may return GPU objects + def fn(*inputs, inner_fn=inner_fn): + outs = inner_fn(*(pytorch_typify(inp) for inp in inputs)) + return tuple(out.cpu().numpy() for out in outs) + + return fn def create_thunk_inputs(self, storage_map): thunk_inputs = [] diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index dd4c5b4967..ec88b0fd50 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -889,3 +889,20 @@ def test_cache_warning_suppressed(): x_test = np.random.uniform(size=5) np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2) + + +@pytest.mark.parametrize("mode", ("default", "trust_input", "direct")) +def test_function_overhead(mode, benchmark): + x = pt.vector("x") + out = pt.exp(x) + + fn = function([x], out, mode="NUMBA") + if mode == "trust_input": + fn.trust_input = True + elif mode == "direct": + fn = fn.vm.jit_fn + + test_x = np.zeros(1000) + assert np.sum(fn(test_x)) == 1000 + + benchmark(fn, test_x) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 83249d021b..d7e2aef47b 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -53,8 +53,6 @@ def compare_pytorch_and_py( assert_fn: func, opt Assert function used to check for equality between python and pytorch. If not provided uses np.testing.assert_allclose - must_be_device_array: Bool - Checks if torch.device.type is cuda """ @@ -66,20 +64,19 @@ def compare_pytorch_and_py( pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode) pytorch_res = pytensor_torch_fn(*test_inputs) - if must_be_device_array: - if isinstance(pytorch_res, list): - assert all(isinstance(res, torch.Tensor) for res in pytorch_res) - else: - assert pytorch_res.device.type == "cuda" + if isinstance(pytorch_res, list): + assert all(isinstance(res, np.ndarray) for res in pytorch_res) + else: + assert isinstance(pytorch_res, np.ndarray) pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) py_res = pytensor_py_fn(*test_inputs) if len(fgraph.outputs) > 1: for pytorch_res_i, py_res_i in zip(pytorch_res, py_res, strict=True): - assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i) + assert_fn(pytorch_res_i, py_res_i) else: - assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0]) + assert_fn(pytorch_res[0], py_res[0]) return pytensor_torch_fn, pytorch_res @@ -162,23 +159,23 @@ def test_shared(device): pytensor_torch_fn = function([], a, mode="PYTORCH") pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) + assert isinstance(pytorch_res, np.ndarray) assert isinstance(a.get_value(), np.ndarray) - np.testing.assert_allclose(pytorch_res.cpu(), a.get_value()) + np.testing.assert_allclose(pytorch_res, a.get_value()) pytensor_torch_fn = function([], a * 2, mode="PYTORCH") pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) + assert isinstance(pytorch_res, np.ndarray) assert isinstance(a.get_value(), np.ndarray) - np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2) + np.testing.assert_allclose(pytorch_res, a.get_value() * 2) new_a_value = np.array([3, 4, 5], dtype=config.floatX) a.set_value(new_a_value) pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) - np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2) + assert isinstance(pytorch_res, np.ndarray) + np.testing.assert_allclose(pytorch_res, new_a_value * 2) @pytest.mark.parametrize("device", ["cpu", "cuda"]) @@ -225,7 +222,7 @@ def test_alloc_and_empty(): fn = function([dim1], out, mode=pytorch_mode) res = fn(7) assert res.shape == (5, 7, 3) - assert res.dtype == torch.float32 + assert res.dtype == np.float32 v = vector("v", shape=(3,), dtype="float64") out = alloc(v, dim0, dim1, 3) diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 20c98094c1..2a9cf39c99 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -152,7 +152,7 @@ def test_cast(): _, [res] = compare_pytorch_and_py( fgraph, [np.arange(6, dtype="float32").reshape(2, 3)] ) - assert res.dtype == torch.int32 + assert res.dtype == np.int32 def test_vmap_elemwise():