Skip to content

Commit

Permalink
[VectorDistribution] Set layouts before generalization and folding (i…
Browse files Browse the repository at this point in the history
…ree-org#18186)

This patch allows setting anchors for linalg operations before
generalization + unit dims folding. This patch introduces 2 things:

- Unit dim folding for to_layout ops: required for unit dim folding
which is further required for 1x1ConvToMatmul
- Teaching ConfigureTensorLayout to set layouts for convolutions with
unit filter dims
  • Loading branch information
Groverkss authored Aug 28, 2024
1 parent afd04e4 commit dbbc56e
Show file tree
Hide file tree
Showing 13 changed files with 261 additions and 32 deletions.
14 changes: 4 additions & 10 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -974,14 +974,8 @@ NestedLayoutAttr createNestedLayout(MLIRContext *context, int64_t rank,
FailureOr<std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
MMAScheduleAttr::getContractionLayout(linalg::LinalgOp contractOp) const {
auto maybeOpInfo = VectorContractOpInfo::inferFromIndexingMaps(
contractOp.getIndexingMapsArray());
if (failed(maybeOpInfo)) {
return failure();
}
VectorContractOpInfo opInfo = maybeOpInfo.value();

MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo,
linalg::LinalgOp contractOp) const {
LLVM_DEBUG({
llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
llvm::errs() << "For schedule: " << *this << "\n";
Expand All @@ -1000,8 +994,8 @@ MMAScheduleAttr::getContractionLayout(linalg::LinalgOp contractOp) const {
return failure();
}

int64_t batchCount = opInfo.getBatchCount();
if (batchCount == 1 && bounds[0] != 1) {
if (!llvm::all_of(opInfo.getBatchDims(),
[&bounds](int64_t dim) { return bounds[dim] == 1; })) {
LLVM_DEBUG({ llvm::errs() << "non-unit batch dimension\n"; });
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ def IREEGPU_MmaScheduleAttr : AttrDef<IREEGPU_Dialect, "MMASchedule"> {
::mlir::FailureOr<::std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
getContractionLayout(::mlir::linalg::LinalgOp contractOp) const;
getContractionLayout(::mlir::iree_compiler::VectorContractOpInfo &opInfo,
::mlir::linalg::LinalgOp contractOp) const;
}];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ iree_compiler_cc_library(
name = "VectorExtTransforms",
srcs = [
"Passes.cpp",
"VectorExtFoldUnitExtentDims.cpp",
"VectorizeIREEVectorExtOps.cpp",
],
hdrs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_cc_library(
"Passes.h.inc"
SRCS
"Passes.cpp"
"VectorExtFoldUnitExtentDims.cpp"
"VectorizeIREEVectorExtOps.cpp"
DEPS
::PassesIncGen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,13 @@ def VectorizeIREEVectorExtOpsPass :
];
}

def VectorExtFoldUnitExtentDimsPass :
Pass<"iree-vector-ext-fold-unit-extent-dims", ""> {
let summary = "Folds unit dims for iree_vector_ext ops";
let dependentDialects = [
"::mlir::tensor::TensorDialect",
"::mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect"
];
}

#endif // IREE_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::VectorExt {

#define GEN_PASS_DEF_VECTOREXTFOLDUNITEXTENTDIMSPASS
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc"

namespace {

struct DropToLayoutUnitDims final
: OpRewritePattern<IREE::VectorExt::ToLayoutOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp,
PatternRewriter &rewriter) const override {
if (!toLayoutOp.hasTensorSemantics()) {
return rewriter.notifyMatchFailure(toLayoutOp,
"requires tensor semanticS");
}

Location loc = toLayoutOp.getLoc();
ShapedType inputTy = toLayoutOp.getType();
ArrayRef<int64_t> shape = inputTy.getShape();

// Find list of dims to drop and the target shape.
SmallVector<bool> unitDims(shape.size(), false);
SmallVector<int64_t> targetShape;
bool hasUnitDims = false;
for (auto [idx, size] : llvm::enumerate(shape)) {
if (size == 1) {
unitDims[idx] = true;
hasUnitDims = true;
continue;
}
targetShape.push_back(size);
}

if (!hasUnitDims) {
return rewriter.notifyMatchFailure(toLayoutOp, "no unit dims present");
}

// Drop unit dims using extract_slice.
FailureOr<Value> rankReducingExtract =
tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, loc, toLayoutOp.getInput(), targetShape);
assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");

// Find the rank reduced layout.
VectorLayoutInterface newLayout = toLayoutOp.getLayout().project(unitDims);

Value rankReducedValue = rankReducingExtract.value();
auto newToLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, rankReducedValue.getType(), rankReducedValue, newLayout,
toLayoutOp.getSharedMemoryConversion());
newToLayoutOp->setDiscardableAttrs(
toLayoutOp->getDiscardableAttrDictionary());

// Expand to preserve output shape using insert_slice.
// Here, since the shape comes from the result of a to_layout op, it will
// always be static.
Value dest =
rewriter.create<tensor::EmptyOp>(loc, shape, inputTy.getElementType());

int64_t rank = inputTy.getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, dest);
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
toLayoutOp, newToLayoutOp.getResult(), dest, offsets, sizes, strides);

return success();
}
};

} // namespace

namespace {
struct VectorExtFoldUnitExtentDimsPass final
: impl::VectorExtFoldUnitExtentDimsPassBase<
VectorExtFoldUnitExtentDimsPass> {
void runOnOperation() override {

MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<DropToLayoutUnitDims>(ctx);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace

} // namespace mlir::iree_compiler::IREE::VectorExt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ struct AMDGPUPrepareForChainedMatmulPass final
contractOp.getLoc(), rhs, lhs, acc,
rewriter.getAffineMapArrayAttr({rhsMap, lhsMap, accMap}),
contractOp.getIteratorTypesAttr());
swappedOp->setDiscardableAttrs(contractOp->getDiscardableAttrDictionary());

acc = cast<VectorValue>(swappedOp.getResult());
acc = swapDims(rewriter, acc, accN, accM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
assert(linalg::isaContractionOpInterface(contract) &&
"cannot set contraction anchor on non contraction op");

auto layouts = schedule.getContractionLayout(contract);
FailureOr<VectorContractOpInfo> opInfo =
VectorContractOpInfo::inferFromIndexingMaps(
contract.getIndexingMapsArray());
assert(succeeded(opInfo) && "contraction should have been inferred");

auto layouts = schedule.getContractionLayout(opInfo.value(), contract);
if (failed(layouts)) {
return contract->emitError("cannot get concrete layout for contraction");
}
Expand Down Expand Up @@ -77,6 +82,84 @@ LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
return success();
}

LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule,
RewriterBase &rewriter,
linalg::LinalgOp conv) {
// TODO: Add SIMT fallback.
if (!schedule) {
return conv->emitError("missing mma schedule for convolution");
}

// This function should have only be called on a convolution op.
FailureOr<linalg::ConvolutionDimensions> convDims =
linalg::inferConvolutionDims(conv);
assert(succeeded(convDims) &&
"cannot set convolution anchor on non convolution op");

// Only convs with unit filter dims can be directly converted to matmul.
SmallVector<int64_t> shape = conv.getStaticLoopRanges();
if (!llvm::all_of(convDims->filterLoop,
[&shape](unsigned dim) { return shape[dim] == 1; })) {
return failure();
}

llvm::SmallBitVector filterDims(conv.getNumLoops(), false);
for (unsigned idx : convDims->filterLoop) {
filterDims.set(idx);
}

SmallVector<AffineMap> maps = conv.getIndexingMapsArray();
for (AffineMap &map : maps) {
map = projectDims(map, filterDims, /*compressDimsFlag=*/false);
}

FailureOr<VectorContractOpInfo> opInfo =
VectorContractOpInfo::inferFromIndexingMaps(maps);
assert(succeeded(opInfo) &&
"unit filter dim convolution should have been infered");

auto layouts = schedule.getContractionLayout(opInfo.value(), conv);
if (failed(layouts)) {
return conv->emitError("cannot get concrete layout for convolution");
}

auto [aLayout, bLayout, cLayout] = *layouts;
Location loc = conv.getLoc();

Value lhs = conv->getOperand(0);
Value rhs = conv->getOperand(1);
Value acc = conv->getOperand(2);

// Set layouts for lhs, rhs and acc.
rewriter.setInsertionPoint(conv);
auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, lhs.getType(), lhs, aLayout);
auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, rhs.getType(), rhs, bLayout);
auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, acc.getType(), acc, cLayout);

// Promote matmul lhs and rhs.
// TODO: We should read this from the lowering_config on the operation.
// TODO: This is a hack until layout analysis is improved. The layout analysis
// should decide where to put these shared memory conversions.
layoutedLhs.setSharedMemoryConversion(true);
layoutedRhs.setSharedMemoryConversion(true);

conv->setOperand(0, layoutedLhs.getResult());
conv->setOperand(1, layoutedRhs.getResult());
conv->setOperand(2, layoutedAcc.getResult());

// Set layout for result.
rewriter.setInsertionPointAfter(conv);
auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, conv->getResult(0).getType(), conv->getResult(0), cLayout);
rewriter.replaceAllUsesExcept(conv->getResult(0), toLayout.getResult(),
toLayout);

return success();
}

struct LLVMGPUConfigureTensorLayoutsPass final
: impl::LLVMGPUConfigureTensorLayoutsPassBase<
LLVMGPUConfigureTensorLayoutsPass> {
Expand All @@ -94,13 +177,16 @@ struct LLVMGPUConfigureTensorLayoutsPass final
auto scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
configDict.get(scheduleAttrName));

// Vector layout option setter aimed at contractions. For now, layout
// setting for other problems like reductions is TODO.
// Vector layout option setter aimed at contractions and convolutions. For
// now, layout setting for other problems like reductions is TODO.
SmallVector<linalg::LinalgOp> contracts;
SmallVector<linalg::LinalgOp> convs;

func->walk([&](linalg::LinalgOp linalgOp) {
if (linalg::isaContractionOpInterface(linalgOp)) {
contracts.push_back(linalgOp);
} else if (succeeded(linalg::inferConvolutionDims(linalgOp))) {
convs.push_back(linalgOp);
}
});

Expand All @@ -111,6 +197,12 @@ struct LLVMGPUConfigureTensorLayoutsPass final
return signalPassFailure();
}
}

for (linalg::LinalgOp conv : convs) {
if (failed(setConvolutionAnchor(scheduleAttr, rewriter, conv))) {
return signalPassFailure();
}
}
}
};
} // namespace
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,9 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Set anchors at tensor level for vector distribution later.
funcPassManager.addPass(createLLVMGPUConfigureTensorLayoutsPass());

// Generalize all named ops so that we can fold away unit extent dims. By this
// point, all tiling is finished so the tiling configurations on those ops can
// be safely dropped. This additionally allows vectorization of convolution to
Expand All @@ -829,14 +832,14 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
if (!usePadToModelSharedMemcpy) {
LinalgFoldUnitExtentDimsPassOptions options;
options.useRankReducingSlices = true;
funcPassManager.addPass(
IREE::VectorExt::createVectorExtFoldUnitExtentDimsPass());
funcPassManager.addPass(mlir::createLinalgFoldUnitExtentDimsPass(options));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}

funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
// Set anchors at tensor level for vector distribution later.
funcPassManager.addPass(createLLVMGPUConfigureTensorLayoutsPass());

// Linalg -> Vector
addGPUVectorizationPasses(funcPassManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,19 +715,17 @@ hal.executable private @attention_20x4096x64x4096x64 {
// CHECK-SAME: translation_info = #[[$TRANSLATION]]

// CHECK: scf.for %{{.*}} = %c0 to %c4096 step %c64
// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x4x1x1x4x1xf16>, vector<2x1x4xf32>)
// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x4x1x1x4x1xf16>)
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>,
#hal.descriptor_set.binding<3, storage_buffer>
]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
hal.executable private @attention_multiple_m_transpose {
hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
Expand All @@ -740,10 +738,10 @@ hal.executable private @attention_multiple_m_transpose {
func.func @attention_multiple_m_transpose() {
%cst = arith.constant 1.0 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<24x64x4608x128xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>>
%3 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x4608x24x128xf16>>
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<24x64x4608x128xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>>
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x4608x24x128xf16>>
%4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [24, 64, 4608, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x64x4608x128xf16>> -> tensor<24x64x4608x128xf16>
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
Expand All @@ -769,6 +767,6 @@ hal.executable private @attention_multiple_m_transpose {
// CHECK-SAME: translation_info = #[[$TRANSLATION]]

// CHECK: scf.for %{{.*}} = %c0 to %c72 step %c1
// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x8x1x1x4x1xf16>, vector<2x1x4xf32>)
// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf16>)
// CHECK-COUNT-128: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield
Loading

0 comments on commit dbbc56e

Please sign in to comment.