Skip to content

Commit

Permalink
Add set and add TensorVariable methods for set_subtensor and `i…
Browse files Browse the repository at this point in the history
…nc_subtensor` operations
  • Loading branch information
ricardoV94 committed Nov 13, 2023
1 parent a6f3f2d commit 9d5a825
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
26 changes: 26 additions & 0 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,32 @@ def compress(self, a, axis=None):
"""Return selected slices only."""
return at.extra_ops.compress(self, a, axis=axis)

def set(self, y, **kwargs):
"""Set values to y, where y is the output of an index operation.
Equivalent to set_subtensor(self, y). See docstrings for kwargs.
Examples
--------
>>> x = matrix()
>>> out = x[0].set(5)
"""
return at.subtensor.set_subtensor(self, y, **kwargs)

def add(self, y, **kwargs):
"""Add values to y, where y is the output of an index operation.
Equivalent to inc_subtensor(self, y). See docstrings for kwargs
Examples
--------
>>> x = matrix()
>>> out = x[0].add(5)
"""
return at.inc_subtensor(self, y, **kwargs)


class TensorVariable(
_tensor_py_operators, Variable[_TensorTypeType, OptionalApplyType]
Expand Down
21 changes: 20 additions & 1 deletion tests/tensor/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq, matmul
from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
from pytensor.tensor.subtensor import (
AdvancedSubtensor,
Subtensor,
inc_subtensor,
set_subtensor,
)
from pytensor.tensor.type import (
TensorType,
cscalar,
Expand Down Expand Up @@ -428,6 +433,20 @@ def test_take(self):
# Test equivalent advanced indexing
assert_array_equal(X[:, indices].eval({X: x}), x[:, indices])

def test_set_add(self):
x = matrix("x")
idx = [0]
y = 5

assert equal_computations([x[idx].set(y)], [set_subtensor(x[idx], y)])
assert equal_computations([x[idx].add(y)], [inc_subtensor(x[idx], y)])

msg = "must be the result of a subtensor operation"
with pytest.raises(TypeError, match=msg):
x.set(y)
with pytest.raises(TypeError, match=msg):
x.add(y)


def test_deprecated_import():
with pytest.warns(
Expand Down

0 comments on commit 9d5a825

Please sign in to comment.