From 1418998b111173cc9f4efe7239bb8f8018b4cfef Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 17:03:33 +0100 Subject: [PATCH] CollapseShapeOp --- xdsl/dialects/tensor.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/xdsl/dialects/tensor.py b/xdsl/dialects/tensor.py index 87bafbe7ec..53d456f0b4 100644 --- a/xdsl/dialects/tensor.py +++ b/xdsl/dialects/tensor.py @@ -8,10 +8,14 @@ from xdsl.dialects import memref from xdsl.dialects.builtin import ( + Annotated, AnySignlessIntegerOrIndexType, + ArrayAttr, ContainerType, DenseArrayBase, IndexType, + IntegerAttr, + IntegerType, TensorType, UnrankedTensorType, i64, @@ -176,6 +180,25 @@ def parse(cls, parser: Parser) -> Self: return empty +ReassociationAttr = ArrayAttr[ + ArrayAttr[IntegerAttr[Annotated[IntegerType, IntegerType(64)]]] +] + + +@irdl_op_definition +class CollapseShapeOp(IRDLOperation): + name = "tensor.collapse_shape" + + src = operand_def(TensorType[Attribute]) + result = result_def(TensorType[Attribute]) + reassociation = prop_def(ReassociationAttr) + assembly_format = ( + "$src $reassociation attr-dict `:` type($src) `into` type($result)" + ) + + traits = frozenset([NoMemoryEffect()]) + + @irdl_op_definition class ReshapeOp(IRDLOperation): name = "tensor.reshape" @@ -420,6 +443,7 @@ def from_static_parameters( ExtractSliceOp, InsertSliceOp, ReshapeOp, + CollapseShapeOp, ], [], )