Skip to content

Commit

Permalink
[Codegen][GPU] Add pass to expand multi_mma op shapes to intrinsic la…
Browse files Browse the repository at this point in the history
…yout (iree-org#18139)

This PR adds a new pass to explicitly materialize the dimensions of
intrinsic layouts for `iree_gpu.multi_mma` ops. This means adding an
expand_shape on each of the inputs to go from the `OpaqueMmaLayout`
shape to the `ConcreteMmaLayout` shape. This makes it easy to extract
the correct data from the tensors when it is time to distribute the
multi_mma op to lanes, since the shape will match the number of offsets
and sizes needed for the slice.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Aug 7, 2024
1 parent 352e05f commit 235e110
Show file tree
Hide file tree
Showing 10 changed files with 388 additions and 0 deletions.
79 changes: 79 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand Down Expand Up @@ -752,6 +753,84 @@ LogicalResult MMAAttr::populateOperandOffsetsSizesStrides(
return success();
}

LogicalResult MMAAttr::materializeOperandConcreteShape(
OpBuilder &builder, IREE::GPU::MMAFragment fragment, Value operand,
std::optional<ArrayRef<int64_t>> permutation,
SmallVector<ReassociationIndices> &reassociations,
RankedTensorType &resultType) const {
OpaqueMmaLayout opaqueLayout =
getOpaqueMFMALayout(operand.getContext(), getIntrinsic().getValue());
// TODO(Max191): The `getConcreteMFMALayout` function creates some
// `PerDimLayoutAttr` that are not used by this function. This means that
// any pass that uses `materializeOperandConcreteShape` needs to be
// dependent on the VectorExt dialect. Ideally, the `getConcreteMFMALayout`
// function should be refactored so we can reuse the shape information of
// the layout without needing to create any `PerDimLayoutAttr`.
ConcreteMmaLayout layout =
getConcreteMFMALayout(operand.getContext(), getIntrinsic().getValue());
SmallVector<ArrayRef<int64_t>> concreteSizes;
SmallVector<int64_t, 2> opaqueSizes;
switch (fragment) {
case IREE::GPU::MMAFragment::Lhs: {
concreteSizes.push_back(layout.aMLayout.getShapes());
concreteSizes.push_back(layout.aKLayout.getShapes());
opaqueSizes.push_back(opaqueLayout.mSize);
opaqueSizes.push_back(opaqueLayout.kSize);
break;
}
case IREE::GPU::MMAFragment::Rhs: {
concreteSizes.push_back(layout.bKLayout.getShapes());
concreteSizes.push_back(layout.bNLayout.getShapes());
opaqueSizes.push_back(opaqueLayout.kSize);
opaqueSizes.push_back(opaqueLayout.nSize);
break;
}
case IREE::GPU::MMAFragment::Acc: {
concreteSizes.push_back(layout.cMLayout.getShapes());
concreteSizes.push_back(layout.cNLayout.getShapes());
opaqueSizes.push_back(opaqueLayout.mSize);
opaqueSizes.push_back(opaqueLayout.nSize);
break;
}
}
if (permutation.has_value()) {
if (permutation.value().size() != opaqueSizes.size()) {
return failure();
}
applyPermutationToVector(concreteSizes, permutation.value());
applyPermutationToVector(opaqueSizes, permutation.value());
}

// Inner tile must have sizes matching the opaque layout.
auto operandType = llvm::cast<RankedTensorType>(operand.getType());
ArrayRef<int64_t> operandShape = operandType.getShape();
SmallVector<int64_t, 2> innerShape(operandShape.end() - opaqueSizes.size(),
operandShape.end());
if (!llvm::equal(opaqueSizes, innerShape)) {
return failure();
}

// Expand the shape of the inner tile to reflect the MMA thread layout.
SmallVector<int64_t, 4> resultShape(operandShape.begin(),
operandShape.end() - 2);
SmallVector<ReassociationIndices> reInds =
llvm::map_to_vector(llvm::seq<int64_t>(resultShape.size()),
[](int64_t idx) -> ReassociationIndices {
return ReassociationIndices({idx});
});
int idx = reInds.size();
for (ArrayRef<int64_t> sizes : concreteSizes) {
resultShape.append(SmallVector<int64_t>(sizes));
reInds.push_back(
llvm::to_vector(llvm::seq<int64_t>(idx, idx + sizes.size())));
idx += sizes.size();
}

reassociations = reInds;
resultType = operandType.clone(resultShape);
return success();
}

//===----------------------------------------------------------------------===//
// MMA Schedule Attributes
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class IREEGPU_MmaVectorLayoutAttr<string attrname, string mmaintrinsic> :
"getSubgroupSize",
"buildMmaOperation",
"populateOperandOffsetsSizesStrides",
"materializeOperandConcreteShape",
]>
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::GPU";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,26 @@ def IREEGPU_MmaInterfaceAttr : AttrInterface<"MmaInterfaceAttr"> {
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Constructs the offsets/sizes/strides for extracting the per-thread
slice of the given operand fragment.
}],
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"materializeOperandConcreteShape",
/*args=*/(ins
"::mlir::OpBuilder&":$builder,
"::mlir::iree_compiler::IREE::GPU::MMAFragment":$fragment,
"::mlir::Value":$operand,
"std::optional<::llvm::ArrayRef<int64_t>>":$permutation,
"::llvm::SmallVector<::mlir::SmallVector<int64_t, 2>>&":$reassociations,
"::mlir::RankedTensorType&":$result_type
),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
}]
>,
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ iree_gentbl_cc_library(
iree_compiler_cc_library(
name = "GPUTransforms",
srcs = [
"ConcretizeMmaShapes.cpp",
"DistributeMmaToLanes.cpp",
"FuseAndHoistParallelLoops.cpp",
"LowerIREEGPUOps.cpp",
Expand All @@ -69,6 +70,7 @@ iree_compiler_cc_library(
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
"//compiler/src/iree/compiler/Codegen/Transforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ iree_cc_library(
"Passes.h.inc"
"Transforms.h"
SRCS
"ConcretizeMmaShapes.cpp"
"DistributeMmaToLanes.cpp"
"FuseAndHoistParallelLoops.cpp"
"LowerIREEGPUOps.cpp"
Expand Down Expand Up @@ -80,6 +81,7 @@ iree_cc_library(
MLIRVectorUtils
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
iree::compiler::Codegen::Transforms
PUBLIC
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::GPU {

#define GEN_PASS_DEF_CONCRETIZEMMASHAPESPASS
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"

namespace {
struct ConcretizeMmaShapesPass final
: impl::ConcretizeMmaShapesPassBase<ConcretizeMmaShapesPass> {
using ConcretizeMmaShapesPassBase::ConcretizeMmaShapesPassBase;
void runOnOperation() override;
};
} // namespace

struct ConcretizeMmaOperandShape final : OpRewritePattern<MultiMmaOp> {
using OpRewritePattern::OpRewritePattern;

ConcretizeMmaOperandShape(MLIRContext *context, MMAFragment fragment)
: OpRewritePattern<MultiMmaOp>(context), fragment(fragment) {}

LogicalResult matchAndRewrite(MultiMmaOp mmaOp,
PatternRewriter &rewriter) const override {
if (!mmaOp.hasTensorSemantics()) {
return failure();
}

// Get the right operand and permutation for the `fragment`.
Value operand;
std::optional<ArrayRef<int64_t>> permutation;
switch (fragment) {
case MMAFragment::Lhs:
operand = mmaOp.getLhs();
permutation = mmaOp.getLhsPermutation();
break;
case MMAFragment::Rhs:
operand = mmaOp.getRhs();
permutation = mmaOp.getRhsPermutation();
break;
case MMAFragment::Acc:
operand = mmaOp.getAcc();
permutation = mmaOp.getAccPermutation();
break;
}

// Get the reassociation indices and result type of the expand_shape op.
MmaInterfaceAttr kind = mmaOp.getKind();
SmallVector<ReassociationIndices> reassociations;
RankedTensorType concreteType;
if (failed(kind.materializeOperandConcreteShape(rewriter, fragment, operand,
permutation, reassociations,
concreteType))) {
return failure();
}

// Create the expand_shape.
Location loc = mmaOp->getLoc();
Value concreteOperand = rewriter
.create<tensor::ExpandShapeOp>(
loc, concreteType, operand, reassociations)
.getResult();

// Expand the permutation for the new inner dimensions of the expanded
// multi_mma operand.
auto expandPerm =
[&](std::optional<ArrayRef<int64_t>> perm, MMAFragment frag,
int64_t outerRank) -> std::optional<DenseI64ArrayAttr> {
if (!perm.has_value()) {
return std::nullopt;
}
if (frag != fragment) {
return rewriter.getDenseI64ArrayAttr(perm.value());
}
SmallVector<ReassociationIndices> innerReInds(
reassociations.begin() + outerRank, reassociations.end());
for (auto &reInd : innerReInds) {
for (auto &idx : reInd) {
idx -= outerRank;
}
}
SmallVector<int64_t> expandedPerm;
for (auto reInd : applyPermutation(innerReInds, perm.value())) {
expandedPerm.append(reInd);
}
return rewriter.getDenseI64ArrayAttr(expandedPerm);
};
std::optional<DenseI64ArrayAttr> lhsPerm = expandPerm(
mmaOp.getLhsPermutation(), MMAFragment::Lhs, mmaOp.getLhsOuterRank());
std::optional<DenseI64ArrayAttr> rhsPerm = expandPerm(
mmaOp.getRhsPermutation(), MMAFragment::Rhs, mmaOp.getRhsOuterRank());
std::optional<DenseI64ArrayAttr> accPerm = expandPerm(
mmaOp.getAccPermutation(), MMAFragment::Acc, mmaOp.getAccOuterRank());

// Create the new multi_mma op with the concrete type.
auto concreteMmaOp = rewriter.create<MultiMmaOp>(
loc,
/*lhs=*/fragment == MMAFragment::Lhs ? concreteOperand : mmaOp.getLhs(),
/*rhs=*/fragment == MMAFragment::Rhs ? concreteOperand : mmaOp.getRhs(),
/*acc=*/fragment == MMAFragment::Acc ? concreteOperand : mmaOp.getAcc(),
mmaOp.getIndexingMaps(), mmaOp.getIteratorTypes(), mmaOp.getKind(),
lhsPerm, rhsPerm, accPerm);

if (auto config = getLoweringConfig(mmaOp)) {
setLoweringConfig(concreteMmaOp, config);
}

if (fragment != MMAFragment::Acc) {
rewriter.replaceOp(mmaOp, concreteMmaOp);
return success();
}

// For the Acc operand, the result needs to be collapsed back to the
// original type so that types match with consumers.
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
mmaOp, mmaOp.getAccType(), concreteMmaOp.getResult(), reassociations);

return success();
}

private:
MMAFragment fragment;
};

void ConcretizeMmaShapesPass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();

RewritePatternSet patterns(context);
if (concretizeInputs) {
patterns.insert<ConcretizeMmaOperandShape>(context, MMAFragment::Lhs);
patterns.insert<ConcretizeMmaOperandShape>(context, MMAFragment::Rhs);
}
if (concretizeResult) {
patterns.insert<ConcretizeMmaOperandShape>(context, MMAFragment::Acc);
}
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

} // namespace mlir::iree_compiler::IREE::GPU
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@ def DistributeMmaToLanesPass :
];
}

def ConcretizeMmaShapesPass :
InterfacePass<"iree-gpu-concretize-mma-shapes", "mlir::FunctionOpInterface"> {
let summary = "Expands the inner dimensions of iree_gpu.multi_mma ops to match the thread layout";
let dependentDialects = [
"::mlir::tensor::TensorDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
"::mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect",
];
let options = [
Option<"concretizeInputs", "concretize-inputs",
"bool", /*default=*/"true",
"Expand the inner dimensions for the lhs and rhs operands of the multi_mma ops.">,
Option<"concretizeResult", "concretize-result",
"bool", /*default=*/"true",
"Expand the inner dimensions for the acc operand of the multi_mma ops.">,
];
}

def FuseAndHoistParallelLoopsPass :
InterfacePass<"iree-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> {
let summary = "Greedily fuses and hoists parallel loops.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
[
"concretize_mma_shapes.mlir",
"distribute_mma_to_lanes.mlir",
"fuse_and_hoist_forall.mlir",
"pack_to_intrinsics.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ iree_lit_test_suite(
NAME
lit
SRCS
"concretize_mma_shapes.mlir"
"distribute_mma_to_lanes.mlir"
"fuse_and_hoist_forall.mlir"
"pack_to_intrinsics.mlir"
Expand Down
Loading

0 comments on commit 235e110

Please sign in to comment.