From 28d9d4dc0bd0fc0060670622514b83cc281cda3e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Jun 2024 18:51:20 +0200 Subject: [PATCH] Improve static output shape of Reshape --- pytensor/tensor/shape.py | 25 ++++++++++++++++++++++++- tests/tensor/test_shape.py | 25 +++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index b01aa19465..b6fcd9fb21 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -669,7 +669,7 @@ def make_node(self, x, shp): assert shp.ndim == 1 if isinstance(shp, TensorConstant): - out_shape = tuple(int(s) if s >= 0 else None for s in shp.data) + out_shape = [int(s) if s >= 0 else None for s in shp.data] else: out_shape = [None] * self.ndim shp_list = shp_orig @@ -685,6 +685,29 @@ def make_node(self, x, shp): except NotScalarConstantError: pass + # If we only don't know the size of one output dimension, + # but we know all the input dimensions we can deduce it + # This happens often when there is -1 as an input of Reshape + if None not in x.type.shape and out_shape.count(None) == 1: + full_size = np.prod(x.type.shape) + known_size = np.prod([s for s in out_shape if s is not None]) + out_shape[out_shape.index(None)] = int(full_size // known_size) + + out_shape = tuple(out_shape) + + # Run some eager error checks + if len(out_shape) != self.ndim: + raise ValueError( + "Shape argument to Reshape has incorrect length:" + f" {len(out_shape)}, should be {self.ndim}" + ) + + if None not in x.type.shape and None not in out_shape: + if np.prod(x.type.shape) != np.prod(out_shape): + raise ValueError( + f"Reshape: Input shape {x.type.shape} is incompatible with new shape {out_shape}" + ) + return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)]) def perform(self, node, inp, out_): diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index db802268af..7fa8133c4e 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pytest @@ -353,6 +355,29 @@ def test_rebuild(self): assert tuple(y_new.shape.eval({i: i_test})) == (4, 25) assert y_new.eval({i: i_test}).shape == (4, 25) + def test_static_shape(self): + dim = lscalar("dim") + x1 = tensor(shape=(2, 2, None)) + x2 = specify_shape(x1, (2, 2, 6)) + + assert reshape(x1, (6, 2)).type.shape == (6, 2) + assert reshape(x1, (6, -1)).type.shape == (6, None) + assert reshape(x1, (6, dim)).type.shape == (6, None) + assert reshape(x1, (6, dim, 2)).type.shape == (6, None, 2) + assert reshape(x1, (6, 3, 99)).type.shape == (6, 3, 99) + + assert reshape(x2, (6, 4)).type.shape == (6, 4) + assert reshape(x2, (6, -1)).type.shape == (6, 4) + assert reshape(x2, (6, dim)).type.shape == (6, 4) + assert reshape(x2, (6, dim, 2)).type.shape == (6, 2, 2) + with pytest.raises( + ValueError, + match=re.escape( + "Reshape: Input shape (2, 2, 6) is incompatible with new shape (6, 3, 99)" + ), + ): + reshape(x2, (6, 3, 99)) + def test_shape_i_hash(): assert isinstance(Shape_i(np.int64(1)).__hash__(), int)