Skip to content

Commit

Permalink
[VectorDistribution] Use to_layout to set anchors for LLVMGPUVectorDi…
Browse files Browse the repository at this point in the history
…stribute pass (iree-org#18044)

This patch makes LLVMGPUVectorDistribute pass use to_layout operations
to set layout anchors instead of directly setting them on the analysis.
This allows for better readability of what anchors are being set.

This patch allows the layout anchoring and the distribution to be split
up, which will be done in future patches.
  • Loading branch information
Groverkss authored Jul 31, 2024
1 parent 2c53b4a commit 388ebd2
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,6 @@ LogicalResult distributeVectorOps(Operation *root,
// Run the analysis and determine the layouts.
LLVM_DEBUG(llvm::dbgs() << "Running Layout Analysis\n");
VectorLayoutAnalysis analysis(root);
if (failed(options.setAnchorOps(analysis)))
return failure();
if (failed(analysis.run()))
return failure();
LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Succeded\n");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,6 @@ class VectorLayoutOptions {

virtual ~VectorLayoutOptions() = default;

/// Set the anchor ops in the analysis rooted on the root operation.
virtual LogicalResult setAnchorOps(VectorLayoutAnalysis &analysis) = 0;

bool verifyConversion() const { return fullConversion; }

protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1100,10 +1100,6 @@ class TestVectorLayoutOptions : public VectorLayoutOptions {
public:
TestVectorLayoutOptions(Operation *root)
: VectorLayoutOptions(root, /*fullConversion=*/false) {}

LogicalResult setAnchorOps(VectorLayoutAnalysis &analysis) override {
return success();
}
};

DiagnosedSilenceableFailure
Expand Down
11 changes: 0 additions & 11 deletions compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,17 +1023,6 @@ DistributionLayout *EnforceLayout::getLatticeElement(Value val) {
/// VectorLayoutAnalysis
/// ==========================================================================

LogicalResult VectorLayoutAnalysis::setAnchor(Value val,
VectorLayoutInterface layout) {
auto typedVal = dyn_cast<TypedValue<VectorType>>(val);
assert(typedVal && "expected value to be a vector type");
if (layout.isValidLayout(typedVal).failed()) {
return failure();
}
anchors[typedVal] = cast<VectorLayoutInterface>(layout);
return success();
}

LogicalResult VectorLayoutAnalysis::run() {
// The order of loading matters here, because propagateLayout does anchoring
// initialization which needs the lattice to know both enforcement and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ class VectorLayoutAnalysis {
public:
VectorLayoutAnalysis(Operation *root) : root(root) {}

/// Fix the layout for a specific value. Returns failure if the layout set is
/// invalid for the value.
LogicalResult setAnchor(Value val, VectorLayoutInterface layout);

/// Run the analysis. The analysis expects that the user has set some anchor
/// points and is trying to infer the layout of other values.
LogicalResult run();
Expand Down
111 changes: 73 additions & 38 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,34 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions {
subgroupSize);
}

LogicalResult setAnchorOps(VectorLayoutAnalysis &analysis) override {
LogicalResult setAnchorOps(RewriterBase &rewriter) {
MLIRContext *context = root->getContext();
WalkResult walkResult = root->walk([&](Operation *op) {
LogicalResult setResult =
llvm::TypeSwitch<Operation *, LogicalResult>(op)
.Case([&](vector::ContractionOp contract) {
return setContractionAnchor(context, analysis, contract);
})
.Case([&](vector::TransferReadOp transfer) {
return setTransferReadAnchor(context, analysis, transfer);
})
.Default([](Operation *) { return success(); });
return failed(setResult) ? WalkResult::interrupt()
: WalkResult::advance();
SmallVector<vector::TransferReadOp> reads;
SmallVector<vector::ContractionOp> contracts;

root->walk([&](Operation *op) {
llvm::TypeSwitch<Operation *>(op)
.Case([&](vector::TransferReadOp transfer) {
reads.push_back(transfer);
})
.Case([&](vector::ContractionOp contract) {
contracts.push_back(contract);
});
});
return failure(walkResult.wasInterrupted());

for (vector::TransferReadOp read : reads) {
if (failed(setTransferReadAnchor(context, rewriter, read))) {
return failure();
}
}

for (vector::ContractionOp contract : contracts) {
if (failed(setContractionAnchor(context, rewriter, contract))) {
return failure();
}
}

return success();
}

RewritePatternSet &getPatterns() { return patterns; }
Expand All @@ -92,7 +104,7 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions {
// supported mma type from the cached list of mma types and populates the
// necessary distribution pattern for those contractions.
LogicalResult setContractionAnchor(MLIRContext *context,
VectorLayoutAnalysis &analysis,
RewriterBase &rewriter,
vector::ContractionOp contract) {
// TODO: Add SIMT fallback.
if (!schedule) {
Expand All @@ -105,19 +117,29 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions {
}

auto [aLayout, bLayout, cLayout] = *layouts;
if (analysis.setAnchor(contract.getLhs(), aLayout).failed()) {
return failure();
}
if (analysis.setAnchor(contract.getRhs(), bLayout).failed()) {
return failure();
}
if (analysis.setAnchor(contract.getAcc(), cLayout).failed()) {
return failure();
}
if (analysis.setAnchor(contract.getResult(), cLayout).failed()) {
return failure();
}
Location loc = contract.getLoc();

// Set layouts for lhs, rhs and acc.
rewriter.setInsertionPoint(contract);
Value layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, contract.getLhsType(), contract.getLhs(), aLayout);
Value layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, contract.getRhsType(), contract.getRhs(), bLayout);
Value layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, contract.getAccType(), contract.getAcc(), cLayout);
contract->setOperand(0, layoutedLhs);
contract->setOperand(1, layoutedRhs);
contract->setOperand(2, layoutedAcc);

// Set layout for result.
rewriter.setInsertionPointAfter(contract);
auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, contract.getResultType(), contract.getResult(), cLayout);
rewriter.replaceAllUsesExcept(contract, toLayout.getResult(), toLayout);

// Set intrinsic kind.
contract->setAttr("iree.amdgpu.mma", schedule.getIntrinsic());

if (printLayout) {
llvm::outs() << "contract A vector layout: " << aLayout << "\n";
llvm::outs() << "contract B vector layout: " << bLayout << "\n";
Expand Down Expand Up @@ -168,7 +190,7 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions {
//
// *_order = [0, 1]>
LogicalResult setTransferReadAnchor(MLIRContext *context,
VectorLayoutAnalysis &analysis,
RewriterBase &rewriter,
vector::TransferReadOp transfer) {

// Get the forward slice of the transfer to approximate whether it will take
Expand Down Expand Up @@ -332,9 +354,13 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions {
auto layout = IREE::VectorExt::NestedLayoutAttr::get(
context, subgroupCounts, batchSizes, outerSizes, threadCounts,
elementSizes, subgroupStrides, threadStrides);
if (analysis.setAnchor(transfer.getResult(), layout).failed()) {
return failure();
}

Location loc = transfer.getLoc();
rewriter.setInsertionPointAfter(transfer);
auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, transfer.getResult().getType(), transfer.getResult(), layout);
rewriter.replaceAllUsesExcept(transfer, toLayout.getResult(), toLayout);

if (printLayout) {
llvm::outs() << "transfer '" << transfer << "' vector layout: " << layout
<< "\n";
Expand Down Expand Up @@ -403,16 +429,18 @@ struct LLVMGPUVectorDistributePass
AffineExpr linearId =
x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z;

OpBuilder builder(func);
builder.setInsertionPointToStart(&func.getFunctionBody().front());
IRRewriter rewriter(func);
rewriter.setInsertionPointToStart(&func.getFunctionBody().front());
SmallVector<OpFoldResult> threadGrid = {
builder.createOrFold<gpu::ThreadIdOp>(func.getLoc(), gpu::Dimension::x),
builder.createOrFold<gpu::ThreadIdOp>(func.getLoc(), gpu::Dimension::y),
builder.createOrFold<gpu::ThreadIdOp>(func.getLoc(),
gpu::Dimension::z)};
rewriter.createOrFold<gpu::ThreadIdOp>(func.getLoc(),
gpu::Dimension::x),
rewriter.createOrFold<gpu::ThreadIdOp>(func.getLoc(),
gpu::Dimension::y),
rewriter.createOrFold<gpu::ThreadIdOp>(func.getLoc(),
gpu::Dimension::z)};

Value linearThreadIdVal = affine::makeComposedAffineApply(
builder, func.getLoc(), linearId, threadGrid);
rewriter, func.getLoc(), linearId, threadGrid);

std::optional<int64_t> subgroupSize = getSubgroupSize(func);
if (!subgroupSize) {
Expand All @@ -424,6 +452,13 @@ struct LLVMGPUVectorDistributePass
ContractionVectorLayoutOptions options(func, workgroupSize, scheduleAttr,
linearThreadIdVal,
subgroupSize.value(), testLayout);

// Set anchor layouts.
if (failed(options.setAnchorOps(rewriter))) {
func->emitError() << "failed to set anchors";
return signalPassFailure();
}

if (failed(distributeVectorOps(func, options.getPatterns(), options))) {
func->emitOpError() << "failed to distribute";
return signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1474,10 +1474,6 @@ class TransformVectorLayoutOptions : public VectorLayoutOptions {
public:
TransformVectorLayoutOptions(Operation *root, bool fullConversion)
: VectorLayoutOptions(root, fullConversion) {}

LogicalResult setAnchorOps(VectorLayoutAnalysis &analysis) override {
return success();
}
};

DiagnosedSilenceableFailure
Expand Down

0 comments on commit 388ebd2

Please sign in to comment.