From 91eeb53f463ca698ad28679f3f9a99a6f28df66d Mon Sep 17 00:00:00 2001 From: Watermelon Wolverine <29666253+watermelonwolverine@users.noreply.github.com> Date: Thu, 2 Jan 2025 16:49:32 +0100 Subject: [PATCH] dialects: (builtin) Fix TensorOrMemrefOf and add some tests (#3685) Fixed TensorOrMemrefOf and added some tests --- tests/dialects/test_builtin.py | 23 +++++++++++++++++++++++ xdsl/dialects/builtin.py | 4 ++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/tests/dialects/test_builtin.py b/tests/dialects/test_builtin.py index 1a983111fe..6068aa37f3 100644 --- a/tests/dialects/test_builtin.py +++ b/tests/dialects/test_builtin.py @@ -30,6 +30,8 @@ Signedness, StridedLayoutAttr, SymbolRefAttr, + TensorOrMemrefOf, + TensorType, UnrealizedConversionCastOp, VectorBaseTypeAndRankConstraint, VectorBaseTypeConstraint, @@ -514,3 +516,24 @@ def test_strides(): assert ShapedType.strides_for_shape((1,), factor=2) == (2,) assert ShapedType.strides_for_shape((2, 3)) == (3, 1) assert ShapedType.strides_for_shape((4, 5, 6), factor=2) == (60, 12, 2) + + +def test_tensor_or_memref_of_constraint_verify(): + constraint = TensorOrMemrefOf(i64) + + constraint.verify(MemRefType(i64, [1]), ConstraintContext()) + constraint.verify(TensorType(i64, [1]), ConstraintContext()) + + +def test_tensor_or_memref_of_constraint_attribute_mismatch(): + constraint = TensorOrMemrefOf(i64) + + with pytest.raises( + VerifyException, match=f"Expected tensor or memref type, got {i64}" + ): + constraint.verify(i64, ConstraintContext()) + + with pytest.raises( + VerifyException, match=f"Expected attribute {i64} but got {i32}" + ): + constraint.verify(MemRefType(i32, [1]), ConstraintContext()) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index e03b9be0d4..77830a2a35 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -1898,8 +1898,8 @@ def get_resolvers( } def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: - if isinstance(attr, VectorType) or isinstance(attr, TensorType): - attr = cast(VectorType[Attribute] | TensorType[Attribute], attr) + if isinstance(attr, MemRefType) or isinstance(attr, TensorType): + attr = cast(MemRefType[Attribute] | TensorType[Attribute], attr) self.elem_constr.verify(attr.element_type, constraint_context) else: raise VerifyException(f"Expected tensor or memref type, got {attr}")