Skip to content

Commit

Permalink
[Global Opt] Add option to generalize matmul ops (#19741)
Browse files Browse the repository at this point in the history
In order to support -O* flags, generalizing matmul ops needs to be moved
out of preprocessing and into global optimization. This adds a flag
`iree-opt-generalize-matmul` and uses it during global optimization's
`GeneralizeLinalgNamedOps` pass. Also, this changes SDXL tests to use
this flag instead of the preprocessing pass.

---------

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Jan 21, 2025
1 parent b47fbdf commit 3e15a5a
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace {
struct GeneralizeLinalgNamedOpsPass
: public impl::GeneralizeLinalgNamedOpsPassBase<
GeneralizeLinalgNamedOpsPass> {
using Base::Base;
void runOnOperation() override;
};
} // namespace
Expand Down Expand Up @@ -62,7 +63,12 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() {
auto funcOp = getOperation();
SmallVector<linalg::LinalgOp> namedOpCandidates;
funcOp.walk([&](linalg::LinalgOp linalgOp) {
if (!IREE::Flow::isNonNullAndOutsideDispatch(linalgOp)) {
if (!IREE::Flow::isNonNullAndOutsideDispatch(linalgOp) ||
isa<linalg::GenericOp>(linalgOp)) {
return;
}
if (enableGeneralizeMatmul && linalg::isaContractionOpInterface(linalgOp)) {
namedOpCandidates.push_back(linalgOp);
return;
}
if (isa_and_nonnull<linalg::AbsOp, linalg::AddOp, linalg::BroadcastOp,
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ void buildGlobalOptimizationPassPipeline(
// dims as the unit dim folding pass updates indexing maps and is better
// at working with generics. By this point we have already done any
// specialized raising and the op names are no longer useful.
.addPass(createGeneralizeLinalgNamedOpsPass);
.addPass([&]() {
GeneralizeLinalgNamedOpsPassOptions opt;
opt.enableGeneralizeMatmul = transformOptions.options.generalizeMatmul;
return createGeneralizeLinalgNamedOpsPass(opt);
});

mainPassManager.addPass(DispatchCreation::createFoldUnitExtentDimsPass());
FunctionLikeNest(mainPassManager)
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def FuseSiluHorizontalMatmulPass:
def GeneralizeLinalgNamedOpsPass :
InterfacePass<"iree-global-opt-generalize-linalg-named-ops", "mlir::FunctionOpInterface"> {
let summary = "Convert some Linalg named ops into linalg.generics.";
let options = [
Option<"enableGeneralizeMatmul", "enable-generalize-matmul", "bool",
/*default=*/"false", "Convert linalg named opt to generic ops.">,
];
}

def InferNumericNarrowingPass :
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) {
"File path to create a parameter archive of splat values out of all "
"parameter backed globals."),
llvm::cl::cat(category));

binder.opt<bool>(
"iree-opt-generalize-matmul", generalizeMatmul,
llvm::cl::desc("Convert named matmul ops to linalg generic ops during "
"global optimization to enable better fusion."),
llvm::cl::cat(category));
}

void SchedulingOptions::bindOptions(OptionsBinder &binder) {
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ struct GlobalOptimizationOptions {
// Strips debug assertions after any useful information has been extracted.
bool stripAssertions = false;

// Converts linalg named matmul ops to linalg generic ops.
bool generalizeMatmul = false;

void bindOptions(OptionsBinder &binder);
using FromFlags = OptionsFromFlags<GlobalOptimizationOptions>;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def SDXL_PUNET_INT8_FP8_OUT(
"--iree-dispatch-creation-enable-aggressive-fusion=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-opt-outer-dim-concat=true",
"--iree-opt-generalize-matmul=true",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-data-tiling=false",
Expand All @@ -218,7 +219,7 @@ def SDXL_PUNET_INT8_FP8_OUT(

INT8_PUNET_FLAGS = [
f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet.mlir",
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))",
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)",
]

ROCM_UNET_PIPELINE_FP16_COMPILE_FLAGS = [
Expand Down

0 comments on commit 3e15a5a

Please sign in to comment.