From 9d5a8256e6a31cc63c310c880458054bec4097d5 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 13 Nov 2023 16:12:35 +0100 Subject: [PATCH] Add `set` and `add` TensorVariable methods for `set_subtensor` and `inc_subtensor` operations --- pytensor/tensor/variable.py | 26 ++++++++++++++++++++++++++ tests/tensor/test_variable.py | 21 ++++++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 57804a204a..91115ad4ae 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -815,6 +815,32 @@ def compress(self, a, axis=None): """Return selected slices only.""" return at.extra_ops.compress(self, a, axis=axis) + def set(self, y, **kwargs): + """Set values to y, where y is the output of an index operation. + + Equivalent to set_subtensor(self, y). See docstrings for kwargs. + + Examples + -------- + + >>> x = matrix() + >>> out = x[0].set(5) + """ + return at.subtensor.set_subtensor(self, y, **kwargs) + + def add(self, y, **kwargs): + """Add values to y, where y is the output of an index operation. + + Equivalent to inc_subtensor(self, y). See docstrings for kwargs + + Examples + -------- + + >>> x = matrix() + >>> out = x[0].add(5) + """ + return at.inc_subtensor(self, y, **kwargs) + class TensorVariable( _tensor_py_operators, Variable[_TensorTypeType, OptionalApplyType] diff --git a/tests/tensor/test_variable.py b/tests/tensor/test_variable.py index b43cb2c4e4..3ffbdbd50a 100644 --- a/tests/tensor/test_variable.py +++ b/tests/tensor/test_variable.py @@ -14,7 +14,12 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import dot, eq, matmul from pytensor.tensor.shape import Shape -from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor +from pytensor.tensor.subtensor import ( + AdvancedSubtensor, + Subtensor, + inc_subtensor, + set_subtensor, +) from pytensor.tensor.type import ( TensorType, cscalar, @@ -428,6 +433,20 @@ def test_take(self): # Test equivalent advanced indexing assert_array_equal(X[:, indices].eval({X: x}), x[:, indices]) + def test_set_add(self): + x = matrix("x") + idx = [0] + y = 5 + + assert equal_computations([x[idx].set(y)], [set_subtensor(x[idx], y)]) + assert equal_computations([x[idx].add(y)], [inc_subtensor(x[idx], y)]) + + msg = "must be the result of a subtensor operation" + with pytest.raises(TypeError, match=msg): + x.set(y) + with pytest.raises(TypeError, match=msg): + x.add(y) + def test_deprecated_import(): with pytest.warns(