Skip to content

Commit

Permalink
Reduce overhead of JITLinker
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 29, 2024
1 parent d1c5ae2 commit a0c64b5
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 91 deletions.
52 changes: 10 additions & 42 deletions pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,41 +653,36 @@ def create_jitable_thunk(
)

thunk_inputs = self.create_thunk_inputs(storage_map)

thunks = []

thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

fgraph_jit = self.jit_compile(converted_fgraph)

def thunk(
fgraph=self.fgraph,
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
try:
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
except Exception:
# TODO: Should we add a fake node that combines all outputs,
# since the error may come from any of them?
raise_with_op(self.fgraph, output_nodes[0], thunk)

# strict=False because we are in a hot loop
for o_var, o_storage, o_val in zip(
fgraph.outputs, thunk_outputs, outputs, strict=False
):
compute_map[o_var][0] = True
o_storage[0] = self.output_filter(o_var, o_val)
return outputs
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
o_storage[0] = o_val

thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False

thunks.append(thunk)
thunks = [thunk]

return thunks, output_nodes, fgraph_jit

def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling

input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
Expand All @@ -701,34 +696,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
compute_map, nodes, input_storage, output_storage, storage_map
)

computed, last_user = gc_helper(nodes)

if self.allow_gc:
post_thunk_old_storage = [
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
for node in nodes
]
else:
post_thunk_old_storage = None

if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]

fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)

[fn] = thunks
fn.jit_fn = jit_fn
fn.allow_gc = self.allow_gc
fn.storage_map = storage_map
Expand Down
17 changes: 0 additions & 17 deletions pytensor/link/numba/linker.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,9 @@
from typing import TYPE_CHECKING, Any

import numpy as np

import pytensor
from pytensor.link.basic import JITLinker


if TYPE_CHECKING:
from pytensor.graph.basic import Variable


class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""

def output_filter(self, var: "Variable", out: Any) -> Any:
if not isinstance(var, np.ndarray) and isinstance(
var.type, pytensor.tensor.TensorType
):
return var.type.filter(out, allow_downcast=True)

return out

def fgraph_convert(self, fgraph, **kwargs):
from pytensor.link.numba.dispatch import numba_funcify

Expand Down
26 changes: 11 additions & 15 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import copy
from typing import Any

from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker
from pytensor.link.utils import unique_name_generator

Expand All @@ -13,14 +9,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []

def input_filter(self, inp: Any) -> Any:
from pytensor.link.pytorch.dispatch import pytorch_typify

return pytorch_typify(inp)

def output_filter(self, var: Variable, out: Any) -> Any:
return out.cpu()

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify

Expand Down Expand Up @@ -49,6 +37,8 @@ def conversion_func_register(*args, **kwargs):
def jit_compile(self, fn):
import torch

from pytensor.link.pytorch.dispatch import pytorch_typify

class wrapper:
"""
Pytorch would fail compiling our method when trying
Expand All @@ -62,7 +52,7 @@ class wrapper:

def __init__(self, fn, gen_functors):
self.fn = torch.compile(fn)
self.gen_functors = copy.copy(gen_functors)
self.gen_functors = gen_functors.copy()

def __call__(self, *args, **kwargs):
import pytensor.link.utils
Expand All @@ -83,9 +73,15 @@ def __call__(self, *args, **kwargs):
def __del__(self):
del self.gen_functors

res = wrapper(fn, self.gen_functors)
inner_fn = wrapper(fn, self.gen_functors)
self.gen_functors = []
return res

# Torch does not accept numpy inputs and may return GPU objects
def fn(*inputs, inner_fn=inner_fn):
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
return tuple(out.cpu().numpy() for out in outs)

return fn

def create_thunk_inputs(self, storage_map):
thunk_inputs = []
Expand Down
17 changes: 17 additions & 0 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,3 +889,20 @@ def test_cache_warning_suppressed():

x_test = np.random.uniform(size=5)
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2)


@pytest.mark.parametrize("mode", ("default", "trust_input", "direct"))
def test_function_overhead(mode, benchmark):
x = pt.vector("x")
out = pt.exp(x)

fn = function([x], out, mode="NUMBA")
if mode == "trust_input":
fn.trust_input = True
elif mode == "direct":
fn = fn.vm.jit_fn

test_x = np.zeros(1000)
assert np.sum(fn(test_x)) == 1000

benchmark(fn, test_x)
29 changes: 13 additions & 16 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def compare_pytorch_and_py(
assert_fn: func, opt
Assert function used to check for equality between python and pytorch. If not
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks if torch.device.type is cuda
"""
Expand All @@ -66,20 +64,19 @@ def compare_pytorch_and_py(
pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode)
pytorch_res = pytensor_torch_fn(*test_inputs)

if must_be_device_array:
if isinstance(pytorch_res, list):
assert all(isinstance(res, torch.Tensor) for res in pytorch_res)
else:
assert pytorch_res.device.type == "cuda"
if isinstance(pytorch_res, list):
assert all(isinstance(res, np.ndarray) for res in pytorch_res)
else:
assert isinstance(pytorch_res, np.ndarray)

pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)

if len(fgraph.outputs) > 1:
for pytorch_res_i, py_res_i in zip(pytorch_res, py_res, strict=True):
assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i)
assert_fn(pytorch_res_i, py_res_i)
else:
assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])
assert_fn(pytorch_res[0], py_res[0])

return pytensor_torch_fn, pytorch_res

Expand Down Expand Up @@ -162,23 +159,23 @@ def test_shared(device):
pytensor_torch_fn = function([], a, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()

assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(pytorch_res, np.ndarray)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value())
np.testing.assert_allclose(pytorch_res, a.get_value())

pytensor_torch_fn = function([], a * 2, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()

assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(pytorch_res, np.ndarray)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2)
np.testing.assert_allclose(pytorch_res, a.get_value() * 2)

new_a_value = np.array([3, 4, 5], dtype=config.floatX)
a.set_value(new_a_value)

pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2)
assert isinstance(pytorch_res, np.ndarray)
np.testing.assert_allclose(pytorch_res, new_a_value * 2)


@pytest.mark.parametrize("device", ["cpu", "cuda"])
Expand Down Expand Up @@ -225,7 +222,7 @@ def test_alloc_and_empty():
fn = function([dim1], out, mode=pytorch_mode)
res = fn(7)
assert res.shape == (5, 7, 3)
assert res.dtype == torch.float32
assert res.dtype == np.float32

v = vector("v", shape=(3,), dtype="float64")
out = alloc(v, dim0, dim1, 3)
Expand Down
2 changes: 1 addition & 1 deletion tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_cast():
_, [res] = compare_pytorch_and_py(
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
)
assert res.dtype == torch.int32
assert res.dtype == np.int32


def test_vmap_elemwise():
Expand Down

0 comments on commit a0c64b5

Please sign in to comment.