diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/CMakeLists.txt b/compiler/plugins/input/Torch/torch-iree/InputConversion/CMakeLists.txt index 9d157194e40e..6e184d66e2a4 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/CMakeLists.txt +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/CMakeLists.txt @@ -36,6 +36,7 @@ iree_cc_library( "Passes.h" SRCS "ConvertTMTensorToLinalgExt.cpp" + "SetStrictSymbolicShapes.cpp" "Passes.cpp" DEPS ::PassHeaders diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp index 9d44ec42e049..b9438516dd13 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp @@ -26,7 +26,8 @@ namespace { #include "torch-iree/InputConversion/Passes.h.inc" // IWYU pragma: export } // namespace -void createTorchToIREEPipeline(OpPassManager &pm) { +void createTorchToIREEPipeline( + OpPassManager &pm, const TorchToIREELoweringPipelineOptions &options) { // This pipeline adapted from // createTorchBackendToLinalgOnTensorsBackendPipeline. Keep in sync with // additions there. Lower to linalg + guards which is the input to codegen @@ -35,6 +36,13 @@ void createTorchToIREEPipeline(OpPassManager &pm) { // model) and those constants get somewhat obscured by TorchToArith. llvm::ArrayRef emptyArrayRef; + if (options.strictSymbolicShapes) { + pm.addNestedPass(createSetStrictSymbolicShapesPass()); + // Run canonicalization in case any previously non-strict dynamic code can + // now be simplified. + pm.addNestedPass(createCanonicalizerPass()); + } + pm.addNestedPass( torch::Torch::createDecomposeComplexOpsPass(emptyArrayRef)); pm.addNestedPass(torch::createConvertTorchToTMTensorPass()); @@ -68,7 +76,7 @@ void registerTMTensorConversionPasses() { // Generated. registerPasses(); - mlir::PassPipelineRegistration<>( + mlir::PassPipelineRegistration( "torch-to-iree", "Pipeline to lower from the Torch backend contract to legal IREE input.", createTorchToIREEPipeline); diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.h b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.h index 11c1a4577649..4afb0cd1744a 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.h +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.h @@ -14,13 +14,24 @@ namespace mlir { namespace iree_compiler { namespace TorchInput { +struct TorchToIREELoweringPipelineOptions + : public PassPipelineOptions { + Option strictSymbolicShapes{ + *this, "strict-symbolic-shapes", + llvm::cl::desc("Use strict symbolic shapes."), llvm::cl::init(true)}; +}; + std::unique_ptr> createConvertTMTensorToLinalgExtPass(); +std::unique_ptr> +createSetStrictSymbolicShapesPass(); + // Creates a pipeline that lowers from the torch backend contract to IREE. // This is based on the torch-backend-to-linalg-on-tensors-backend-pipeline // pipeline in torch-mlir but includes IREE specific lowerings. -void createTorchToIREEPipeline(OpPassManager &pm); +void createTorchToIREEPipeline( + OpPassManager &pm, const TorchToIREELoweringPipelineOptions &options); //===----------------------------------------------------------------------===// // Register all Passes diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.td b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.td index 0391039a5f7e..eb9a5abf0572 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.td +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.td @@ -15,4 +15,10 @@ def ConvertTMTensorToLinalgExt : let constructor = "mlir::iree_compiler::TorchInput::createConvertTMTensorToLinalgExtPass()"; } +def SetStrictSymbolicShapesPass : + Pass<"torch-iree-set-strict-symbolic-shapes", "func::FuncOp"> { + let summary = "Adds the attribute indicating strict symbolic shapes in Torch IR"; + let constructor = "mlir::iree_compiler::TorchInput::createSetStrictSymbolicShapesPass()"; +} + #endif // TORCH_IREE_INPUTCONVERSION_PASSES diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/SetStrictSymbolicShapes.cpp b/compiler/plugins/input/Torch/torch-iree/InputConversion/SetStrictSymbolicShapes.cpp new file mode 100644 index 000000000000..c33b8510d9bd --- /dev/null +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/SetStrictSymbolicShapes.cpp @@ -0,0 +1,43 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- SetStrictSymbolicShapes.cpp - Pass to set strict symbolic shapes -=====// +// +// Adds an attribute to all functions in the module indicating all contained +// operations can be treated as if the symbolic shapes are strict, thereby +// eliminating the need for special dynamic size-1 broadcast handling. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/StringRef.h" +#include "torch-iree/InputConversion/PassDetail.h" +#include "torch-iree/InputConversion/Passes.h" + +static const llvm::StringLiteral kStrictSymbolsMarker = + "torch.assume_strict_symbolic_shapes"; + +namespace mlir { +namespace iree_compiler { +namespace TorchInput { + +namespace { +struct SetStrictSymbolicShapesPass + : public SetStrictSymbolicShapesPassBase { + + void runOnOperation() override { + getOperation()->setAttr(kStrictSymbolsMarker, UnitAttr::get(&getContext())); + } +}; +} // namespace + +std::unique_ptr> +createSetStrictSymbolicShapesPass() { + return std::make_unique(); +} + +} // namespace TorchInput +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt index 0c10eceff007..18d0de0072b9 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt @@ -2,6 +2,7 @@ iree_lit_test_suite( NAME lit SRCS + "assume_strict_symbols.mlir" "attention.mlir" "scan.mlir" "scatter.mlir" diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/assume_strict_symbols.mlir b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/assume_strict_symbols.mlir new file mode 100644 index 000000000000..672e9203c1b1 --- /dev/null +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/assume_strict_symbols.mlir @@ -0,0 +1,14 @@ +// RUN: iree-opt --split-input-file --torch-iree-set-strict-symbolic-shapes %s | FileCheck %s + +module { + // CHECK: func @forward() {{.*}} attributes {torch.assume_strict_symbolic_shapes} + func.func @forward() -> !torch.int { + %int0 = torch.constant.int 0 + return %int0 : !torch.int + } + // CHECK: func @other_forward() {{.*}} attributes {torch.assume_strict_symbolic_shapes} + func.func @other_forward() -> !torch.int { + %int1 = torch.constant.int 1 + return %int1 : !torch.int + } +} diff --git a/compiler/plugins/input/Torch/torch-iree/PluginRegistration.cpp b/compiler/plugins/input/Torch/torch-iree/PluginRegistration.cpp index f363a037631d..82d514e66b84 100644 --- a/compiler/plugins/input/Torch/torch-iree/PluginRegistration.cpp +++ b/compiler/plugins/input/Torch/torch-iree/PluginRegistration.cpp @@ -18,7 +18,14 @@ namespace mlir::iree_compiler { namespace { struct TorchOptions { - void bindOptions(OptionsBinder &binder) {} + bool strictSymbolicShapes = true; + void bindOptions(OptionsBinder &binder) { + static llvm::cl::OptionCategory category("Torch Input"); + binder.opt( + "iree-torch-use-strict-symbolic-shapes", strictSymbolicShapes, + llvm::cl::cat(category), + llvm::cl::desc("Forces dynamic shapes to be treated as strict")); + } }; // The shark-turbine plugin provides dialects, passes and opt-in options. @@ -42,7 +49,9 @@ struct TorchSession bool extendCustomInputConversionPassPipeline( OpPassManager &passManager, std::string_view typeMnemonic) override { if (typeMnemonic == "torch") { - TorchInput::createTorchToIREEPipeline(passManager); + TorchInput::TorchToIREELoweringPipelineOptions torchOptions; + torchOptions.strictSymbolicShapes = options.strictSymbolicShapes; + TorchInput::createTorchToIREEPipeline(passManager, torchOptions); passManager.addNestedPass( TorchInput::createConvertTMTensorToLinalgExtPass()); return true;