From f5b76e41223dc535f70051a06a2862ea1f1a2995 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 25 Apr 2024 19:49:30 +0200 Subject: [PATCH] Change behavior of helper set/inc to act on an indexed variable directly --- pytensor/tensor/variable.py | 30 ++++++++++++++++++++---------- tests/tensor/test_variable.py | 4 ++-- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 6100108380..e881331017 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -824,25 +824,35 @@ def compress(self, a, axis=None): """Return selected slices only.""" return pt.extra_ops.compress(self, a, axis=axis) - def set(self, idx, y, **kwargs): - """Return a copy of self with the indexed values set to y. + def set(self, y, **kwargs): + """Return a copy of the variable indexed by self with the indexed values set to y. - Equivalent to set_subtensor(self[idx], y). See docstrings for kwargs. + Equivalent to set_subtensor(self, y). See docstrings for kwargs. + + Raises + ------ + TypeError: + If self is not the result of a subtensor operation Examples -------- >>> import pytensor.tensor as pt >>> >>> x = pt.ones((3,)) - >>> out = x.set(1, 2) + >>> out = x[1].set(2) >>> out.eval() # array([1., 2., 1.]) """ - return pt.subtensor.set_subtensor(self[idx], y, **kwargs) + return pt.subtensor.set_subtensor(self, y, **kwargs) + + def inc(self, y, **kwargs): + """Return a copy of the variable indexed by self with the indexed values incremented by y. - def inc(self, idx, y, **kwargs): - """Return a copy of self with the indexed values incremented by y. + Equivalent to inc_subtensor(self, y). See docstrings for kwargs. - Equivalent to inc_subtensor(self[idx], y). See docstrings for kwargs. + Raises + ------ + TypeError: + If self is not the result of a subtensor operation Examples -------- @@ -850,10 +860,10 @@ def inc(self, idx, y, **kwargs): >>> import pytensor.tensor as pt >>> >>> x = pt.ones((3,)) - >>> out = x.inc(1, 2) + >>> out = x[1].inc(2) >>> out.eval() # array([1., 3., 1.]) """ - return pt.inc_subtensor(self[idx], y, **kwargs) + return pt.inc_subtensor(self, y, **kwargs) class TensorVariable( diff --git a/tests/tensor/test_variable.py b/tests/tensor/test_variable.py index 70ad04999a..50c36a05fc 100644 --- a/tests/tensor/test_variable.py +++ b/tests/tensor/test_variable.py @@ -438,8 +438,8 @@ def test_set_inc(self): idx = [0] y = 5 - assert equal_computations([x.set(idx, y)], [set_subtensor(x[idx], y)]) - assert equal_computations([x.inc(idx, y)], [inc_subtensor(x[idx], y)]) + assert equal_computations([x[:, idx].set(y)], [set_subtensor(x[:, idx], y)]) + assert equal_computations([x[:, idx].inc(y)], [inc_subtensor(x[:, idx], y)]) def test_set_item_error(self): x = matrix("x")