Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change behavior of helper set/inc to act on an indexed variable directly #730

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,36 +824,46 @@ def compress(self, a, axis=None):
"""Return selected slices only."""
return pt.extra_ops.compress(self, a, axis=axis)

def set(self, idx, y, **kwargs):
"""Return a copy of self with the indexed values set to y.
def set(self, y, **kwargs):
"""Return a copy of the variable indexed by self with the indexed values set to y.

Equivalent to set_subtensor(self[idx], y). See docstrings for kwargs.
Equivalent to set_subtensor(self, y). See docstrings for kwargs.

Raises
------
TypeError:
If self is not the result of a subtensor operation

Examples
--------
>>> import pytensor.tensor as pt
>>>
>>> x = pt.ones((3,))
>>> out = x.set(1, 2)
>>> out = x[1].set(2)
>>> out.eval() # array([1., 2., 1.])
"""
return pt.subtensor.set_subtensor(self[idx], y, **kwargs)
return pt.subtensor.set_subtensor(self, y, **kwargs)

def inc(self, y, **kwargs):
"""Return a copy of the variable indexed by self with the indexed values incremented by y.

def inc(self, idx, y, **kwargs):
"""Return a copy of self with the indexed values incremented by y.
Equivalent to inc_subtensor(self, y). See docstrings for kwargs.

Equivalent to inc_subtensor(self[idx], y). See docstrings for kwargs.
Raises
------
TypeError:
If self is not the result of a subtensor operation

Examples
--------

>>> import pytensor.tensor as pt
>>>
>>> x = pt.ones((3,))
>>> out = x.inc(1, 2)
>>> out = x[1].inc(2)
>>> out.eval() # array([1., 3., 1.])
"""
return pt.inc_subtensor(self[idx], y, **kwargs)
return pt.inc_subtensor(self, y, **kwargs)


class TensorVariable(
Expand Down
4 changes: 2 additions & 2 deletions tests/tensor/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ def test_set_inc(self):
idx = [0]
y = 5

assert equal_computations([x.set(idx, y)], [set_subtensor(x[idx], y)])
assert equal_computations([x.inc(idx, y)], [inc_subtensor(x[idx], y)])
assert equal_computations([x[:, idx].set(y)], [set_subtensor(x[:, idx], y)])
assert equal_computations([x[:, idx].inc(y)], [inc_subtensor(x[:, idx], y)])

def test_set_item_error(self):
x = matrix("x")
Expand Down
Loading