From 0a94521865389269d6b7c5db5e6c37ceda7a8371 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Mon, 15 Jul 2024 08:24:53 +0100 Subject: [PATCH 1/6] feat(torch.aten.mm): fold up-casts into matmul when supported in TOSA --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 60 ++++++++++++++++++++-- test/Conversion/TorchToTosa/basic.mlir | 40 +++++++++++++++ 2 files changed, 96 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index a25bbe402a73..776f900f5de9 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -24,6 +24,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/TypeSwitch.h" #include #include @@ -1132,6 +1133,34 @@ Type getMatMulOutputType(Type inputElemTy, PatternRewriter &rewriter) { return outputElemTy; } +RankedTensorType getCastedInputTypeForMatmul(Value inputValue, + PatternRewriter &rewriter) { + // Check to see if the inputs to the matmul where casted from another type + auto preCastType = + TypeSwitch(inputValue.getDefiningOp()) + .Case([](AtenToDtypeOp op) { + return cast(op->getOperand(0).getType()); + }) + .Case([](tosa::CastOp op) { + return cast(op->getOperand(0).getType()); + }) + .Default([](Operation * /*op*/) { return RankedTensorType(); }); + if (!preCastType) { + return preCastType; + } + // Calculate the expected accumulator type based on the input type of the cast + auto accumulatorType = + getMatMulOutputType(preCastType.getElementType(), rewriter); + // If the expected accumulatorType for the given input type to the cast + // matches the output type of the cast then we can fold the casting into the + // matmul. Because the casting is an up-cast and does not affect the numeric + // values due to rounding or saturation. + return accumulatorType == + cast(inputValue.getType()).getElementType() + ? preCastType + : RankedTensorType(); +} + // Perform the basic n-dim matmul operation encompassing the handling of // broadcasting and dynamic shape propagation. // All PyTorch ops that leverage matrix multiplication will derive this and @@ -1173,6 +1202,28 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Matmul: input datatypes mismatched"); + // Step: check if the inputs have been casted from a supported input type to + // an accumulator type and insert casts back to the original type if true + RankedTensorType lhsPreCastedType = + getCastedInputTypeForMatmul(lhs, rewriter); + RankedTensorType rhsPreCastedType = + getCastedInputTypeForMatmul(rhs, rewriter); + if (lhsPreCastedType && (lhsPreCastedType.getElementType() == + rhsPreCastedType.getElementType())) { + lhs = rewriter.create( + lhs.getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + lhsPreCastedType), + lhs); + rhs = rewriter.create( + rhs.getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + rhsPreCastedType), + rhs); + lhsElemTy = cast(lhsPreCastedType).getElementType(); + rhsElemTy = cast(rhsPreCastedType).getElementType(); + } + auto outputElemTy = getMatMulOutputType(lhsElemTy, rewriter); if (!outputElemTy) { return rewriter.notifyMatchFailure( @@ -1565,12 +1616,13 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { matmulLhs, matmulRhs) .getResult(); + auto torchOpOutputType = lhsTy.getElementType(); auto castOutputTy = RankedTensorType::get( - makeShapeLLVMCompatible(matmulOutputShape), lhsElemTy); + makeShapeLLVMCompatible(matmulOutputShape), torchOpOutputType); auto castResult = rewriter.createOrFold( op->getLoc(), - OpConversionPattern::getTypeConverter() - ->convertType(castOutputTy), + OpConversionPattern::getTypeConverter()->convertType( + castOutputTy), mmOpResult); // Perform the reshape to output shape. This is always required unless max @@ -1673,7 +1725,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Perform reshape auto reshapedOpType = RankedTensorType::get( - makeShapeLLVMCompatible(reshapedOpShape), lhsElemTy); + makeShapeLLVMCompatible(reshapedOpShape), outputElemTy); auto reshapedOp = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 317b5c9efe86..b02c96d23832 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -85,6 +85,46 @@ func.func @torch.aten.mm_2d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !torch.vt // ----- +// CHECK: tosa.matmul{{.*}} : (tensor<1x4x8xbf16>, tensor<1x8x16xbf16>) -> tensor<1x4x16xf32> +func.func @torch.aten.mm_bf16(%arg0: !torch.vtensor<[4,8],bf16>, %arg1: !torch.vtensor<[8,16],bf16>) -> !torch.vtensor<[4,16],f32> { + %false = torch.constant.bool false + %none = torch.constant.none + %int6 = torch.constant.int 6 + %0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[4,8],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],f32> + %1 = torch.aten.to.dtype %arg1, %int6, %false, %false, %none : !torch.vtensor<[8,16],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],f32> + %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[4,16],f32> + return %2 : !torch.vtensor<[4,16],f32> +} + +// ----- + +// CHECK: tosa.matmul{{.*}} : (tensor<1x4x8xf16>, tensor<1x8x16xf16>) -> tensor<1x4x16xf32> +func.func @torch.aten.mm_f16(%arg0: !torch.vtensor<[4,8],f16>, %arg1: !torch.vtensor<[8,16],f16>) -> !torch.vtensor<[4,16],f32> { + %false = torch.constant.bool false + %none = torch.constant.none + %int6 = torch.constant.int 6 + %0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[4,8],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],f32> + %1 = torch.aten.to.dtype %arg1, %int6, %false, %false, %none : !torch.vtensor<[8,16],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],f32> + %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[4,16],f32> + return %2 : !torch.vtensor<[4,16],f32> +} + + +// ----- + +// CHECK: tosa.matmul{{.*}} : (tensor<1x4x8xi8>, tensor<1x8x16xi8>) -> tensor<1x4x16xi32> +func.func @torch.aten.mm_i8(%arg0: !torch.vtensor<[4,8],si8>, %arg1: !torch.vtensor<[8,16],si8>) -> !torch.vtensor<[4,16],si32> { + %false = torch.constant.bool false + %none = torch.constant.none + %int3 = torch.constant.int 3 + %0 = torch.aten.to.dtype %arg0, %int3, %false, %false, %none : !torch.vtensor<[4,8],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],si32> + %1 = torch.aten.to.dtype %arg1, %int3, %false, %false, %none : !torch.vtensor<[8,16],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],si32> + %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],si32>, !torch.vtensor<[8,16],si32> -> !torch.vtensor<[4,16],si32> + return %2 : !torch.vtensor<[4,16],si32> +} + +// ----- + // CHECK: %[[VAL_2:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<10x10x6x2xf32>) -> tensor<100x6x2xf32> // CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<10x10x2x6xf32>) -> tensor<100x2x6xf32> // CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<100x6x2xf32>, tensor<100x2x6xf32>) -> tensor<100x6x6xf32> From e637e1496c58ab71c45a572b1f4d3d368461fed1 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Mon, 15 Jul 2024 15:52:36 +0100 Subject: [PATCH 2/6] fix: address PR comments --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 776f900f5de9..6c527f8b8736 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1151,10 +1151,10 @@ RankedTensorType getCastedInputTypeForMatmul(Value inputValue, // Calculate the expected accumulator type based on the input type of the cast auto accumulatorType = getMatMulOutputType(preCastType.getElementType(), rewriter); - // If the expected accumulatorType for the given input type to the cast + // If the expected accumulatorType for the given input type of the cast // matches the output type of the cast then we can fold the casting into the - // matmul. Because the casting is an up-cast and does not affect the numeric - // values due to rounding or saturation. + // matmul. The tosa matmul is defined to cast the inputs to the output type + // first, so we do not need explicit casts up front. return accumulatorType == cast(inputValue.getType()).getElementType() ? preCastType @@ -1208,8 +1208,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { getCastedInputTypeForMatmul(lhs, rewriter); RankedTensorType rhsPreCastedType = getCastedInputTypeForMatmul(rhs, rewriter); - if (lhsPreCastedType && (lhsPreCastedType.getElementType() == - rhsPreCastedType.getElementType())) { + if (lhsPreCastedType && rhsPreCastedType && + (lhsPreCastedType.getElementType() == + rhsPreCastedType.getElementType())) { lhs = rewriter.create( lhs.getLoc(), OpConversionPattern::getTypeConverter()->convertType( @@ -1725,7 +1726,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Perform reshape auto reshapedOpType = RankedTensorType::get( - makeShapeLLVMCompatible(reshapedOpShape), outputElemTy); + makeShapeLLVMCompatible(reshapedOpShape), torchOpOutputType); auto reshapedOp = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( @@ -1741,7 +1742,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { /*shape=*/{static_cast(transposedOpDims.size())}); auto transposedOpType = RankedTensorType::get( - makeShapeLLVMCompatible(transposedOpShape), outputElemTy); + makeShapeLLVMCompatible(transposedOpShape), torchOpOutputType); output = rewriter .create( op->getLoc(), From 511bf68352f702d9f3d45209d06f05fd1c7ac78d Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Mon, 15 Jul 2024 15:53:09 +0100 Subject: [PATCH 3/6] test(TorchToTosa): add more torch.aten.mm cases --- test/Conversion/TorchToTosa/basic.mlir | 34 +++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index b02c96d23832..7b106d7bb907 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -109,7 +109,6 @@ func.func @torch.aten.mm_f16(%arg0: !torch.vtensor<[4,8],f16>, %arg1: !torch.vte return %2 : !torch.vtensor<[4,16],f32> } - // ----- // CHECK: tosa.matmul{{.*}} : (tensor<1x4x8xi8>, tensor<1x8x16xi8>) -> tensor<1x4x16xi32> @@ -125,6 +124,39 @@ func.func @torch.aten.mm_i8(%arg0: !torch.vtensor<[4,8],si8>, %arg1: !torch.vten // ----- +// expected-error @+1 {{invalid dtype 'si48' for !torch.tensor type}} +func.func @torch.aten.mm_i16(%arg0: !torch.vtensor<[4,8],si16>, %arg1: !torch.vtensor<[8,16],si16>) -> !torch.vtensor<[4,16],si48> { + %false = torch.constant.bool false + %none = torch.constant.none + %int3 = torch.constant.int 3 + %0 = torch.aten.to.dtype %arg0, %int3, %false, %false, %none : !torch.vtensor<[4,8],si16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],si48> + %1 = torch.aten.to.dtype %arg1, %int3, %false, %false, %none : !torch.vtensor<[8,16],si16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],si48> + %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],si48>, !torch.vtensor<[8,16],si48> -> !torch.vtensor<[4,16],si48> + return %2 : !torch.vtensor<[4,16],si48> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.cast %{{[0-9]+}} : (tensor<4x8xf32>) -> tensor<4x8xf16> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.cast %{{[0-9]+}} : (tensor<8x16xf32>) -> tensor<8x16xf16> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4x8xf16>) -> tensor<1x4x8xf16> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<8x16xf16>) -> tensor<1x8x16xf16> +// CHECK-NEXT: %[[VAL_6:.+]] = tosa.matmul %[[VAL_4]], %[[VAL_5]] : (tensor<1x4x8xf16>, tensor<1x8x16xf16>) -> tensor<1x4x16xf32> +// CHECK-NEXT: %[[VAL_7:.+]] = tosa.cast %[[VAL_6]] : (tensor<1x4x16xf32>) -> tensor<1x4x16xf16> +// CHECK-NEXT: %[[VAL_8:.+]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x4x16xf16>) -> tensor<4x16xf16> + +func.func @torch.aten.mm_f32_to_f16(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[4,16],f16> { + %false = torch.constant.bool false + %none = torch.constant.none + %int5 = torch.constant.int 5 + %0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[4,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],f16> + %1 = torch.aten.to.dtype %arg1, %int5, %false, %false, %none : !torch.vtensor<[8,16],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],f16> + %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],f16>, !torch.vtensor<[8,16],f16> -> !torch.vtensor<[4,16],f16> + return %2 : !torch.vtensor<[4,16],f16> +} + +// ----- + // CHECK: %[[VAL_2:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<10x10x6x2xf32>) -> tensor<100x6x2xf32> // CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<10x10x2x6xf32>) -> tensor<100x2x6xf32> // CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<100x6x2xf32>, tensor<100x2x6xf32>) -> tensor<100x6x6xf32> From 0ef1774479fcc0fb6f0f668ec20fcabeb0e3fde1 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Tue, 16 Jul 2024 10:26:02 +0100 Subject: [PATCH 4/6] refactor(TorchToTosa): aten.mm if f16 use tosa.matmul(f16, f16) -> f16 rather than the f32 accumulator --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 21 ++++++++++++++------- test/Conversion/TorchToTosa/basic.mlir | 8 +++----- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 6c527f8b8736..070865b827c5 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1116,21 +1116,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -Type getMatMulOutputType(Type inputElemTy, PatternRewriter &rewriter) { - Type outputElemTy; +Type getMatMulOutputType(Type inputElemTy, Type outputElemTy, + PatternRewriter &rewriter) { + Type tosaOutputElemTy; if (auto floatTy = dyn_cast(inputElemTy)) { + if (inputElemTy.isF16() && outputElemTy.isF16()) { + return rewriter.getF16Type(); + } if (floatTy.isBF16() || floatTy.isF16() || floatTy.isF32()) { // Always accumulate on f32 - outputElemTy = rewriter.getF32Type(); + tosaOutputElemTy = rewriter.getF32Type(); } } else if (auto integerTy = dyn_cast(inputElemTy)) { if (integerTy.isInteger(/*width=*/8)) { - outputElemTy = rewriter.getIntegerType(/*width=*/32); + tosaOutputElemTy = rewriter.getIntegerType(/*width=*/32); } else if (integerTy.isInteger(/*width=*/16)) { - outputElemTy = rewriter.getIntegerType(/*width=*/48); + tosaOutputElemTy = rewriter.getIntegerType(/*width=*/48); } } - return outputElemTy; + return tosaOutputElemTy; } RankedTensorType getCastedInputTypeForMatmul(Value inputValue, @@ -1225,7 +1229,10 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { rhsElemTy = cast(rhsPreCastedType).getElementType(); } - auto outputElemTy = getMatMulOutputType(lhsElemTy, rewriter); + auto torchMatmulOutputType = + cast(op.getType()).getDtype(); + auto outputElemTy = + getMatMulOutputType(lhsElemTy, torchMatmulOutputType, rewriter); if (!outputElemTy) { return rewriter.notifyMatchFailure( op, "Only i8 and i16 integer and bf16, f16 and " diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 7b106d7bb907..c4d804c7bc63 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -141,9 +141,8 @@ func.func @torch.aten.mm_i16(%arg0: !torch.vtensor<[4,8],si16>, %arg1: !torch.vt // CHECK-NEXT: %[[VAL_3:.+]] = tosa.cast %{{[0-9]+}} : (tensor<8x16xf32>) -> tensor<8x16xf16> // CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4x8xf16>) -> tensor<1x4x8xf16> // CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<8x16xf16>) -> tensor<1x8x16xf16> -// CHECK-NEXT: %[[VAL_6:.+]] = tosa.matmul %[[VAL_4]], %[[VAL_5]] : (tensor<1x4x8xf16>, tensor<1x8x16xf16>) -> tensor<1x4x16xf32> -// CHECK-NEXT: %[[VAL_7:.+]] = tosa.cast %[[VAL_6]] : (tensor<1x4x16xf32>) -> tensor<1x4x16xf16> -// CHECK-NEXT: %[[VAL_8:.+]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x4x16xf16>) -> tensor<4x16xf16> +// CHECK-NEXT: %[[VAL_6:.+]] = tosa.matmul %[[VAL_4]], %[[VAL_5]] : (tensor<1x4x8xf16>, tensor<1x8x16xf16>) -> tensor<1x4x16xf16> +// CHECK-NEXT: %[[VAL_7:.+]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x4x16xf16>) -> tensor<4x16xf16> func.func @torch.aten.mm_f32_to_f16(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[4,16],f16> { %false = torch.constant.bool false @@ -215,8 +214,7 @@ func.func @torch.aten.matmul_3d_broadcast(%arg0 : !torch.vtensor<[100,4,8],f32>, // ----- -// CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xf16>, tensor<100x8x16xf16>) -> tensor<100x4x16xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = tosa.cast %[[VAL_2]] : (tensor<100x4x16xf32>) -> tensor<100x4x16xf16> +// CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xf16>, tensor<100x8x16xf16>) -> tensor<100x4x16xf16> func.func @torch.aten.bmm_3d_fp16(%arg0 : !torch.vtensor<[100,4,8],f16>, %arg1 : !torch.vtensor<[100,8,16],f16>) -> !torch.vtensor<[100,4,16],f16> { %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],f16>, !torch.vtensor<[100,8,16],f16> -> !torch.vtensor<[100,4,16],f16> return %0 : !torch.vtensor<[100,4,16],f16> From 8e10629e6e7f4dced45a6ab741d495c7925eae5b Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Tue, 16 Jul 2024 10:26:44 +0100 Subject: [PATCH 5/6] refactor(TorchToTosa): add guard for aten.mm si16->si48 --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 25 ++++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 070865b827c5..eb564bb29bd1 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1152,17 +1152,24 @@ RankedTensorType getCastedInputTypeForMatmul(Value inputValue, if (!preCastType) { return preCastType; } + Type castOutputTy = + cast(inputValue.getType()).getElementType(); + // The FxImporter does not support si48 and neither does torch-mlir so for now + // we reject this case for the future when the dialect and importer may + // support it. + if (castOutputTy.isInteger(48) && + (castOutputTy.isSignedInteger() || castOutputTy.isSignlessInteger())) { + return RankedTensorType(); + } // Calculate the expected accumulator type based on the input type of the cast auto accumulatorType = - getMatMulOutputType(preCastType.getElementType(), rewriter); - // If the expected accumulatorType for the given input type of the cast - // matches the output type of the cast then we can fold the casting into the - // matmul. The tosa matmul is defined to cast the inputs to the output type - // first, so we do not need explicit casts up front. - return accumulatorType == - cast(inputValue.getType()).getElementType() - ? preCastType - : RankedTensorType(); + getMatMulOutputType(preCastType.getElementType(), castOutputTy, rewriter); + // If the expected accumulatorType for the given input type of the + // cast matches the output type of the cast then we can fold the + // casting into the matmul. The tosa matmul is defined to cast the + // inputs to the output type first, so we do not need explicit + // casts up front. + return accumulatorType == castOutputTy ? preCastType : RankedTensorType(); } // Perform the basic n-dim matmul operation encompassing the handling of From f0eb1b21f9448df0449e25e9814623e0310ccd8a Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Tue, 16 Jul 2024 10:27:49 +0100 Subject: [PATCH 6/6] refactor(TorchToTosa): remove AtenToDtype case as the op was already converted to tosa --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index eb564bb29bd1..a1faef63b6d2 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1142,9 +1142,6 @@ RankedTensorType getCastedInputTypeForMatmul(Value inputValue, // Check to see if the inputs to the matmul where casted from another type auto preCastType = TypeSwitch(inputValue.getDefiningOp()) - .Case([](AtenToDtypeOp op) { - return cast(op->getOperand(0).getType()); - }) .Case([](tosa::CastOp op) { return cast(op->getOperand(0).getType()); })