Skip to content

Commit

Permalink
Improve static output shape of Reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 21, 2024
1 parent 734009a commit 28d9d4d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
25 changes: 24 additions & 1 deletion pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def make_node(self, x, shp):
assert shp.ndim == 1

if isinstance(shp, TensorConstant):
out_shape = tuple(int(s) if s >= 0 else None for s in shp.data)
out_shape = [int(s) if s >= 0 else None for s in shp.data]
else:
out_shape = [None] * self.ndim
shp_list = shp_orig
Expand All @@ -685,6 +685,29 @@ def make_node(self, x, shp):
except NotScalarConstantError:
pass

# If we only don't know the size of one output dimension,
# but we know all the input dimensions we can deduce it
# This happens often when there is -1 as an input of Reshape
if None not in x.type.shape and out_shape.count(None) == 1:
full_size = np.prod(x.type.shape)
known_size = np.prod([s for s in out_shape if s is not None])
out_shape[out_shape.index(None)] = int(full_size // known_size)

out_shape = tuple(out_shape)

# Run some eager error checks
if len(out_shape) != self.ndim:
raise ValueError(
"Shape argument to Reshape has incorrect length:"
f" {len(out_shape)}, should be {self.ndim}"
)

if None not in x.type.shape and None not in out_shape:
if np.prod(x.type.shape) != np.prod(out_shape):
raise ValueError(
f"Reshape: Input shape {x.type.shape} is incompatible with new shape {out_shape}"
)

return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])

def perform(self, node, inp, out_):
Expand Down
25 changes: 25 additions & 0 deletions tests/tensor/test_shape.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import numpy as np
import pytest

Expand Down Expand Up @@ -353,6 +355,29 @@ def test_rebuild(self):
assert tuple(y_new.shape.eval({i: i_test})) == (4, 25)
assert y_new.eval({i: i_test}).shape == (4, 25)

def test_static_shape(self):
dim = lscalar("dim")
x1 = tensor(shape=(2, 2, None))
x2 = specify_shape(x1, (2, 2, 6))

assert reshape(x1, (6, 2)).type.shape == (6, 2)
assert reshape(x1, (6, -1)).type.shape == (6, None)
assert reshape(x1, (6, dim)).type.shape == (6, None)
assert reshape(x1, (6, dim, 2)).type.shape == (6, None, 2)
assert reshape(x1, (6, 3, 99)).type.shape == (6, 3, 99)

assert reshape(x2, (6, 4)).type.shape == (6, 4)
assert reshape(x2, (6, -1)).type.shape == (6, 4)
assert reshape(x2, (6, dim)).type.shape == (6, 4)
assert reshape(x2, (6, dim, 2)).type.shape == (6, 2, 2)
with pytest.raises(
ValueError,
match=re.escape(
"Reshape: Input shape (2, 2, 6) is incompatible with new shape (6, 3, 99)"
),
):
reshape(x2, (6, 3, 99))


def test_shape_i_hash():
assert isinstance(Shape_i(np.int64(1)).__hash__(), int)
Expand Down

0 comments on commit 28d9d4d

Please sign in to comment.