Skip to content

Commit

Permalink
Implement Cast in PyTorch backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed Oct 10, 2024
1 parent be6a032 commit be358ed
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
11 changes: 11 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
ScalarOp,
)

Expand Down Expand Up @@ -38,3 +39,13 @@ def pytorch_func(*args):
)

return pytorch_func


@pytorch_funcify.register(Cast)
def pytorch_funcify_Cast(op: Cast, node, **kwargs):
dtype = getattr(torch, op.o_type.dtype)

def cast(x):
return x.to(dtype=dtype)

return cast
13 changes: 13 additions & 0 deletions tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from tests.link.pytorch.test_basic import compare_pytorch_and_py


torch = pytest.importorskip("torch")


def test_pytorch_Dimshuffle():
a_pt = matrix("a")

Expand Down Expand Up @@ -137,3 +140,13 @@ def test_softmax_grad(axis):
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_pytorch_and_py(fgraph, [dy_value, sm_value])


def test_cast():
x = matrix("x", dtype="float32")
out = pt.cast(x, "int32")
fgraph = FunctionGraph([x], [out])
_, [res] = compare_pytorch_and_py(
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
)
assert res.dtype == torch.int32

0 comments on commit be358ed

Please sign in to comment.