From 90a022526e36861e942384d36178c92f6bdb17d9 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 5 Oct 2023 09:25:09 -0400 Subject: [PATCH] [Torch] Assume strict symbolic shapes (#15107) In nearly all real world models, dynamic numpy style broadcasting never occurs, and managing such cases leads to troublingly pessimistic lowerings and restricts later optimization. This defaults all dynamic symbols coming from pytorch to be interpreted strictly, meaning it must represent an actual size (not something that can optionally be 1). --- .../torch-iree/InputConversion/CMakeLists.txt | 1 + .../torch-iree/InputConversion/Passes.cpp | 12 +++++- .../Torch/torch-iree/InputConversion/Passes.h | 13 +++++- .../torch-iree/InputConversion/Passes.td | 6 +++ .../SetStrictSymbolicShapes.cpp | 43 +++++++++++++++++++ .../InputConversion/test/CMakeLists.txt | 1 + .../test/assume_strict_symbols.mlir | 14 ++++++ .../Torch/torch-iree/PluginRegistration.cpp | 13 +++++- 8 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 compiler/plugins/input/Torch/torch-iree/InputConversion/SetStrictSymbolicShapes.cpp create mode 100644 compiler/plugins/input/Torch/torch-iree/InputConversion/test/assume_strict_symbols.mlir diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/CMakeLists.txt b/compiler/plugins/input/Torch/torch-iree/InputConversion/CMakeLists.txt index 9d157194e40e..6e184d66e2a4 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/CMakeLists.txt +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/CMakeLists.txt @@ -36,6 +36,7 @@ iree_cc_library( "Passes.h" SRCS "ConvertTMTensorToLinalgExt.cpp" + "SetStrictSymbolicShapes.cpp" "Passes.cpp" DEPS ::PassHeaders diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp index 9d44ec42e049..b9438516dd13 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp @@ -26,7 +26,8 @@ namespace { #include "torch-iree/InputConversion/Passes.h.inc" // IWYU pragma: export } // namespace -void createTorchToIREEPipeline(OpPassManager &pm) { +void createTorchToIREEPipeline( + OpPassManager &pm, const TorchToIREELoweringPipelineOptions &options) { // This pipeline adapted from // createTorchBackendToLinalgOnTensorsBackendPipeline. Keep in sync with // additions there. Lower to linalg + guards which is the input to codegen @@ -35,6 +36,13 @@ void createTorchToIREEPipeline(OpPassManager &pm) { // model) and those constants get somewhat obscured by TorchToArith. llvm::ArrayRef emptyArrayRef; + if (options.strictSymbolicShapes) { + pm.addNestedPass(createSetStrictSymbolicShapesPass()); + // Run canonicalization in case any previously non-strict dynamic code can + // now be simplified. + pm.addNestedPass(createCanonicalizerPass()); + } + pm.addNestedPass( torch::Torch::createDecomposeComplexOpsPass(emptyArrayRef)); pm.addNestedPass(torch::createConvertTorchToTMTensorPass()); @@ -68,7 +76,7 @@ void registerTMTensorConversionPasses() { // Generated. registerPasses(); - mlir::PassPipelineRegistration<>( + mlir::PassPipelineRegistration( "torch-to-iree", "Pipeline to lower from the Torch backend contract to legal IREE input.", createTorchToIREEPipeline); diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.h b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.h index 11c1a4577649..4afb0cd1744a 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.h +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.h @@ -14,13 +14,24 @@ namespace mlir { namespace iree_compiler { namespace TorchInput { +struct TorchToIREELoweringPipelineOptions + : public PassPipelineOptions { + Option strictSymbolicShapes{ + *this, "strict-symbolic-shapes", + llvm::cl::desc("Use strict symbolic shapes."), llvm::cl::init(true)}; +}; + std::unique_ptr> createConvertTMTensorToLinalgExtPass(); +std::unique_ptr> +createSetStrictSymbolicShapesPass(); + // Creates a pipeline that lowers from the torch backend contract to IREE. // This is based on the torch-backend-to-linalg-on-tensors-backend-pipeline // pipeline in torch-mlir but includes IREE specific lowerings. -void createTorchToIREEPipeline(OpPassManager &pm); +void createTorchToIREEPipeline( + OpPassManager &pm, const TorchToIREELoweringPipelineOptions &options); //===----------------------------------------------------------------------===// // Register all Passes diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.td b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.td index 0391039a5f7e..eb9a5abf0572 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.td +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.td @@ -15,4 +15,10 @@ def ConvertTMTensorToLinalgExt : let constructor = "mlir::iree_compiler::TorchInput::createConvertTMTensorToLinalgExtPass()"; } +def SetStrictSymbolicShapesPass : + Pass<"torch-iree-set-strict-symbolic-shapes", "func::FuncOp"> { + let summary = "Adds the attribute indicating strict symbolic shapes in Torch IR"; + let constructor = "mlir::iree_compiler::TorchInput::createSetStrictSymbolicShapesPass()"; +} + #endif // TORCH_IREE_INPUTCONVERSION_PASSES diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/SetStrictSymbolicShapes.cpp b/compiler/plugins/input/Torch/torch-iree/InputConversion/SetStrictSymbolicShapes.cpp new file mode 100644 index 000000000000..c33b8510d9bd --- /dev/null +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/SetStrictSymbolicShapes.cpp @@ -0,0 +1,43 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- SetStrictSymbolicShapes.cpp - Pass to set strict symbolic shapes -=====// +// +// Adds an attribute to all functions in the module indicating all contained +// operations can be treated as if the symbolic shapes are strict, thereby +// eliminating the need for special dynamic size-1 broadcast handling. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/StringRef.h" +#include "torch-iree/InputConversion/PassDetail.h" +#include "torch-iree/InputConversion/Passes.h" + +static const llvm::StringLiteral kStrictSymbolsMarker = + "torch.assume_strict_symbolic_shapes"; + +namespace mlir { +namespace iree_compiler { +namespace TorchInput { + +namespace { +struct SetStrictSymbolicShapesPass + : public SetStrictSymbolicShapesPassBase { + + void runOnOperation() override { + getOperation()->setAttr(kStrictSymbolsMarker, UnitAttr::get(&getContext())); + } +}; +} // namespace + +std::unique_ptr> +createSetStrictSymbolicShapesPass() { + return std::make_unique(); +} + +} // namespace TorchInput +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt index 0c10eceff007..18d0de0072b9 100644 --- a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/CMakeLists.txt @@ -2,6 +2,7 @@ iree_lit_test_suite( NAME lit SRCS + "assume_strict_symbols.mlir" "attention.mlir" "scan.mlir" "scatter.mlir" diff --git a/compiler/plugins/input/Torch/torch-iree/InputConversion/test/assume_strict_symbols.mlir b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/assume_strict_symbols.mlir new file mode 100644 index 000000000000..672e9203c1b1 --- /dev/null +++ b/compiler/plugins/input/Torch/torch-iree/InputConversion/test/assume_strict_symbols.mlir @@ -0,0 +1,14 @@ +// RUN: iree-opt --split-input-file --torch-iree-set-strict-symbolic-shapes %s | FileCheck %s + +module { + // CHECK: func @forward() {{.*}} attributes {torch.assume_strict_symbolic_shapes} + func.func @forward() -> !torch.int { + %int0 = torch.constant.int 0 + return %int0 : !torch.int + } + // CHECK: func @other_forward() {{.*}} attributes {torch.assume_strict_symbolic_shapes} + func.func @other_forward() -> !torch.int { + %int1 = torch.constant.int 1 + return %int1 : !torch.int + } +} diff --git a/compiler/plugins/input/Torch/torch-iree/PluginRegistration.cpp b/compiler/plugins/input/Torch/torch-iree/PluginRegistration.cpp index f363a037631d..82d514e66b84 100644 --- a/compiler/plugins/input/Torch/torch-iree/PluginRegistration.cpp +++ b/compiler/plugins/input/Torch/torch-iree/PluginRegistration.cpp @@ -18,7 +18,14 @@ namespace mlir::iree_compiler { namespace { struct TorchOptions { - void bindOptions(OptionsBinder &binder) {} + bool strictSymbolicShapes = true; + void bindOptions(OptionsBinder &binder) { + static llvm::cl::OptionCategory category("Torch Input"); + binder.opt( + "iree-torch-use-strict-symbolic-shapes", strictSymbolicShapes, + llvm::cl::cat(category), + llvm::cl::desc("Forces dynamic shapes to be treated as strict")); + } }; // The shark-turbine plugin provides dialects, passes and opt-in options. @@ -42,7 +49,9 @@ struct TorchSession bool extendCustomInputConversionPassPipeline( OpPassManager &passManager, std::string_view typeMnemonic) override { if (typeMnemonic == "torch") { - TorchInput::createTorchToIREEPipeline(passManager); + TorchInput::TorchToIREELoweringPipelineOptions torchOptions; + torchOptions.strictSymbolicShapes = options.strictSymbolicShapes; + TorchInput::createTorchToIREEPipeline(passManager, torchOptions); passManager.addNestedPass( TorchInput::createConvertTMTensorToLinalgExtPass()); return true;