Skip to content

Commit

Permalink
dialects: (builtin) Fix TensorOrMemrefOf and add some tests (#3685)
Browse files Browse the repository at this point in the history
Fixed TensorOrMemrefOf and added some tests
  • Loading branch information
watermelonwolverine authored Jan 2, 2025
1 parent 346e1d4 commit 91eeb53
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
23 changes: 23 additions & 0 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
Signedness,
StridedLayoutAttr,
SymbolRefAttr,
TensorOrMemrefOf,
TensorType,
UnrealizedConversionCastOp,
VectorBaseTypeAndRankConstraint,
VectorBaseTypeConstraint,
Expand Down Expand Up @@ -514,3 +516,24 @@ def test_strides():
assert ShapedType.strides_for_shape((1,), factor=2) == (2,)
assert ShapedType.strides_for_shape((2, 3)) == (3, 1)
assert ShapedType.strides_for_shape((4, 5, 6), factor=2) == (60, 12, 2)


def test_tensor_or_memref_of_constraint_verify():
constraint = TensorOrMemrefOf(i64)

constraint.verify(MemRefType(i64, [1]), ConstraintContext())
constraint.verify(TensorType(i64, [1]), ConstraintContext())


def test_tensor_or_memref_of_constraint_attribute_mismatch():
constraint = TensorOrMemrefOf(i64)

with pytest.raises(
VerifyException, match=f"Expected tensor or memref type, got {i64}"
):
constraint.verify(i64, ConstraintContext())

with pytest.raises(
VerifyException, match=f"Expected attribute {i64} but got {i32}"
):
constraint.verify(MemRefType(i32, [1]), ConstraintContext())
4 changes: 2 additions & 2 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,8 +1898,8 @@ def get_resolvers(
}

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if isinstance(attr, VectorType) or isinstance(attr, TensorType):
attr = cast(VectorType[Attribute] | TensorType[Attribute], attr)
if isinstance(attr, MemRefType) or isinstance(attr, TensorType):
attr = cast(MemRefType[Attribute] | TensorType[Attribute], attr)
self.elem_constr.verify(attr.element_type, constraint_context)
else:
raise VerifyException(f"Expected tensor or memref type, got {attr}")
Expand Down

0 comments on commit 91eeb53

Please sign in to comment.