From 5f7cb9e253406572d3a9442a103b33b67704ad08 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 11 May 2024 15:33:37 +0800 Subject: [PATCH 1/5] =?UTF-8?q?[Stablehlo]=20lowering=20aten.randn=20&=20a?= =?UTF-8?q?ten.normal=5Ffunctional=20to=20mhlo.rng=20=E2=80=A6=20(#3328)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …NORMAL * split lowering of uniform, randn, normal from Basic.cpp into Rng.cpp --- lib/Conversion/TorchToStablehlo/Basic.cpp | 32 +--- .../TorchToStablehlo/CMakeLists.txt | 1 + .../TorchToStablehlo/PopulatePatterns.h | 5 + lib/Conversion/TorchToStablehlo/Rng.cpp | 137 ++++++++++++++++++ .../TorchToStablehlo/TorchToStablehlo.cpp | 2 + test/Conversion/TorchToStablehlo/basic.mlir | 27 ---- test/Conversion/TorchToStablehlo/rng.mlir | 78 ++++++++++ 7 files changed, 224 insertions(+), 58 deletions(-) create mode 100644 lib/Conversion/TorchToStablehlo/Rng.cpp create mode 100644 test/Conversion/TorchToStablehlo/rng.mlir diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index f5844e442d29..377795d843d9 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1819,36 +1819,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenUniformOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value self = adaptor.getSelf(); - Value generator = adaptor.getGenerator(); - Location loc = op.getLoc(); - - if (!isa(generator.getType())) - return rewriter.notifyMatchFailure( - op, "The generator has to be None because only global default " - "generator is supported"); - - auto elements = cast(self.getType()).getShape(); - if (llvm::any_of(elements, - [](int64_t dim) { return dim == ShapedType::kDynamic; })) - return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); - auto shape_tensor = rewriter.create( - loc, rewriter.getI64TensorAttr(elements)); - auto outTy = getTypeConverter()->convertType(op.getType()); - auto outElemTy = cast(outTy).getElementType(); - Value from = - hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy); - Value to = - hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy); - rewriter.replaceOpWithNewOp( - op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM); - return success(); -} - // Converts `aten.empty.memory_format` to `tensor.empty` op. template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -2240,7 +2210,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); - INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFlipOp); diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt index 566f1d15b6ad..b200063e1785 100644 --- a/lib/Conversion/TorchToStablehlo/CMakeLists.txt +++ b/lib/Conversion/TorchToStablehlo/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo Linear.cpp ViewLike.cpp Reduction.cpp + Rng.cpp Pooling.cpp Utils.cpp diff --git a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h index fc28acfde29f..112d5d0ed374 100644 --- a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h @@ -62,6 +62,11 @@ void populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options); +void populateRngOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToStablehloOptions &options); + } // namespace torch_to_stablehlo } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToStablehlo/Rng.cpp b/lib/Conversion/TorchToStablehlo/Rng.cpp new file mode 100644 index 000000000000..06448794dddd --- /dev/null +++ b/lib/Conversion/TorchToStablehlo/Rng.cpp @@ -0,0 +1,137 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" + +#include "../PassDetail.h" +#include "./PopulatePatterns.h" + +#include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; +using namespace mlir::torch::torch_to_stablehlo; + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUniformOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.getSelf(); + Value generator = adaptor.getGenerator(); + Location loc = op.getLoc(); + + if (!isa(generator.getType())) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + auto elements = cast(self.getType()).getShape(); + if (llvm::any_of(elements, + [](int64_t dim) { return dim == ShapedType::kDynamic; })) + return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); + auto shape_tensor = rewriter.create( + loc, rewriter.getI64TensorAttr(elements)); + auto outTy = getTypeConverter()->convertType(op.getType()); + auto outElemTy = cast(outTy).getElementType(); + Value from = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy); + Value to = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy); + rewriter.replaceOpWithNewOp( + op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRandnGeneratorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value generator = adaptor.getGenerator(); + Location loc = op.getLoc(); + + if (!isa(generator.getType())) { + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + } + llvm::SmallVector shape; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) { + return rewriter.notifyMatchFailure(op, "size must be constant"); + } + + auto outTy = getTypeConverter()->convertType(op.getType()); + auto outElemTy = cast(outTy).getElementType(); + auto scalarTy = RankedTensorType::get({}, outElemTy); + if (!isa(outElemTy)) { + return rewriter.notifyMatchFailure(op, + "only support output with float type"); + } + + Value shapeTensor = rewriter.create( + loc, rewriter.getI64TensorAttr(shape)); + Value mean = rewriter.create( + loc, DenseFPElementsAttr::get(scalarTy, 0.0)); + Value var = rewriter.create( + loc, DenseFPElementsAttr::get(scalarTy, 1.0)); + + rewriter.replaceOpWithNewOp( + op, outTy, mean, var, shapeTensor, stablehlo::RngDistribution::NORMAL); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenNormalFunctionalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.getSelf(); + Value generator = adaptor.getGenerator(); + Location loc = op.getLoc(); + + if (!isa(generator.getType())) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + auto elements = cast(self.getType()).getShape(); + if (llvm::any_of(elements, + [](int64_t dim) { return dim == ShapedType::kDynamic; })) + return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); + auto shapeTensor = rewriter.create( + loc, rewriter.getI64TensorAttr(elements)); + auto outTy = getTypeConverter()->convertType(op.getType()); + auto outElemTy = cast(outTy).getElementType(); + Value mean = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getMean(), outElemTy); + Value std = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStd(), outElemTy); + rewriter.replaceOpWithNewOp( + op, outTy, mean, std, shapeTensor, stablehlo::RngDistribution::NORMAL); + return success(); +} + +void mlir::torch::torch_to_stablehlo::populateRngOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToStablehloOptions &options) { + MLIRContext *context = patterns.getContext(); + +#define INSERT_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + + INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenRandnGeneratorOp); + INSERT_ATENOP_PATTERN(AtenNormalFunctionalOp); +#undef INSERT_ATENOP_PATTERN +} diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index 4bcc02344e7d..9a3360bf9069 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -75,6 +75,8 @@ class ConvertTorchToStablehlo typeConverter, patterns, target, options); torch_to_stablehlo::populatePoolingOpPatternsAndLegality( typeConverter, patterns, target, options); + torch_to_stablehlo::populateRngOpPatternsAndLegality( + typeConverter, patterns, target, options); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index d8ec0fa6495f..30f8716ebdf0 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -291,33 +291,6 @@ func.func @torch.runtime.assert(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // ----- -// CHECK-LABEL: func.func @torch.aten.uniform( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00 -// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]] -// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]] -// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64> -// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64> -// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64> -// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor -// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64> -// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64> -// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor -// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor, tensor, tensor<2xi64>) -> tensor<32x64xf64> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64> -func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> { - %none = torch.constant.none - %float0 = torch.constant.float 0.0 - %float1 = torch.constant.float 1.0 - %0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64> - return %0 : !torch.vtensor<[32, 64],f64> -} - -// ----- - // CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> { diff --git a/test/Conversion/TorchToStablehlo/rng.mlir b/test/Conversion/TorchToStablehlo/rng.mlir new file mode 100644 index 000000000000..66beee5456d5 --- /dev/null +++ b/test/Conversion/TorchToStablehlo/rng.mlir @@ -0,0 +1,78 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.uniform( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 0.000000e+00 +// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]] +// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]] +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64> +// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64> +// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64> +// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor +// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64> +// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64> +// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor +// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = UNIFORM : (tensor, tensor, tensor<2xi64>) -> tensor<32x64xf64> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64> +func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vtensor<[32, 64],f64> { + %none = torch.constant.none + %float0 = torch.constant.float 0.0 + %float1 = torch.constant.float 1.0 + %0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64> + return %0 : !torch.vtensor<[32, 64],f64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.randn.generator +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT32:.*]] = torch.constant.int 32 +// CHECK: %[[INT64:.*]] = torch.constant.int 64 +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct +// CHECK: %[[SHAPE:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64> +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[RNG:.*]] = stablehlo.rng %[[VAL_0]], %[[VAL_1]], %[[SHAPE]], distribution = NORMAL : (tensor, tensor, tensor<2xi64>) -> tensor<32x64xf64> +// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[RNG]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64> +// CHECK: return %[[RET]] : !torch.vtensor<[32,64],f64> +func.func @torch.aten.randn.generator() -> !torch.vtensor<[32, 64],f64> { + %none = torch.constant.none + %int32 = torch.constant.int 32 + %int64 = torch.constant.int 64 + %size = torch.prim.ListConstruct %int32, %int64 : (!torch.int, !torch.int) -> !torch.list + %0 = torch.aten.randn.generator %size, %none, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[32, 64], f64> + return %0 : !torch.vtensor<[32, 64],f64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.normal_functional( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FLOAT_0:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_0:.*]] = torch_c.to_f64 %[[FLOAT_0]] +// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[VAL_1:.*]] = torch_c.to_f64 %[[FLOAT_1]] +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64> +// CHECK: %[[ELEM_0:.*]] = tensor.from_elements %[[VAL_0]] : tensor<1xf64> +// CHECK: %[[VAL_3:.*]] = stablehlo.convert %[[ELEM_0]] : tensor<1xf64> +// CHECK: %[[VAL_4:.*]] = stablehlo.reshape %[[VAL_3]] : (tensor<1xf64>) -> tensor +// CHECK: %[[ELEM_1:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xf64> +// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[ELEM_1]] : tensor<1xf64> +// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf64>) -> tensor +// CHECK: %[[VAL_7:.*]] = stablehlo.rng %[[VAL_4]], %[[VAL_6]], %[[VAL_2]], distribution = NORMAL : (tensor, tensor, tensor<2xi64>) -> tensor<32x64xf64> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<32x64xf64> -> !torch.vtensor<[32,64],f64> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[32,64],f64> +func.func @torch.aten.normal_functional(%arg0: !torch.vtensor<[32, 64], f64>) -> !torch.vtensor<[32, 64], f64> { + %none = torch.constant.none + %mean = torch.constant.float 2.0 + %std = torch.constant.float 1.0 + %0 = torch.aten.normal_functional %arg0, %mean, %std, %none : !torch.vtensor<[32, 64], f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64], f64> + return %0 : !torch.vtensor<[32, 64],f64> +} From 0b7cbf5e601cb9b2b646df7ab19957ba4293d6c7 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 11 May 2024 17:40:04 +0800 Subject: [PATCH 2/5] [Stablehlo] fix aten.randn's lowering with f32 element type (#3329) --- lib/Conversion/TorchToStablehlo/Rng.cpp | 9 ++++++--- test/Conversion/TorchToStablehlo/rng.mlir | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Rng.cpp b/lib/Conversion/TorchToStablehlo/Rng.cpp index 06448794dddd..3cd440c957e9 100644 --- a/lib/Conversion/TorchToStablehlo/Rng.cpp +++ b/lib/Conversion/TorchToStablehlo/Rng.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "./PopulatePatterns.h" +#include "mlir/IR/BuiltinTypes.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -73,18 +74,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outTy = getTypeConverter()->convertType(op.getType()); auto outElemTy = cast(outTy).getElementType(); - auto scalarTy = RankedTensorType::get({}, outElemTy); if (!isa(outElemTy)) { return rewriter.notifyMatchFailure(op, "only support output with float type"); } + auto scalarTy = RankedTensorType::get({}, outElemTy); Value shapeTensor = rewriter.create( loc, rewriter.getI64TensorAttr(shape)); Value mean = rewriter.create( - loc, DenseFPElementsAttr::get(scalarTy, 0.0)); + loc, + DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 0.0))); Value var = rewriter.create( - loc, DenseFPElementsAttr::get(scalarTy, 1.0)); + loc, + DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 1.0))); rewriter.replaceOpWithNewOp( op, outTy, mean, var, shapeTensor, stablehlo::RngDistribution::NORMAL); diff --git a/test/Conversion/TorchToStablehlo/rng.mlir b/test/Conversion/TorchToStablehlo/rng.mlir index 66beee5456d5..31241caacb28 100644 --- a/test/Conversion/TorchToStablehlo/rng.mlir +++ b/test/Conversion/TorchToStablehlo/rng.mlir @@ -52,6 +52,28 @@ func.func @torch.aten.randn.generator() -> !torch.vtensor<[32, 64],f64> { // ----- +// CHECK-LABEL: func.func @torch.aten.randn.generator$f32 +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT32:.*]] = torch.constant.int 32 +// CHECK: %[[INT64:.*]] = torch.constant.int 64 +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct +// CHECK: %[[SHAPE:.*]] = stablehlo.constant dense<[32, 64]> : tensor<2xi64> +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[RNG:.*]] = stablehlo.rng %[[VAL_0]], %[[VAL_1]], %[[SHAPE]], distribution = NORMAL : (tensor, tensor, tensor<2xi64>) -> tensor<32x64xf32> +// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[RNG]] : tensor<32x64xf32> -> !torch.vtensor<[32,64],f32> +// CHECK: return %[[RET]] : !torch.vtensor<[32,64],f32> +func.func @torch.aten.randn.generator$f32() -> !torch.vtensor<[32, 64],f32> { + %none = torch.constant.none + %int32 = torch.constant.int 32 + %int64 = torch.constant.int 64 + %size = torch.prim.ListConstruct %int32, %int64 : (!torch.int, !torch.int) -> !torch.list + %0 = torch.aten.randn.generator %size, %none, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[32, 64], f32> + return %0 : !torch.vtensor<[32, 64],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.normal_functional( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[32,64],f64>) -> !torch.vtensor<[32,64],f64> { // CHECK: %[[NONE:.*]] = torch.constant.none From 75d1d72059cf2731ddfd5e44f8646cd8cb6ebe66 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Sun, 12 May 2024 22:49:59 -0500 Subject: [PATCH 3/5] Generalize Operand Quantization in FuseQuantizeOps (#3327) This change enables more customization with operand quantization, and generalizes the patterns QuantizeOperands and QuantizeTransposeOperands to QuantizeOperandsPastCommutingOps. This allows for passing quantization through operations which are functionally unaffected by quantization, such as view-like ops. The purpose of this change is to address a myriad of quantization issues seen in quantized onnx models that have some reshape-like operations sandwiched in between a dequant and something like a matmul (whose other operand is immediately quantizable). --- .../Torch/Transforms/FuseQuantizedOps.cpp | 181 ++++++++++-------- test/Dialect/Torch/fuse-quantized-ops.mlir | 84 ++++++-- 2 files changed, 168 insertions(+), 97 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 0c352d31ca80..7870ff63cb40 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -13,6 +13,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -27,98 +28,112 @@ template struct QuantInfo { template <> struct QuantInfo { static constexpr unsigned operandsToQuantize[1] = {0}; }; -template -class QuantizeOperands : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SrcOp op, - PatternRewriter &rewriter) const override { - llvm::SmallVector operands(op->getOperands()); - - bool dequanted = false; - auto f = [&dequanted](Value operand) { - if (auto dequant = operand.getDefiningOp()) { - operand = dequant.getOperand(); - dequanted = true; - } - if (auto dequant = operand.getDefiningOp()) { - operand = dequant.getOperand(); - dequanted = true; - } - return operand; - }; - - for (unsigned i : QuantInfo::operandsToQuantize) { - operands[i] = f(operands[i]); - } - if (!dequanted) { - return rewriter.notifyMatchFailure(op, "no dequantizations found"); - } - - rewriter.replaceOpWithNewOp(op, op.getType(), operands); - return success(); - } -}; +// A QCommutingOp is an Op satisfying: +// 1. Has at most one tensor operand at index 0 +// 2. Has a single output, which is a tensor +// 3. Satisfies the commutation relation: +// [MPTQT -> Dequant -> Op(float)] = [Op(int) -> MPTQT -> Dequant] +// where MPTQT = "Aten_MakePerTensorQuantizedTensorOp" +// and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp" +bool isQCommutingOp(mlir::Operation *op) { + // if adding a new commuting op here, be sure to add a + // RemoveUnused pattern for that op to clean up afterwards + return llvm::isa(op); +} -template -class QuantizeTransposedOperands : public OpRewritePattern { +// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant +// -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... -> +// Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops +// {Op1,Op2,...,Opk} with k <= depth. +// With depth = 0, this conversion will simply fuse any immediately quantizable +// operands: [MPTQT -> Dequant -> SrcOp (float operands)] to [MPTQT -> SrcOp(int +// operands)] +template +class QuantizeOperandsPastCommutingOps : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); llvm::SmallVector operands(op->getOperands()); - unsigned numOperands = operands.size(); bool dequanted = false; - for (unsigned i = 0; i < numOperands; i++) { - if (auto trans = operands[i].getDefiningOp()) { - auto transOperands = trans.getOperands(); - Value dequantOperand; - if (auto dequant = - transOperands[0].getDefiningOp()) { - dequantOperand = dequant.getOperand(); - if (auto quant = - dequantOperand - .getDefiningOp()) { - auto quantOperands = quant.getOperands(); - auto qType = quantOperands[0] - .getType() - .cast() - .getOptionalDtype(); - auto torchQType = - cast(quant.getType()).getOptionalDtype(); - auto transQTy = - rewriter.getType(trans.getResult() - .getType() - .cast() - .getOptionalSizes(), - qType); - auto newQuantTy = - rewriter.getType(trans.getResult() - .getType() - .cast() - .getOptionalSizes(), - torchQType); - Value newTrans = rewriter.create( - op.getLoc(), transQTy, quantOperands[0], transOperands[1], - transOperands[2]); - Value newQuant = - rewriter.create( - op.getLoc(), newQuantTy, newTrans, quantOperands[1], - quantOperands[2]); - operands[i] = newQuant; - dequanted = true; - } + + for (unsigned i : QuantInfo::operandsToQuantize) { + Value operand = operands[i]; + std::stack commutingOpStack; + Value dequantOpd, MPTQTOpd; + for (unsigned k = 0; k < depth + 1; k++) { + auto currOp = operand.getDefiningOp(); + // Case 0 : currOp is a nullptr (e.g., operand is a block argument) + if (!currOp) + break; + // Case 1 : currOp is a q commuting op (continue loop) + if (isQCommutingOp(currOp)) { + commutingOpStack.push(currOp); + // set operand to currOp for next k-iteration + operand = currOp->getOperand(0); + continue; + } + // Case 2 : currOp is a dequant op (end loop) + if (llvm::isa(currOp)) { + dequantOpd = currOp->getOperand(0); + auto MPTQTOp = + dequantOpd.getDefiningOp(); + MPTQTOpd = MPTQTOp.getOperand(0); } + // either a dequant was found or chain broken, so break loop + break; + } + + // move to next operand if this trace was unsuccessful + if (!MPTQTOpd) + continue; + + // a successful trace occured, so set dequant to true + dequanted = true; + + // rewrite stack + Value oldOpd = MPTQTOpd; + Type intDType = + cast(MPTQTOpd.getType()).getOptionalDtype(); + while (!commutingOpStack.empty()) { + // get front of the commuting op stack and replace its first operand + // with oldOpd + auto currOp = commutingOpStack.top(); + commutingOpStack.pop(); + llvm::SmallVector currOperands(currOp->getOperands()); + currOperands[0] = oldOpd; + // get new result type + auto oldType = cast(currOp->getResultTypes()[0]); + auto intType = + rewriter.getType(oldType.getSizes(), intDType); + // rewrite currOp to have new operands and result type + // store this as oldOpd for next loop + oldOpd = rewriter + .create(loc, (currOp->getName()).getIdentifier(), + currOperands, intType, currOp->getAttrs()) + ->getResult(0); } + + // stack is empty, so oldOpd is now the corrected verion of the + // SrcOp's original operand + // convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp + auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands(); + auto qTorchType = + cast(dequantOpd.getType()).getOptionalDtype(); + auto newMPTQTType = rewriter.getType( + cast(operands[i].getType()).getSizes(), qTorchType); + operands[i] = rewriter.create( + loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]); } + if (!dequanted) { - return rewriter.notifyMatchFailure( - op, "no dequantized transpose inputs found."); + return rewriter.notifyMatchFailure(op, "No dequantizations found."); } + rewriter.replaceOpWithNewOp(op, op.getType(), operands); return success(); } @@ -356,11 +371,13 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { RemoveUnused, RemoveUnused, RemoveUnused, - RemoveUnused, QuantizeOperands, - QuantizeOperands, QuantizeOperands, - QuantizeTransposedOperands, - QuantizeAccumulator, QuantizeOperands, - QuantizeTransposedOperands, QuantizeAccumulator, + RemoveUnused, RemoveUnused, + RemoveUnused, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeAccumulator, QuantizeAccumulator, QuantizeResultLikeOperand, QuantizeBias>( context); diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir index f98cb842f5d3..594295d4e86d 100644 --- a/test/Dialect/Torch/fuse-quantized-ops.mlir +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -28,6 +28,60 @@ func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si // ----- +// CHECK-LABEL: @matmul_commuting +func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.vtensor<[1,1024,1024],f32> { + %float5.000000e-01 = torch.constant.float 5.000000e-01 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int-128 = torch.constant.int -128 + %int2 = torch.constant.int 2 + %int128 = torch.constant.int 128 + %int1024 = torch.constant.int 1024 + %int12 = torch.constant.int 12 + %0 = torch.aten._make_per_tensor_quantized_tensor %arg0, %float5.000000e-01, %int-128 : !torch.vtensor<[2,128,32,32],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,128,32,32],!torch.qint8> + %1 = torch.aten.dequantize.self %0 : !torch.vtensor<[2,128,32,32],!torch.qint8> -> !torch.vtensor<[2,128,32,32],f32> + %2 = torch.aten.slice.Tensor %1, %int0, %int0, %int1, %int1 : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + %3 = torch.aten.slice.Tensor %1, %int0, %int1, %int2, %int1 : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + %4 = torch.prim.ListConstruct %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5 = torch.aten.reshape %2, %4 : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + %6 = torch.aten.reshape %3, %4 : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + %7 = torch.aten.transpose.int %5, %int1, %int2 : !torch.vtensor<[1,128,1024],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],f32> + %8 = torch.aten.quantize_per_tensor %7, %float5.000000e-01, %int0, %int12 : !torch.vtensor<[1,1024,128],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + %9 = torch.aten.int_repr %8 : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],si8> + %10 = torch.aten._make_per_tensor_quantized_tensor %9, %float5.000000e-01, %int0 : !torch.vtensor<[1,1024,128],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + %11 = torch.aten.dequantize.self %10 : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],f32> + %12 = torch.aten.matmul %11, %6 : !torch.vtensor<[1,1024,128],f32>, !torch.vtensor<[1,128,1024],f32> -> !torch.vtensor<[1,1024,1024],f32> + + // CHECK-DAG: %[[QUARTER:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[IN128:.+]] = torch.constant.int -128 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[I128:.+]] = torch.constant.int 128 + // CHECK-DAG: %[[I1024:.+]] = torch.constant.int 1024 + // CHECK-DAG: %[[I12:.+]] = torch.constant.int 12 + // CHECK-DAG: %[[MPTQT0:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[IN128]] : !torch.vtensor<[2,128,32,32],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,128,32,32],!torch.qint8> + // CHECK-DAG: %[[DQ0:.+]] = torch.aten.dequantize.self %[[MPTQT0]] : !torch.vtensor<[2,128,32,32],!torch.qint8> -> !torch.vtensor<[2,128,32,32],f32> + // CHECK-DAG: %[[SLICE0:.+]] = torch.aten.slice.Tensor %[[DQ0]], %[[I0]], %[[I0]], %[[I1]], %[[I1]] : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[I1]], %[[I128]], %[[I1024]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[RESHAPE0:.+]] = torch.aten.reshape %[[SLICE0]], %[[LIST]] : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + // CHECK-DAG: %[[TR0:.+]] = torch.aten.transpose.int %[[RESHAPE0]], %[[I1]], %[[I2]] : !torch.vtensor<[1,128,1024],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],f32> + // CHECK-DAG: %[[Q0:.+]] = torch.aten.quantize_per_tensor %[[TR0]], %[[HALF]], %[[I0]], %[[I12]] : !torch.vtensor<[1,1024,128],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + // CHECK-DAG: %[[IR0:.+]] = torch.aten.int_repr %[[Q0]] : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],si8> + // CHECK-DAG: %[[MPTQT1:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[IR0]], %[[HALF]], %[[I0]] : !torch.vtensor<[1,1024,128],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + // CHECK-DAG: %[[SLICE1:.+]] = torch.aten.slice.Tensor %arg0, %[[I0]], %[[I1]], %[[I2]], %[[I1]] : !torch.vtensor<[2,128,32,32],si8>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],si8> + // CHECK-DAG: %[[RESHAPE1:.+]] = torch.aten.reshape %[[SLICE1]], %[[LIST]] : !torch.vtensor<[1,128,32,32],si8>, !torch.list -> !torch.vtensor<[1,128,1024],si8> + // CHECK-DAG: %[[MPTQT2:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[RESHAPE1]], %[[HALF]], %[[IN128]] : !torch.vtensor<[1,128,1024],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1024],!torch.qint8> + // CHECK-DAG: %[[MATMUL:.+]] = torch.aten.matmul %[[MPTQT1]], %[[MPTQT2]] : !torch.vtensor<[1,1024,128],!torch.qint8>, !torch.vtensor<[1,128,1024],!torch.qint8> -> !torch.vtensor<[1,1024,1024],!torch.qint32> + // CHECK-DAG: %[[IR1:.+]] = torch.aten.int_repr %[[MATMUL]] : !torch.vtensor<[1,1024,1024],!torch.qint32> -> !torch.vtensor<[1,1024,1024],si32> + // CHECK-DAG: %[[MPTQT3:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[IR1]], %[[QUARTER]], %[[I0]] : !torch.vtensor<[1,1024,1024],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,1024],!torch.qint32> + // CHECK-DAG: %[[DQ1:.+]] = torch.aten.dequantize.tensor %[[MPTQT3]] : !torch.vtensor<[1,1024,1024],!torch.qint32> -> !torch.vtensor<[1,1024,1024],f32> + return %12 : !torch.vtensor<[1,1024,1024],f32> +} + +// ----- + // CHECK-LABEL: @convolution_bias func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { %scale = torch.constant.float 0.5 @@ -43,21 +97,21 @@ func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch. %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list %16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],f32> - // CHECK: %[[DTYPE:.+]] = torch.constant.int 14 - // CHECK: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 - // CHECK: %[[HALF:.+]] = torch.constant.float 5.000000e-01 - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[ZERO:.+]] = torch.constant.int 0 - // CHECK: %[[ONE:.+]] = torch.constant.int 1 - // CHECK: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> - // CHECK: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> - // CHECK: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> - // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> - // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> - // CHECK: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> - // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> + // CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14 + // CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + // CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> + // CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> + // CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> + // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK-DAG: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> return %16 : !torch.vtensor<[1,3,7,7],f32> } From 20d4d16d32fcc23707fff08a62de4c7a59127c74 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Tue, 14 May 2024 00:45:19 +0800 Subject: [PATCH 4/5] [FxImporter] Add an e2e test example for FxImporter (#3331) --- README.md | 17 ++++++ projects/pt1/examples/_example_utils.py | 52 ++++++++++++++++ projects/pt1/examples/fximporter_resnet18.py | 59 ++++++++++++++++++ projects/pt1/examples/torchscript_resnet18.py | 61 ++++--------------- 4 files changed, 141 insertions(+), 48 deletions(-) create mode 100644 projects/pt1/examples/_example_utils.py create mode 100644 projects/pt1/examples/fximporter_resnet18.py diff --git a/README.md b/README.md index 70268ba729f0..b9d7a47595fa 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,23 @@ pip install torch-mlir -f https://github.com/llvm/torch-mlir-release/releases/ex ## Demos +### FxImporter ResNet18 +```shell +# Get the latest example if you haven't checked out the code +wget https://raw.githubusercontent.com/llvm/torch-mlir/main/projects/pt1/examples/fximporter_resnet18.py + +# Run ResNet18 as a standalone script. +python projects/pt1/examples/fximporter_resnet18.py + +# Output +load image from https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg +... +PyTorch prediction +[('Labrador retriever', 70.65674591064453), ('golden retriever', 4.988346099853516), ('Saluki, gazelle hound', 4.477451324462891)] +torch-mlir prediction +[('Labrador retriever', 70.6567153930664), ('golden retriever', 4.988325119018555), ('Saluki, gazelle hound', 4.477458477020264)] +``` + ### TorchScript ResNet18 Standalone script to Convert a PyTorch ResNet18 model to MLIR and run it on the CPU Backend: diff --git a/projects/pt1/examples/_example_utils.py b/projects/pt1/examples/_example_utils.py new file mode 100644 index 000000000000..8f63b4fd4a63 --- /dev/null +++ b/projects/pt1/examples/_example_utils.py @@ -0,0 +1,52 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from PIL import Image +import requests + +import torch +from torchvision import transforms + + +DEFAULT_IMAGE_URL = ( + "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" +) +DEFAULT_LABEL_URL = ( + "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt" +) + + +def load_and_preprocess_image(url: str = DEFAULT_IMAGE_URL): + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" + } + img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB") + # preprocessing pipeline + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + img_preprocessed = preprocess(img) + return torch.unsqueeze(img_preprocessed, 0) + + +def load_labels(url: str = DEFAULT_LABEL_URL): + classes_text = requests.get( + url=url, + stream=True, + ).text + labels = [line.strip() for line in classes_text.splitlines()] + return labels + + +def top3_possibilities(res, labels): + _, indexes = torch.sort(res, descending=True) + percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 + top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] + return top3 diff --git a/projects/pt1/examples/fximporter_resnet18.py b/projects/pt1/examples/fximporter_resnet18.py new file mode 100644 index 000000000000..8776c42fa7e4 --- /dev/null +++ b/projects/pt1/examples/fximporter_resnet18.py @@ -0,0 +1,59 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import sys +from pathlib import Path + +import torch +import torch.utils._pytree as pytree +import torchvision.models as models +from torch_mlir import fx +from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend +from torch_mlir_e2e_test.configs.utils import ( + recursively_convert_to_numpy, +) + +sys.path.append(str(Path(__file__).absolute().parent)) +from _example_utils import ( + top3_possibilities, + load_and_preprocess_image, + load_labels, + DEFAULT_IMAGE_URL, +) + + +print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr) +img = load_and_preprocess_image(DEFAULT_IMAGE_URL) +labels = load_labels() + +resnet18 = models.resnet18(pretrained=True).eval() +module = fx.export_and_import( + resnet18, + torch.ones(1, 3, 224, 224), + output_type="linalg-on-tensors", + func_name=resnet18.__class__.__name__, +) +backend = refbackend.RefBackendLinalgOnTensorsBackend() +compiled = backend.compile(module) +fx_module = backend.load(compiled) + +params = { + **dict(resnet18.named_buffers(remove_duplicate=False)), +} +params_flat, params_spec = pytree.tree_flatten(params) +params_flat = list(params_flat) +with torch.no_grad(): + numpy_inputs = recursively_convert_to_numpy(params_flat + [img]) + +golden_prediction = top3_possibilities(resnet18.forward(img), labels) +print("PyTorch prediction") +print(golden_prediction) + +prediction = top3_possibilities( + torch.from_numpy(getattr(fx_module, resnet18.__class__.__name__)(*numpy_inputs)), + labels, +) +print("torch-mlir prediction") +print(prediction) diff --git a/projects/pt1/examples/torchscript_resnet18.py b/projects/pt1/examples/torchscript_resnet18.py index 0cc5b5dda96a..ea56653ca6f6 100644 --- a/projects/pt1/examples/torchscript_resnet18.py +++ b/projects/pt1/examples/torchscript_resnet18.py @@ -4,71 +4,36 @@ # Also available under a BSD-style license. See LICENSE. import sys - -from PIL import Image -import requests +from pathlib import Path import torch import torchvision.models as models -from torchvision import transforms - from torch_mlir import torchscript from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend - -def load_and_preprocess_image(url: str): - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" - } - img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB") - # preprocessing pipeline - preprocess = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - img_preprocessed = preprocess(img) - return torch.unsqueeze(img_preprocessed, 0) - - -def load_labels(): - classes_text = requests.get( - "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt", - stream=True, - ).text - labels = [line.strip() for line in classes_text.splitlines()] - return labels - - -def top3_possibilities(res): - _, indexes = torch.sort(res, descending=True) - percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 - top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] - return top3 +sys.path.append(str(Path(__file__).absolute().parent)) +from _example_utils import ( + top3_possibilities, + load_and_preprocess_image, + load_labels, + DEFAULT_IMAGE_URL, +) def predictions(torch_func, jit_func, img, labels): - golden_prediction = top3_possibilities(torch_func(img)) + golden_prediction = top3_possibilities(torch_func(img), labels) print("PyTorch prediction") print(golden_prediction) - prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy()))) + prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())), labels) print("torch-mlir prediction") print(prediction) -image_url = ( - "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" -) - -print("load image from " + image_url, file=sys.stderr) -img = load_and_preprocess_image(image_url) +print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr) +img = load_and_preprocess_image(DEFAULT_IMAGE_URL) labels = load_labels() -resnet18 = models.resnet18(pretrained=True) -resnet18.train(False) +resnet18 = models.resnet18(pretrained=True).eval() module = torchscript.compile( resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors" ) From 911e7235819b16f2574964c0d6112c06501d7886 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Mon, 13 May 2024 13:01:53 -0500 Subject: [PATCH 5/5] Expands Q Commuting Ops (#3332) After running the model tests in SHARK-TestSuite, I noticed a few model failures due to half-fusion. Notably, RDN_pytorch_vaiq_int8 had a depth=5 convolution chain with multiple AtenViewOp's. --- lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 7870ff63cb40..38bc4d275bf1 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -39,7 +39,8 @@ template <> struct QuantInfo { bool isQCommutingOp(mlir::Operation *op) { // if adding a new commuting op here, be sure to add a // RemoveUnused pattern for that op to clean up afterwards - return llvm::isa(op); + return llvm::isa(op); } // The following conversion takes patterns of the form [op0 -> MPTQT -> dequant @@ -372,11 +373,12 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, - RemoveUnused, - QuantizeOperandsPastCommutingOps, + RemoveUnused, RemoveUnused, + RemoveUnused, + QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, - QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, QuantizeAccumulator, QuantizeAccumulator, QuantizeResultLikeOperand, QuantizeBias>( context);