Skip to content

Commit

Permalink
[LLVMGPUVectorDistribute] Refactor vector.contract distribute (#19631)
Browse files Browse the repository at this point in the history
Currently, vector.contract distribution is implemented as a standalone
distribution closely following vector.multi_reduce. Therefore, we have
to duplicate code/effort when we improve either one.

This commit changes vector.contract just to distribute the "contract"
part of it. Then it creates a new vector.multi_reduce to be
re-distributed with partial reduction semantics. Thus, allowing the
improvements of vector.multi_reduce to be re-used by vector.contract

closes : #19620

---------

Signed-off-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
manupak authored Jan 22, 2025
1 parent 6933c39 commit 03c5a0f
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,6 @@ struct DistributeMultiReduction final
}

Type elemTy = srcVector.getType().getElementType();
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
if (elemBitwidth != maxBitsPerShuffle) {
return rewriter.notifyMatchFailure(
multiReduceOp, llvm::formatv("unimplemented: packed shuffle",
elemBitwidth, maxBitsPerShuffle));
}

VectorValue disSrc =
getDistributed(rewriter, srcVector, signature[srcVector]);

Expand Down Expand Up @@ -770,24 +763,18 @@ struct DistributeMultiReduction final
int64_t maxBitsPerShuffle;
};

/// The lowering for Contract is performed in three steps (similar to above
/// multi_reduction):
/// 1. Local Contract: Each thread performs operations on its locally
/// distributed elements.
/// 2. Subgroup Reduction: Threads in each subgroup reduce the results from
/// step 1 across threads using a subgroup reduction if distribution occurs
/// along the reduction dimension.
/// 3. Accumulator Reduction: Each thread combines its intermediate results
/// with its held accumulator.
///
/// Currently, reduction across multiple warps is not supported.
/// The distribution of contract is performed by doing a local contraction where
/// each thread performs operations on its locally distributed elements. Then,
/// the resulting vector is interpreted in undistributed domain. The said
/// undistributed vector is a partial reduction when contraction has been
/// performed only thread locally. Therefore, a to-be-distributed
/// vector.multi_reduce
////is added to complete the contraction.
struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
using OpDistributionPattern::OpDistributionPattern;

DistributeContract(MLIRContext *context, int64_t subgroupSize,
int64_t maxBitsPerShuffle, int64_t benefit = 1)
: OpDistributionPattern(context, benefit), subgroupSize(subgroupSize),
maxBitsPerShuffle(maxBitsPerShuffle) {}
DistributeContract(MLIRContext *context, int64_t benefit = 1)
: OpDistributionPattern(context, benefit) {}

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature,
Expand Down Expand Up @@ -817,6 +804,16 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
return rewriter.notifyMatchFailure(
contractOp, "missing nested layout for contraction rhs");
}
NestedLayoutAttr resLayout;
if (auto contractRes = dyn_cast<VectorValue>(contractOp.getResult())) {
resLayout = dyn_cast<NestedLayoutAttr>(signature[contractRes]);
} else {
// Create a zero-d layout because we
// are going to add reduction dims
// back to handle the partial reduction
resLayout = NestedLayoutAttr::get(
contractOp.getContext(), ArrayRef<int64_t>{}, {}, {}, {}, {}, {}, {});
}

Value disLhs = getDistributed(rewriter, contractOp.getLhs(), lhsLayout);
Value disRhs = getDistributed(rewriter, contractOp.getRhs(), rhsLayout);
Expand All @@ -838,21 +835,10 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
Location loc = contractOp.getLoc();

// Step 1: local contraction
Value localInit = getCombiningIdentityValue(
loc, rewriter, contractOp.getKind(), disAcc.getType());
vector::ContractionOp localContractOp = doDistributedContraction(
rewriter, loc, ctx, contractOp, disLhs, disRhs, disAcc);

int64_t rank = lhsLayout.getRank();
SmallVector<bool> reducedDims(rank, false);

// Identify the reduction dimension and apply it for subgroup reduction.
for (auto [index, iteratorType] :
llvm::enumerate(contractOp.getIteratorTypes())) {
if (vector::isReductionIterator(iteratorType)) {
auto map = contractOp.getIndexingMapsArray()[0];
int64_t redIdx = *(map.getResultPosition(getAffineDimExpr(index, ctx)));
reducedDims[redIdx] = true;
}
}
rewriter, loc, ctx, contractOp, disLhs, disRhs, localInit);

VectorValue localContractValue;
if (accVector) {
Expand All @@ -865,46 +851,79 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

assert(localContractValue && "result should have been a vector");

// Flatten the locally result value.
VectorType shaped = localContractValue.getType();
int64_t numElements = shaped.getNumElements();
SmallVector<int64_t> flatShape(1, numElements);
VectorType flatVecType = VectorType::get(flatShape, accElemTy);
VectorValue flat = rewriter.create<vector::ShapeCastOp>(loc, flatVecType,
localContractValue);

// Step 2: Do subgroup reduction.
FailureOr<VectorValue> threadReduced = doThreadReduction(
rewriter, lhsLayout, flat, contractOp.getKind(), reducedDims);
if (failed(threadReduced)) {
return failure();
}

// Do reduction against accumulator, which needs to be done after thread
// reduction.
VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
loc, shaped, threadReduced.value());

if (!accVector) {
disAcc = rewriter.create<vector::BroadcastOp>(loc, shaped, disAcc);
}

// Step 3: Accumulator Reduction
Value accReduction = vector::makeArithReduction(
rewriter, loc, contractOp.getKind(), unflattened, disAcc);
auto accReduced = dyn_cast<VectorValue>(accReduction);
if (!accReduced) {
return failure();
// Identify the reduction dimension and apply it for subgroup reduction.
auto lhsMap = contractOp.getIndexingMapsArray()[0];
SmallVector<int64_t> reductionSubGroupTile;
SmallVector<int64_t> reductionSubGroupStrides;
SmallVector<int64_t> reductionThreadTile;
SmallVector<int64_t> reductionThreadStrides;
SmallVector<int64_t> partialReductionDims;
for (auto [index, iteratorType] :
llvm::enumerate(contractOp.getIteratorTypes())) {
if (vector::isReductionIterator(iteratorType)) {
int64_t redLhsIdx =
*(lhsMap.getResultPosition(getAffineDimExpr(index, ctx)));
partialReductionDims.push_back(resLayout.getRank() +
reductionSubGroupTile.size());
reductionSubGroupTile.push_back(lhsLayout.getSubgroupTile()[redLhsIdx]);
reductionSubGroupStrides.push_back(
lhsLayout.getSubgroupStrides()[redLhsIdx]);
reductionThreadTile.push_back(lhsLayout.getThreadTile()[redLhsIdx]);
reductionThreadStrides.push_back(
lhsLayout.getThreadStrides()[redLhsIdx]);
}
}

if (resVector) {
replaceOpWithDistributedValues(rewriter, contractOp, accReduced);
} else {
Value accReducedVal = rewriter.create<vector::ExtractOp>(
loc, accReduction, SmallVector<int64_t>{0});
replaceOpWithDistributedValues(rewriter, contractOp, accReducedVal);
SmallVector<int64_t> unitBroadcastTile(reductionThreadTile.size(), 1);

// Manually infer the layout of partial reduction
// We do this by appending the reduction dims on
// subgroup and thread tiles to the layout of the
// result.
IREE::VectorExt::NestedLayoutAttr reductionLayout =
IREE::VectorExt::NestedLayoutAttr::get(
contractOp.getContext(),
/*source=*/resLayout,
/*appendSubGroupLens=*/reductionSubGroupTile,
/*appendBatchLens=*/unitBroadcastTile,
/*appendOuterLens=*/unitBroadcastTile,
/*appendThreadLens=*/reductionThreadTile,
/*appendElementLens=*/unitBroadcastTile,
/*appendSubgroupStrides=*/reductionSubGroupStrides,
/*appendThreadStrides=*/reductionThreadStrides);

VectorType partialReducedDistributedType =
VectorType::get(reductionLayout.getDistributedShape(),
localContractValue.getType().getElementType());
Value shapeCasted = rewriter.create<vector::ShapeCastOp>(
loc, partialReducedDistributedType, localContractValue);
VectorType unDistributedType =
VectorType::get(reductionLayout.getUndistributedShape(),
localContractValue.getType().getElementType());
Value undistrLocalReduced = rewriter.create<IREE::VectorExt::ToSIMDOp>(
loc, unDistributedType, shapeCasted);

// Create the partial reduction
auto partialReduction = rewriter.create<vector::MultiDimReductionOp>(
loc, contractOp.getKind(), undistrLocalReduced, acc,
partialReductionDims);
{
auto unitAttr = UnitAttr::get(rewriter.getContext());
auto reduceAttrs =
SmallVector<Attribute>(partialReduction->getNumOperands(), unitAttr);
reduceAttrs[0] = reductionLayout;
ArrayAttr reduceResultsAttr =
ArrayAttr::get(rewriter.getContext(), {unitAttr});
if (auto dstLayout =
dyn_cast_or_null<NestedLayoutAttr>(signature[resVector])) {
reduceAttrs[1] = dstLayout;
reduceResultsAttr = ArrayAttr::get(rewriter.getContext(), {dstLayout});
}
ArrayAttr reduceOperandsAttr =
ArrayAttr::get(rewriter.getContext(), reduceAttrs);
setSignatureForRedistribution(rewriter, partialReduction.getOperation(),
reduceOperandsAttr, reduceResultsAttr);
}

rewriter.replaceOp(contractOp, partialReduction);
return success();
}

Expand Down Expand Up @@ -954,46 +973,6 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

return localContractOp;
}

FailureOr<VectorValue> doThreadReduction(RewriterBase &rewriter,
NestedLayoutAttr layout,
VectorValue flat,
vector::CombiningKind kind,
ArrayRef<bool> reductionMask) const {
VectorType flatVecType = flat.getType();
int64_t numElements = flatVecType.getNumElements();
Location loc = flat.getLoc();

auto constOp = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(flatVecType));
auto res = llvm::cast<VectorValue>(constOp.getResult());

for (unsigned i = 0; i < numElements; ++i) {
Value extracted = rewriter.create<vector::ExtractOp>(loc, flat, i);

// Reduce across all reduction dimensions 1-by-1.
for (unsigned i = 0, e = reductionMask.size(); i != e; ++i) {
if (reductionMask[i]) {
int64_t offset = getShuffleOffset(layout, i);
int64_t width = getShuffleWidth(layout, i);
assert(offset <= std::numeric_limits<uint32_t>::max() &&
width <= std::numeric_limits<uint32_t>::max());

extracted = rewriter.create<gpu::SubgroupReduceOp>(
loc, extracted, combiningKindToAllReduce(kind),
/*uniform=*/false, /*cluster_size=*/width,
/*cluster_stride=*/offset);
}
}

res = rewriter.create<vector::InsertOp>(loc, extracted, res, i);
}

return res;
}

int64_t subgroupSize;
int64_t maxBitsPerShuffle;
};

struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
Expand Down Expand Up @@ -1344,8 +1323,7 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
maxBitsPerShuffle);
patterns.add<DistributeContract>(patterns.getContext(), subgroupSize,
maxBitsPerShuffle);
patterns.add<DistributeContract>(patterns.getContext());
patterns.add<DistributeBatchOuterToLayoutConversions>(patterns.getContext());
patterns.add<DistributeStep>(patterns.getContext(), threadId, subgroupSize);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ SmallVector<int64_t> NestedLayoutAttr::getUndistributedPackedShape() const {
return shape;
}

SmallVector<int64_t> NestedLayoutAttr::getUndistributedShape() const {
int64_t rank = getRank();
SmallVector<int64_t> shape;
shape.reserve(rank);
for (int64_t i : llvm::seq<int64_t>(rank)) {
int64_t expectedDimLen = getSubgroupTile()[i] * getBatchTile()[i] *
getOuterTile()[i] * getThreadTile()[i] *
getElementTile()[i];
shape.push_back(expectedDimLen);
}
return shape;
}

// Gets the rank of the undistributed vector for this layout.
int64_t NestedLayoutAttr::getRank() const {
// The layout requires that all size lists are the same length and match
Expand Down Expand Up @@ -198,6 +211,42 @@ NestedLayoutAttr NestedLayoutAttr::get(
normalizedThreadStrides);
}

static SmallVector<int64_t> appendDims(ArrayRef<int64_t> tileLens,
ArrayRef<int64_t> appendLens) {
SmallVector<int64_t> tileLensResult = llvm::to_vector(tileLens);
tileLensResult.insert(tileLensResult.end(), appendLens.begin(),
appendLens.end());
return tileLensResult;
}

NestedLayoutAttr NestedLayoutAttr::get(MLIRContext *context,
NestedLayoutAttr source,
ArrayRef<int64_t> appendSubGroupLens,
ArrayRef<int64_t> appendBatchLens,
ArrayRef<int64_t> appendOuterLens,
ArrayRef<int64_t> appendThreadLens,
ArrayRef<int64_t> appendElementLens,
ArrayRef<int64_t> appendSubgroupStrides,
ArrayRef<int64_t> appendThreadStrides) {
SmallVector<int64_t> subgroupTile =
appendDims(source.getSubgroupTile(), appendSubGroupLens);
SmallVector<int64_t> batchTile =
appendDims(source.getBatchTile(), appendBatchLens);
SmallVector<int64_t> outerTile =
appendDims(source.getOuterTile(), appendOuterLens);
SmallVector<int64_t> threadTile =
appendDims(source.getThreadTile(), appendThreadLens);
SmallVector<int64_t> elementTile =
appendDims(source.getElementTile(), appendElementLens);
SmallVector<int64_t> subgroupStrides =
appendDims(source.getSubgroupStrides(), appendSubgroupStrides);
SmallVector<int64_t> threadStrides =
appendDims(source.getThreadStrides(), appendThreadStrides);
return NestedLayoutAttr::get(context, subgroupTile, batchTile, outerTile,
threadTile, elementTile, subgroupStrides,
threadStrides);
}

LogicalResult NestedLayoutAttr::verify(
llvm::function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> subgroupTile, ArrayRef<int64_t> batchTile,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,15 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
"ArrayRef<int64_t>":$threadTile,
"ArrayRef<int64_t>":$elementTile,
"ArrayRef<int64_t>":$subgroupStrides,
"ArrayRef<int64_t>":$threadStrides)>
"ArrayRef<int64_t>":$threadStrides)>,
AttrBuilder<(ins "NestedLayoutAttr":$source,
"ArrayRef<int64_t>":$appendSubGroupLens,
"ArrayRef<int64_t>":$appendBatchLens,
"ArrayRef<int64_t>":$appendOuterLens,
"ArrayRef<int64_t>":$appendThreadLens,
"ArrayRef<int64_t>":$appendElementLens,
"ArrayRef<int64_t>":$appendSubgroupStrides,
"ArrayRef<int64_t>":$appendThreadStrides)>
];

let extraClassDeclaration = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def VectorLayoutInterface : AttrInterface<"VectorLayoutInterface"> {
/*methodName=*/"project",
/*args=*/(ins "::llvm::ArrayRef<bool>":$droppedDims)
>,
InterfaceMethod<
/*description=*/"Get the expected undistributed shape for the given vector type.",
/*retTy=*/"SmallVector<int64_t>",
/*methodName=*/"getUndistributedShape",
/*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Get the distributed shape for the given vector type.",
/*retTy=*/"SmallVector<int64_t>",
Expand Down

0 comments on commit 03c5a0f

Please sign in to comment.