From dfb5921d93bc48247d34a947585eb77d541ef795 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 9 Sep 2024 09:15:10 +0200 Subject: [PATCH] [FXML-4791] Lower memref expand/collapse to EmitC (#313) --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 3 +- .../MemRefToEmitC/MemRefToEmitC.cpp | 67 ++++++++++++++++++- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 21 +++++- .../MemRefToEmitC/memref-to-emitc-failed.mlir | 18 +++++ .../MemRefToEmitC/memref-to-emitc.mlir | 19 ++++++ mlir/test/Dialect/EmitC/invalid_ops.mlir | 2 +- mlir/test/Dialect/EmitC/ops.mlir | 5 ++ mlir/test/Target/Cpp/cast.mlir | 9 +++ 8 files changed, 139 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 0c945ab2c40304..bbb539c2b3f2a7 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -265,8 +265,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { def EmitC_CastOp : EmitC_Op<"cast", [CExpression, - DeclareOpInterfaceMethods, - SameOperandsAndResultShape]> { + DeclareOpInterfaceMethods]> { let summary = "Cast operation"; let description = [{ The `cast` operation performs an explicit type conversion and is emitted diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index e0c421741b3055..f6ce553dd899a0 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -15,6 +15,8 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -166,6 +168,68 @@ struct ConvertStore final : public OpConversionPattern { return success(); } }; + +struct ConvertCollapseShape final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CollapseShapeOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + auto arrayValue = dyn_cast>(operands.getSrc()); + if (!arrayValue) { + return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); + } + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot convert result type"); + } + + // Do not generate casts between arrays with dynamic shapes + if (!arrayValue.getType().hasStaticShape()) + return rewriter.notifyMatchFailure(op.getLoc(), + "dynamic shapes not supported"); + auto newCastOp = rewriter.create(op->getLoc(), resultTy, + operands.getSrc()); + newCastOp.setReference(true); + rewriter.replaceOp(op, newCastOp); + return success(); + } +}; + +struct ConvertExpandShape final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ExpandShapeOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + auto arrayValue = dyn_cast>(operands.getSrc()); + if (!arrayValue) { + return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); + } + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot convert result type"); + } + + // Do not generate casts between arrays with dynamic shapes + if (!arrayValue.getType().hasStaticShape()) + return rewriter.notifyMatchFailure(op.getLoc(), + "dynamic shapes not supported"); + + auto newCastOp = rewriter.create(op->getLoc(), resultTy, + operands.getSrc()); + newCastOp.setReference(true); + rewriter.replaceOp(op, newCastOp); + return success(); + } +}; + } // namespace void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { @@ -187,5 +251,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter) { patterns.add(converter, patterns.getContext()); + ConvertStore, ConvertCollapseShape, ConvertExpandShape>( + converter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index aa2495bc42ba03..3f994344ffeee6 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -228,6 +228,15 @@ LogicalResult emitc::AssignOp::verify() { bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { Type input = inputs.front(), output = outputs.front(); + // Cast to array is only possible from an array + if (isa(input) != isa(output)) + return false; + + // Arrays can be casted to arrays by reference. + if (isa(input) && isa(output)) + return true; + + // Scalars return ( (emitc::isIntegerIndexOrOpaqueType(input) || emitc::isSupportedFloatType(input) || isa(input)) && @@ -236,7 +245,15 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { } LogicalResult CastOp::verify() { - if (getReference()) + bool isReference = getReference(); + + if (isa(getDest().getType())) { + if (!isReference) + return emitOpError("cast of array must bear a reference"); + return success(); + } + + if (isReference) return emitOpError("cast of value type must not bear a reference"); return success(); @@ -954,6 +971,8 @@ LogicalResult emitc::ArrayType::verify( for (int64_t dim : shape) { if (dim < 0) return emitError() << "dimensions must have non-negative size"; + if (dim == ShapedType::kDynamic) + return emitError() << "dimensions must have static size"; } if (!elementType) diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir index 89dafa7529ed53..4df7bac0b55806 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir @@ -43,3 +43,21 @@ func.func @zero_rank() { // expected-error@+1 {{failed to legalize operation 'memref.global'}} memref.global "nested" constant @nested_global : memref<3x7xf32> + +// ----- + +// CHECK-LABEL: memref_expand_dyn_shape +func.func @memref_expand_dyn_shape(%arg: memref, %size: index) -> memref { + // expected-error@+1 {{failed to legalize operation 'memref.expand_shape'}} + %0 = memref.expand_shape %arg [[0, 1]] output_shape [%size, 5] : memref into memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: memref_collapse_dyn_shape +func.func @memref_collapse_dyn_shape(%arg: memref) -> memref { + // expected-error@+1 {{failed to legalize operation 'memref.collapse_shape'}} + %0 = memref.collapse_shape %arg [[0, 1]] : memref into memref + return %0 : memref +} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index ffb0e10d80893a..96e4486f5a8191 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -73,3 +73,22 @@ func.func @memref_index_values(%i: index, %j: index) -> index { // CHECK: return %[[CAST_RET]] : index return %1 : index } + +// ----- + +// CHECK-LABEL: memref_expand_shape +func.func @memref_expand_shape(%arg: memref<10xi32>) -> memref<2x5xi32> { + // CHECK: emitc.cast %{{[^ ]*}} : !emitc.array<10xi32> to !emitc.array<2x5xi32> ref + %0 = memref.expand_shape %arg [[0, 1]] output_shape [2, 5] : memref<10xi32> into memref<2x5xi32> + return %0 : memref<2x5xi32> +} + + +// ----- + +// CHECK-LABEL: memref_collapse_shape +func.func @memref_collapse_shape(%arg: memref<2x5xi32>) -> memref<10xi32> { + // CHECK: emitc.cast %{{[^ ]*}} : !emitc.array<2x5xi32> to !emitc.array<10xi32> ref + %0 = memref.collapse_shape %arg [[0, 1]] : memref<2x5xi32> into memref<10xi32> + return %0 : memref<10xi32> +} diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index aa2c969b05cc23..31e065155f0922 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -138,7 +138,7 @@ func.func @cast_tensor(%arg : tensor) { // ----- func.func @cast_array(%arg : !emitc.array<4xf32>) { - // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}} + // expected-error @+1 {{'emitc.cast' op cast of array must bear a reference}} %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> return } diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 7b11c230e9a9dd..482b08a0b68687 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -39,6 +39,11 @@ func.func @cast(%arg0: i32) { return } +func.func @cast_array(%arg : !emitc.array<4xf32>) { + %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> ref + return +} + func.func @c() { %1 = "emitc.constant"(){value = 42 : i32} : () -> i32 %2 = "emitc.constant"(){value = 42 : index} : () -> !emitc.size_t diff --git a/mlir/test/Target/Cpp/cast.mlir b/mlir/test/Target/Cpp/cast.mlir index 7254f84e237f40..c4d26ebdcdec9a 100644 --- a/mlir/test/Target/Cpp/cast.mlir +++ b/mlir/test/Target/Cpp/cast.mlir @@ -28,3 +28,12 @@ func.func @cast_ptr(%arg0 : !emitc.ptr>) { %1 = emitc.cast %arg0 : !emitc.ptr> to !emitc.ptr return } + +// CHECK-LABEL: void cast_array +func.func @cast_array(%arg0: !emitc.array<10xi32>) { + // CHECK-NEXT: int32_t (&[[V1:[^ ]*]])[2][5] = (int32_t (&)[2][5]) [[V0:[^ ]*]] + %1 = emitc.cast %arg0 : !emitc.array<10xi32> to !emitc.array<2x5xi32> ref + // CHECK-NEXT: int32_t (&[[V2:[^ ]*]])[10] = (int32_t (&)[10]) [[V1]] + %2 = emitc.cast %1 : !emitc.array<2x5xi32> to !emitc.array<10xi32> ref + return +}