From 4258475f8fa05676ccd08783f78aba77b024f58a Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 10 Oct 2024 10:10:41 +0200 Subject: [PATCH] Deprecate rarely used Function functionality --- pytensor/compile/function/types.py | 11 +- tests/compile/function/test_types.py | 302 ++++++++++++++++----------- 2 files changed, 190 insertions(+), 123 deletions(-) diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 1efe40d6e1..b7caff1bf4 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -387,6 +387,9 @@ def __init__( self.nodes_with_inner_function = [] self.output_keys = output_keys + if self.output_keys is not None: + warnings.warn("output_keys is deprecated.", FutureWarning) + assert len(self.input_storage) == len(self.maker.fgraph.inputs) assert len(self.output_storage) == len(self.maker.fgraph.outputs) @@ -836,8 +839,10 @@ def __call__(self, *args, **kwargs): t0 = time.perf_counter() output_subset = kwargs.pop("output_subset", None) - if output_subset is not None and self.output_keys is not None: - output_subset = [self.output_keys.index(key) for key in output_subset] + if output_subset is not None: + warnings.warn("output_subset is deprecated.", FutureWarning) + if self.output_keys is not None: + output_subset = [self.output_keys.index(key) for key in output_subset] # Reinitialize each container's 'provided' counter if self.trust_input: @@ -1560,6 +1565,8 @@ def __init__( ) for i in self.inputs ] + if any(self.refeed): + warnings.warn("Inputs with default values are deprecated.", FutureWarning) def create(self, input_storage=None, storage_map=None): """ diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index dc94022e0b..4b6537d328 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -35,6 +35,9 @@ ) +pytestmark = pytest.mark.filterwarnings("error") + + def PatternOptimizer(p1, p2, ign=True): return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) @@ -195,7 +198,10 @@ def test_naming_rule3(self): x, s = scalars("xs") # x's name is not ignored (as in test_naming_rule2) because a has a default value. - f = function([x, In(a, value=1.0), s], a / s + x) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function([x, In(a, value=1.0), s], a / s + x) assert f(9, 2, 4) == 9.5 # can specify all args in order assert f(9, 2, s=4) == 9.5 # can give s as kwarg assert f(9, s=4) == 9.25 # can give s as kwarg, get default a @@ -214,7 +220,10 @@ def test_naming_rule4(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function([x, In(a, value=1.0, name="a"), s], a / s + x) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function([x, In(a, value=1.0, name="a"), s], a / s + x) assert f(9, 2, 4) == 9.5 # can specify all args in order assert f(9, 2, s=4) == 9.5 # can give s as kwarg @@ -248,11 +257,14 @@ def test_state_access(self, mode): a = scalar() x, s = scalars("xs") - f = function( - [x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)], - s + a * x, - mode=mode, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)], + s + a * x, + mode=mode, + ) assert f[a] == 1.0 assert f[s] == 0.0 @@ -303,16 +315,19 @@ def test_copy(self): a = scalar() x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) - g = copy.copy(f) + g = copy.copy(f) assert f.unpack_single == g.unpack_single assert f.trust_input == g.trust_input @@ -504,22 +519,25 @@ def test_shared_state0(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) - g = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=f.container[s], update=s - a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) + g = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=f.container[s], update=s - a * x, mutable=True), + ], + s + a * x, + ) f(1, 2) assert f[s] == 2 @@ -532,17 +550,20 @@ def test_shared_state1(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) - g = function( - [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) + g = function( + [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x + ) f(1, 2) assert f[s] == 2 @@ -556,17 +577,20 @@ def test_shared_state2(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=False), - ], - s + a * x, - ) - g = function( - [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=False), + ], + s + a * x, + ) + g = function( + [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x + ) f(1, 2) assert f[s] == 2 @@ -718,7 +742,10 @@ def test_default_values(self): a, b = dscalars("a", "b") c = a + b - funct = function([In(a, name="first"), In(b, value=1, name="second")], c) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + funct = function([In(a, name="first"), In(b, value=1, name="second")], c) x = funct(first=1) try: funct(second=2) @@ -775,7 +802,8 @@ def test_output_dictionary(self): # Tests that function works when outputs is a dictionary x = scalar() - f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4}) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4}) outputs = f(10.0) @@ -790,7 +818,8 @@ def test_input_named_variables(self): x = scalar("x") y = scalar("y") - f = function([x, y], outputs={"a": x + y, "b": x * y}) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + f = function([x, y], outputs={"a": x + y, "b": x * y}) assert f(2, 4) == {"a": 6, "b": 8} assert f(2, y=4) == f(2, 4) @@ -805,9 +834,10 @@ def test_output_order_sorted(self): e1 = scalar("1") e2 = scalar("2") - f = function( - [x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2} - ) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + f = function( + [x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2} + ) assert "1" in str(f.outputs[0]) assert "2" in str(f.outputs[1]) @@ -825,7 +855,8 @@ def test_composing_function(self): a = x + y b = x * y - f = function([x, y], outputs={"a": a, "b": b}) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + f = function([x, y], outputs={"a": a, "b": b}) a = scalar("a") b = scalar("b") @@ -880,14 +911,17 @@ def test_deepcopy(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a", mutable=True), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a", mutable=True), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) try: g = copy.deepcopy(f) except NotImplementedError as e: @@ -941,14 +975,17 @@ def test_deepcopy_trust_input(self): a = dscalar() # the a is for 'anonymous' (un-named). x, s = dscalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) f.trust_input = True try: g = copy.deepcopy(f) @@ -967,11 +1004,13 @@ def test_deepcopy_trust_input(self): def test_output_keys(self): x = vector() - f = function([x], {"vec": x**2}) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + f = function([x], {"vec": x**2}) o = f([2, 3, 4]) assert isinstance(o, dict) assert np.allclose(o["vec"], [4, 9, 16]) - g = copy.deepcopy(f) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + g = copy.deepcopy(f) o = g([2, 3, 4]) assert isinstance(o, dict) assert np.allclose(o["vec"], [4, 9, 16]) @@ -980,7 +1019,10 @@ def test_deepcopy_shared_container(self): # Ensure that shared containers remain shared after a deep copy. a, x = scalars("ax") - h = function([In(a, value=0.0)], a) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + h = function([In(a, value=0.0)], a) f = function([x, In(a, value=h.container[a], implicit=True)], x + a) try: @@ -1004,14 +1046,17 @@ def test_pickle(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) try: # Note that here we also test protocol 0 on purpose, since it @@ -1105,25 +1150,31 @@ def test_multiple_functions(self): # some derived thing, whose inputs aren't all in the list list_of_things.append(a * x + s) - f1 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f1 = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) list_of_things.append(f1) # now put in a function sharing container with the previous one - f2 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=f1.container[s], update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f2 = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=f1.container[s], update=s + a * x, mutable=True), + ], + s + a * x, + ) list_of_things.append(f2) assert isinstance(f2.container[s].storage, list) @@ -1131,7 +1182,10 @@ def test_multiple_functions(self): # now put in a function with non-scalar v_value = np.asarray([2, 3, 4.0], dtype=config.floatX) - f3 = function([x, In(v, value=v_value)], x + v) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f3 = function([x, In(v, value=v_value)], x + v) list_of_things.append(f3) # try to pickle the entire things @@ -1263,23 +1317,29 @@ def __init__(self): self.e = a * x + s - self.f1 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + self.f1 = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) - self.f2 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=self.f1.container[s], update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + self.f2 = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=self.f1.container[s], update=s + a * x, mutable=True), + ], + s + a * x, + ) def test_empty_givens_updates():