Skip to content

Commit

Permalink
[LLVMGPUVectorDistribute] Support vector.mask + vector.contract
Browse files Browse the repository at this point in the history
This commit primariy adds support to distribute
masked vector.contract.

Firstly, it changes the VectorLayoutInference to propogate
the layouts from contract operands into the contraction mask.
In order to do this, a new builder is added to the NestedLayoutAttr
which can extract and concat from operand layouts using the indexing
maps of the vector.contract.

Secondly, in the distribution, the distributed mask is
projected onto the operands to perform a selection between
the original operand and reduction identity to cater for
non thread-local contraction. Moreover, the distributed
mask is applied to the thread-local contraction.

Signed-off-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
manupak committed Feb 4, 2025
1 parent 2ed8a16 commit 95d0103
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,36 @@ static LogicalResult populateWarpAndThreadIndices(
return success();
}

/// Project a vector based on a provided projection map.
/// Firstly, this will tranpose the vector in a way sliced out
/// dims become outermost. Then it performs a vector.extract
/// remove the dims that are not present in the results of the map.
static VectorValue projectVector(RewriterBase &rewriter, Location loc,
VectorValue val, AffineMap projectionMap) {
llvm::SmallVector<int64_t> remaningDims;
SmallVector<int64_t> allDims =
llvm::to_vector(llvm::seq<int64_t>(projectionMap.getNumDims()));
llvm::SmallDenseSet<int64_t> slicedDims{allDims.begin(), allDims.end()};
for (int64_t resultIdx : llvm::seq<int64_t>(projectionMap.getNumResults())) {
int64_t iterSpacePos = projectionMap.getDimPosition(resultIdx);
remaningDims.push_back(iterSpacePos);
slicedDims.erase(iterSpacePos);
}

SmallVector<int64_t> transposePerm;
for (int64_t slicedDim : slicedDims) {
transposePerm.push_back(slicedDim);
}
transposePerm.append(remaningDims);
auto transposed =
rewriter.create<vector::TransposeOp>(loc, val, transposePerm);

SmallVector<int64_t> extractedPos(slicedDims.size(), 0);
auto sliced =
rewriter.create<vector::ExtractOp>(loc, transposed, extractedPos);
return cast<VectorValue>(sliced.getResult());
}

namespace {

/// Pattern to distribute `vector.transfer_read` ops with nested layouts.
Expand Down Expand Up @@ -931,15 +961,19 @@ struct DistributeMultiReduction final
/// performed only thread locally. Therefore, a to-be-distributed
/// vector.multi_reduce
////is added to complete the contraction.
struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
using OpDistributionPattern::OpDistributionPattern;
struct DistributeContract final
: MaskedOpDistributionPattern<vector::ContractionOp> {
using MaskedOpDistributionPattern::MaskedOpDistributionPattern;

DistributeContract(MLIRContext *context, int64_t benefit = 1)
: OpDistributionPattern(context, benefit) {}
: MaskedOpDistributionPattern(context, benefit) {}

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature, vector::MaskOp maskOp,
std::optional<DistributionSignature> &maskSignature,
PatternRewriter &rewriter) const override {
Location loc = contractOp.getLoc();
FailureOr<VectorContractOpInfo> maybeOpInfo =
VectorContractOpInfo::inferFromIndexingMaps(
contractOp.getIndexingMapsArray());
Expand Down Expand Up @@ -979,6 +1013,44 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
Value disLhs = getDistributed(rewriter, contractOp.getLhs(), lhsLayout);
Value disRhs = getDistributed(rewriter, contractOp.getRhs(), rhsLayout);

VectorValue mask = nullptr;
if (maskOp) {
auto maskLayout = dyn_cast_or_null<NestedLayoutAttr>(
maskSignature.value()[maskOp.getMask()]);
if (!maskLayout) {
return rewriter.notifyMatchFailure(maskOp,
"expected nested layout attr");
}
mask = getDistributed(rewriter, maskOp.getMask(), maskLayout);
Value passThruLhs = getCombiningIdentityValue(
loc, rewriter, contractOp.getKind(), disLhs.getType());
Value passThruRhs = getCombiningIdentityValue(
loc, rewriter, contractOp.getKind(), disRhs.getType());

VectorValue deInterleavedMask =
getDeinterleavedUnpackedForm(rewriter, mask, maskLayout);
VectorValue maskLhs = projectVector(rewriter, loc, deInterleavedMask,
contractOp.getIndexingMapsArray()[0]);
VectorValue interleavedMaskLhs =
getInterleavedPackedForm(rewriter, maskLhs, lhsLayout);

VectorValue maskRhs = projectVector(rewriter, loc, deInterleavedMask,
contractOp.getIndexingMapsArray()[1]);
VectorValue interleavedMaskRhs =
getInterleavedPackedForm(rewriter, maskRhs, rhsLayout);

disLhs = cast<VectorValue>(
rewriter
.create<arith::SelectOp>(loc, interleavedMaskLhs, disLhs,
passThruLhs)
.getResult());
disRhs = cast<VectorValue>(
rewriter
.create<arith::SelectOp>(loc, interleavedMaskRhs, disRhs,
passThruRhs)
.getResult());
}

Value acc = contractOp.getAcc();
Value res = contractOp.getResult();
auto accVector = dyn_cast<VectorValue>(acc);
Expand All @@ -993,21 +1065,25 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
Type accElemTy = getElementTypeOrSelf(acc.getType());

MLIRContext *ctx = contractOp.getContext();
Location loc = contractOp.getLoc();

// Step 1: local contraction
Value localInit = getCombiningIdentityValue(
loc, rewriter, contractOp.getKind(), disAcc.getType());
vector::ContractionOp localContractOp = doDistributedContraction(
Value localContract = doDistributedContraction(
rewriter, loc, ctx, contractOp, disLhs, disRhs, localInit);
if (mask) {
localContract =
vector::maskOperation(rewriter, localContract.getDefiningOp(), mask)
->getResult(0);
}

VectorValue localContractValue;
if (accVector) {
localContractValue = dyn_cast<VectorValue>(localContractOp.getResult());
localContractValue = dyn_cast<VectorValue>(localContract);
} else {
VectorType vecType = VectorType::get(ArrayRef{int64_t(1)}, accElemTy);
localContractValue = rewriter.create<vector::BroadcastOp>(
loc, vecType, localContractOp.getResult());
localContractValue =
rewriter.create<vector::BroadcastOp>(loc, vecType, localContract);
}

assert(localContractValue && "result should have been a vector");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,93 @@ builtin.module attributes { transform.with_named_sequence } {

// CHECK: %[[SELECT:.+]] = arith.select %[[MASK_ITL_PCK]], {{.*}}, %[[RED_IDENTITY]] : vector<2x1x2x1x2x8xi1>, vector<2x1x2x1x2x8xf16>
// CHECK: vector.mask %[[MASK_ITL_PCK]] { vector.multi_reduction <add>, %[[SELECT]], {{.*}} [0, 2, 4] : vector<2x1x2x1x2x8xf16> to vector<1x1x8xf16> } : vector<2x1x2x1x2x8xi1> -> vector<1x1x8xf16>

// -----

#lhs = #iree_vector_ext.nested_layout<
subgroup_tile = [2],
batch_tile = [2],
outer_tile = [2],
thread_tile = [16],
element_tile = [2],

subgroup_strides = [1],
thread_strides = [1]
>

#rhs = #iree_vector_ext.nested_layout<
subgroup_tile = [2, 2],
batch_tile = [2, 2],
outer_tile = [2, 2],
thread_tile = [8, 16],
element_tile = [2, 2],

subgroup_strides = [2, 1],
thread_strides = [16, 1]
>

#out = #iree_vector_ext.nested_layout<
subgroup_tile = [2],
batch_tile = [2],
outer_tile = [2],
thread_tile = [8],
element_tile = [2],

subgroup_strides = [1],
thread_strides = [1]
>

func.func @masked_read_write_contract(%arg0 : memref<?xf16>, %arg1 : memref<?x?xf16>, %arg2 : memref<?xf16>) {
%c0 = arith.constant 0 : index
%cst_6 = arith.constant 0.000000e+00 : f16
%acc = arith.constant dense<0.000000e+00> : vector<128xf16>

%reddim = memref.dim %arg0, %c0 : memref<?xf16>
%pardim = memref.dim %arg1, %c0 : memref<?x?xf16>
%arg0mask = vector.create_mask %reddim : vector<256xi1>
%arg1mask = vector.create_mask %pardim, %reddim : vector<128x256xi1>
%arg2mask = vector.create_mask %pardim : vector<128xi1>
%opmask = vector.create_mask %reddim, %pardim : vector<256x128xi1>

%arg0read = vector.transfer_read %arg0[%c0], %cst_6, %arg0mask {in_bounds = [true]} : memref<?xf16>, vector<256xf16>
%arg0readl = iree_vector_ext.to_layout %arg0read to layout(#lhs) : vector<256xf16>
%arg1read = vector.transfer_read %arg1[%c0, %c0], %cst_6, %arg1mask {in_bounds = [true, true]} : memref<?x?xf16>, vector<128x256xf16>
%arg1readl = iree_vector_ext.to_layout %arg1read to layout(#rhs) : vector<128x256xf16>
%gemm = vector.mask %opmask { vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind<add>} %arg0readl, %arg1readl, %acc : vector<256xf16>, vector<128x256xf16> into vector<128xf16> } : vector<256x128xi1> -> vector<128xf16>
%gemml = iree_vector_ext.to_layout %gemm to layout(#out) : vector<128xf16>
vector.transfer_write %gemml, %arg2[%c0], %arg2mask {in_bounds = [true]} : vector<128xf16>, memref<?xf16>

return
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @masked_read_write_contract

// CHECK-DAG: %[[RED_IDENTITY_LHS:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf16>
// CHECK-DAG: %[[RED_IDENTITY_RHS:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x2x2x2x2xf16>

// Note this this transposed to match the second indexing map
// CHECK-DAG: %[[MASK0:.+]] = vector.create_mask %[[D1UB:.+]], %[[D0UB:.+]] : vector<8x8xi1>

// CHECK-DAG: %[[LHS_MASK_EXTRACT:.+]] = vector.extract %[[MASK0]][0] : vector<8xi1> from vector<8x8xi1>
// CHECK-DAG: %[[LHS_MASK_PACKED:.+]] = vector.shape_cast %[[LHS_MASK_EXTRACT]] : vector<8xi1> to vector<2x2x2xi1>

// CHECK-DAG: %[[RHS_MASK_PACKED:.+]] = vector.shape_cast %[[MASK0]] : vector<8x8xi1> to vector<2x2x2x2x2x2xi1>
// CHECK-DAG: %[[RHS_MASK_INTERLVD:.+]] = vector.transpose %[[RHS_MASK_PACKED]], [0, 3, 1, 4, 2, 5] : vector<2x2x2x2x2x2xi1> to vector<2x2x2x2x2x2xi1>

// CHECK-DAG: %[[LHS_SELECT:.+]] = arith.select %[[LHS_MASK_PACKED]], %{{.*}}, %[[RED_IDENTITY_LHS]] : vector<2x2x2xi1>, vector<2x2x2xf16>
// CHECK-DAG: %[[RHS_SELECT:.+]] = arith.select %[[RHS_MASK_INTERLVD]], %{{.*}}, %[[RED_IDENTITY_RHS]] : vector<2x2x2x2x2x2xi1>, vector<2x2x2x2x2x2xf16>

// This is the actual op mask.
// CHECK-DAG: %[[MASK1:.+]] = vector.create_mask %[[D0UB]], %[[D1UB]] : vector<8x8xi1>
// CHECK-DAG: %[[MASK1_PACKED:.+]] = vector.shape_cast %[[MASK1]] : vector<8x8xi1> to vector<2x2x2x2x2x2xi1>
// CHECK-DAG: %[[MASK1_INTLVD:.+]] = vector.transpose %[[MASK1_PACKED]], [0, 3, 1, 4, 2, 5] : vector<2x2x2x2x2x2xi1> to vector<2x2x2x2x2x2xi1>

// CHECK: vector.mask %[[MASK1_INTLVD]] { vector.contract {{.*}} %[[LHS_SELECT]], %[[RHS_SELECT]]
21 changes: 21 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,27 @@ void PropagateLayout::visitMaskOp(
update(result, changed);
}
}

mask.getBody()->walk([&](vector::ContractionOp contract) {
const DistributionLayout *lhs = getLatticeElement(contract.getLhs());
const DistributionLayout *rhs = getLatticeElement(contract.getRhs());
if (!lhs->isUninitialized() && !rhs->isUninitialized()) {
if (NestedLayoutAttr lhsLayout =
dyn_cast<NestedLayoutAttr>(lhs->getLayout())) {
if (NestedLayoutAttr rhsLayout =
dyn_cast<NestedLayoutAttr>(rhs->getLayout())) {
SmallVector<NestedLayoutAttr> layouts{lhsLayout, rhsLayout};
SmallVector<AffineMap> maps{contract.getIndexingMapsArray()[0],
contract.getIndexingMapsArray()[1]};
NestedLayoutAttr inferredMaskLayout =
NestedLayoutAttr::get(lhsLayout.getContext(), layouts, maps);
DistributionLayout *maskLayout = getLatticeElement(mask.getMask());
ChangeResult changed = maskLayout->resolve(inferredMaskLayout);
update(maskLayout, changed);
}
}
}
});
}

void PropagateLayout::visitOperation(Operation *op) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,39 @@ NestedLayoutAttr NestedLayoutAttr::get(MLIRContext *context,
threadStrides);
}

NestedLayoutAttr
NestedLayoutAttr::get(MLIRContext *context,
ArrayRef<NestedLayoutAttr> operandLayouts,
ArrayRef<AffineMap> operandIndexingMaps) {
int64_t numDims = operandIndexingMaps[0].getNumDims();
SmallVector<int64_t> subgroupTile(numDims, 0);
SmallVector<int64_t> batchTile(numDims, 0);
SmallVector<int64_t> outerTile(numDims, 0);
SmallVector<int64_t> threadTile(numDims, 0);
SmallVector<int64_t> elementTile(numDims, 0);
SmallVector<int64_t> subgroupStrides(numDims, 0);
SmallVector<int64_t> threadStrides(numDims, 0);

for (auto [layout, indexingMap] :
llvm::zip(operandLayouts, operandIndexingMaps)) {
for (int64_t resultIdx : llvm::seq<int64_t>(indexingMap.getNumResults())) {
int64_t iterSpacePos = indexingMap.getDimPosition(resultIdx);
subgroupTile[iterSpacePos] = layout.getSubgroupTile()[resultIdx];
batchTile[iterSpacePos] = layout.getBatchTile()[resultIdx];
outerTile[iterSpacePos] = layout.getBatchTile()[resultIdx];
threadTile[iterSpacePos] = layout.getThreadTile()[resultIdx];
elementTile[iterSpacePos] = layout.getElementTile()[resultIdx];

subgroupStrides[iterSpacePos] = layout.getSubgroupStrides()[resultIdx];
threadStrides[iterSpacePos] = layout.getThreadStrides()[resultIdx];
}
}

return NestedLayoutAttr::get(context, subgroupTile, batchTile, outerTile,
threadTile, elementTile, subgroupStrides,
threadStrides);
}

LogicalResult NestedLayoutAttr::verify(
llvm::function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> subgroupTile, ArrayRef<int64_t> batchTile,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,12 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
"ArrayRef<int64_t>":$appendThreadLens,
"ArrayRef<int64_t>":$appendElementLens,
"ArrayRef<int64_t>":$appendSubgroupStrides,
"ArrayRef<int64_t>":$appendThreadStrides)>
"ArrayRef<int64_t>":$appendThreadStrides)>,
// Special builder to extract and unify a new layout
// to represent the iteration space from operand
// layout
AttrBuilder<(ins "ArrayRef<NestedLayoutAttr>":$operandLayouts,
"ArrayRef<AffineMap>":$operandIndexingMaps)>
];

let extraClassDeclaration = [{
Expand Down

0 comments on commit 95d0103

Please sign in to comment.