Skip to content

Commit

Permalink
[Torch] Assume strict symbolic shapes (iree-org#15107)
Browse files Browse the repository at this point in the history
In nearly all real world models, dynamic numpy style broadcasting never
occurs, and managing such cases leads to troublingly pessimistic
lowerings and restricts later optimization. This defaults all dynamic
symbols coming from pytorch to be interpreted strictly, meaning it must
represent an actual size (not something that can optionally be 1).
  • Loading branch information
qedawkins authored Oct 5, 2023
1 parent 24d80e1 commit 90a0225
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ iree_cc_library(
"Passes.h"
SRCS
"ConvertTMTensorToLinalgExt.cpp"
"SetStrictSymbolicShapes.cpp"
"Passes.cpp"
DEPS
::PassHeaders
Expand Down
12 changes: 10 additions & 2 deletions compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ namespace {
#include "torch-iree/InputConversion/Passes.h.inc" // IWYU pragma: export
} // namespace

void createTorchToIREEPipeline(OpPassManager &pm) {
void createTorchToIREEPipeline(
OpPassManager &pm, const TorchToIREELoweringPipelineOptions &options) {
// This pipeline adapted from
// createTorchBackendToLinalgOnTensorsBackendPipeline. Keep in sync with
// additions there. Lower to linalg + guards which is the input to codegen
Expand All @@ -35,6 +36,13 @@ void createTorchToIREEPipeline(OpPassManager &pm) {
// model) and those constants get somewhat obscured by TorchToArith.
llvm::ArrayRef<std::string> emptyArrayRef;

if (options.strictSymbolicShapes) {
pm.addNestedPass<func::FuncOp>(createSetStrictSymbolicShapesPass());
// Run canonicalization in case any previously non-strict dynamic code can
// now be simplified.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}

pm.addNestedPass<func::FuncOp>(
torch::Torch::createDecomposeComplexOpsPass(emptyArrayRef));
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTMTensorPass());
Expand Down Expand Up @@ -68,7 +76,7 @@ void registerTMTensorConversionPasses() {
// Generated.
registerPasses();

mlir::PassPipelineRegistration<>(
mlir::PassPipelineRegistration<TorchToIREELoweringPipelineOptions>(
"torch-to-iree",
"Pipeline to lower from the Torch backend contract to legal IREE input.",
createTorchToIREEPipeline);
Expand Down
13 changes: 12 additions & 1 deletion compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@ namespace mlir {
namespace iree_compiler {
namespace TorchInput {

struct TorchToIREELoweringPipelineOptions
: public PassPipelineOptions<TorchToIREELoweringPipelineOptions> {
Option<bool> strictSymbolicShapes{
*this, "strict-symbolic-shapes",
llvm::cl::desc("Use strict symbolic shapes."), llvm::cl::init(true)};
};

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTMTensorToLinalgExtPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createSetStrictSymbolicShapesPass();

// Creates a pipeline that lowers from the torch backend contract to IREE.
// This is based on the torch-backend-to-linalg-on-tensors-backend-pipeline
// pipeline in torch-mlir but includes IREE specific lowerings.
void createTorchToIREEPipeline(OpPassManager &pm);
void createTorchToIREEPipeline(
OpPassManager &pm, const TorchToIREELoweringPipelineOptions &options);

//===----------------------------------------------------------------------===//
// Register all Passes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,10 @@ def ConvertTMTensorToLinalgExt :
let constructor = "mlir::iree_compiler::TorchInput::createConvertTMTensorToLinalgExtPass()";
}

def SetStrictSymbolicShapesPass :
Pass<"torch-iree-set-strict-symbolic-shapes", "func::FuncOp"> {
let summary = "Adds the attribute indicating strict symbolic shapes in Torch IR";
let constructor = "mlir::iree_compiler::TorchInput::createSetStrictSymbolicShapesPass()";
}

#endif // TORCH_IREE_INPUTCONVERSION_PASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

//===- SetStrictSymbolicShapes.cpp - Pass to set strict symbolic shapes -=====//
//
// Adds an attribute to all functions in the module indicating all contained
// operations can be treated as if the symbolic shapes are strict, thereby
// eliminating the need for special dynamic size-1 broadcast handling.
//
//===----------------------------------------------------------------------===//

#include "llvm/ADT/StringRef.h"
#include "torch-iree/InputConversion/PassDetail.h"
#include "torch-iree/InputConversion/Passes.h"

static const llvm::StringLiteral kStrictSymbolsMarker =
"torch.assume_strict_symbolic_shapes";

namespace mlir {
namespace iree_compiler {
namespace TorchInput {

namespace {
struct SetStrictSymbolicShapesPass
: public SetStrictSymbolicShapesPassBase<SetStrictSymbolicShapesPass> {

void runOnOperation() override {
getOperation()->setAttr(kStrictSymbolsMarker, UnitAttr::get(&getContext()));
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
createSetStrictSymbolicShapesPass() {
return std::make_unique<SetStrictSymbolicShapesPass>();
}

} // namespace TorchInput
} // namespace iree_compiler
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ iree_lit_test_suite(
NAME
lit
SRCS
"assume_strict_symbols.mlir"
"attention.mlir"
"scan.mlir"
"scatter.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: iree-opt --split-input-file --torch-iree-set-strict-symbolic-shapes %s | FileCheck %s

module {
// CHECK: func @forward() {{.*}} attributes {torch.assume_strict_symbolic_shapes}
func.func @forward() -> !torch.int {
%int0 = torch.constant.int 0
return %int0 : !torch.int
}
// CHECK: func @other_forward() {{.*}} attributes {torch.assume_strict_symbolic_shapes}
func.func @other_forward() -> !torch.int {
%int1 = torch.constant.int 1
return %int1 : !torch.int
}
}
13 changes: 11 additions & 2 deletions compiler/plugins/input/Torch/torch-iree/PluginRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ namespace mlir::iree_compiler {
namespace {

struct TorchOptions {
void bindOptions(OptionsBinder &binder) {}
bool strictSymbolicShapes = true;
void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("Torch Input");
binder.opt<bool>(
"iree-torch-use-strict-symbolic-shapes", strictSymbolicShapes,
llvm::cl::cat(category),
llvm::cl::desc("Forces dynamic shapes to be treated as strict"));
}
};

// The shark-turbine plugin provides dialects, passes and opt-in options.
Expand All @@ -42,7 +49,9 @@ struct TorchSession
bool extendCustomInputConversionPassPipeline(
OpPassManager &passManager, std::string_view typeMnemonic) override {
if (typeMnemonic == "torch") {
TorchInput::createTorchToIREEPipeline(passManager);
TorchInput::TorchToIREELoweringPipelineOptions torchOptions;
torchOptions.strictSymbolicShapes = options.strictSymbolicShapes;
TorchInput::createTorchToIREEPipeline(passManager, torchOptions);
passManager.addNestedPass<func::FuncOp>(
TorchInput::createConvertTMTensorToLinalgExtPass());
return true;
Expand Down

0 comments on commit 90a0225

Please sign in to comment.