Skip to content

Commit

Permalink
Add more specialized static output shape to Eye
Browse files Browse the repository at this point in the history
Importantly, it now provides broadcastability information which is needed elsewhere
  • Loading branch information
ricardoV94 committed Jun 21, 2024
1 parent 28d9d4d commit d3bd1f1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 32 deletions.
6 changes: 5 additions & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,7 @@ def triu_indices_from(


class Eye(Op):
_output_type_depends_on_input_value = True
__props__ = ("dtype",)

def __init__(self, dtype=None):
Expand All @@ -1287,10 +1288,13 @@ def make_node(self, n, m, k):
assert n.ndim == 0
assert m.ndim == 0
assert k.ndim == 0

_, static_shape = infer_static_shape((n, m))

return Apply(
self,
[n, m, k],
[TensorType(dtype=self.dtype, shape=(None, None))()],
[TensorType(dtype=self.dtype, shape=static_shape)()],
)

def perform(self, node, inp, out_):
Expand Down
70 changes: 39 additions & 31 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,38 +937,46 @@ def test_infer_static_shape():
assert static_shape == (1,)


# This is slow for the ('int8', 3) version.
def test_eye():
def check(dtype, N, M_=None, k=0):
# PyTensor does not accept None as a tensor.
# So we must use a real value.
M = M_
# Currently DebugMode does not support None as inputs even if this is
# allowed.
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
M = N
N_symb = iscalar()
M_symb = iscalar()
k_symb = iscalar()
f = function([N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype))
result = f(N, M, k)
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
assert result.dtype == np.dtype(dtype)
class TestEye:
# This is slow for the ('int8', 3) version.
def test_basic(self):
def check(dtype, N, M_=None, k=0):
# PyTensor does not accept None as a tensor.
# So we must use a real value.
M = M_
# Currently DebugMode does not support None as inputs even if this is
# allowed.
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
M = N
N_symb = iscalar()
M_symb = iscalar()
k_symb = iscalar()
f = function(
[N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype)
)
result = f(N, M, k)
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
assert result.dtype == np.dtype(dtype)

for dtype in ALL_DTYPES:
check(dtype, 3)
# M != N, k = 0
check(dtype, 3, 5)
check(dtype, 5, 3)
# N == M, k != 0
check(dtype, 3, 3, 1)
check(dtype, 3, 3, -1)
# N < M, k != 0
check(dtype, 3, 5, 1)
check(dtype, 3, 5, -1)
# N > M, k != 0
check(dtype, 5, 3, 1)
check(dtype, 5, 3, -1)
for dtype in ALL_DTYPES:
check(dtype, 3)
# M != N, k = 0
check(dtype, 3, 5)
check(dtype, 5, 3)
# N == M, k != 0
check(dtype, 3, 3, 1)
check(dtype, 3, 3, -1)
# N < M, k != 0
check(dtype, 3, 5, 1)
check(dtype, 3, 5, -1)
# N > M, k != 0
check(dtype, 5, 3, 1)
check(dtype, 5, 3, -1)

def test_static_output_type(self):
l = lscalar("l")
assert eye(5, 3, l).type.shape == (5, 3)
assert eye(1, l, 3).type.shape == (1, None)


class TestTriangle:
Expand Down

0 comments on commit d3bd1f1

Please sign in to comment.