Skip to content

Commit

Permalink
Yet more IREEGPUAttrs cleanup: drop get{A,B,C}SingleSubgroupLayout
Browse files Browse the repository at this point in the history
…methods (iree-org#19169)

These methods existed before we added the unified
`getSingleSubgroupLayout` taking a `MMAFragment` argument. Now they can
go away. Actually polymorphic callers, which motivated this being an
interface method, are taken care of by a new overload of
`getSingleSubgroupLayout` taking a `MMAInterfaceAttr`.

---------

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Nov 16, 2024
1 parent e10342d commit 29c451b
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,19 @@ static LogicalResult isIntrinsicLayoutCompatible(
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
if (failed(isSubgroupLayoutCompatible(getASingleSubgroupLayout(intrinsic),
lhsLayout, lhsM, lhsK))) {
if (failed(isSubgroupLayoutCompatible(
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Lhs),
lhsLayout, lhsM, lhsK))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(getBSingleSubgroupLayout(intrinsic),
rhsLayout, rhsK, rhsN))) {
if (failed(isSubgroupLayoutCompatible(
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Rhs),
rhsLayout, rhsK, rhsN))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(getCSingleSubgroupLayout(intrinsic),
accLayout, accM, accN))) {
if (failed(isSubgroupLayoutCompatible(
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Acc),
accLayout, accM, accN))) {
return failure();
}
return success();
Expand Down
60 changes: 6 additions & 54 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ static bool is_AMD_WMMA(MMAIntrinsic intrinsic) {

static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
// Not using Wave64 at all at the moment, so the only place where the
// subgroup size is CDNA* architectures.
// subgroup size is 64 is on CDNA* architectures.
return is_AMD_MFMA(intrinsic) ? 64 : 32;
}

Expand Down Expand Up @@ -292,38 +292,14 @@ OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
return getOpaqueMMALayout<IREE::GPU::MMAIntrinsic>(context, intrinsic);
}

//===----------------------------------------------------------------------===//
// MmaInterface Attribute Helper Functions
//===----------------------------------------------------------------------===//

MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
return mmaAttr.getASingleSubgroupLayout();
}
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
return vmmaAttr.getASingleSubgroupLayout();
}
assert(false && "unhandled MMA Interface type.");
return {};
}

MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
return mmaAttr.getBSingleSubgroupLayout();
}
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
return vmmaAttr.getBSingleSubgroupLayout();
}
assert(false && "unhandled MMA Interface type.");
return {};
}

MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
MMAFragment fragment) {
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
return mmaAttr.getCSingleSubgroupLayout();
return getSingleSubgroupLayout(mmaAttr.getIntrinsic().getValue(), fragment);
}
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
return vmmaAttr.getCSingleSubgroupLayout();
return getSingleSubgroupLayout(vmmaAttr.getIntrinsic().getValue(),
fragment);
}
assert(false && "unhandled MMA Interface type.");
return {};
Expand Down Expand Up @@ -407,18 +383,6 @@ FailureOr<IREE::GPU::MMAScope> MMAAttr::getMmaScope() const {
return IREE::GPU::MMAScope::Subgroup;
}

MMASingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs);
}

MMASingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Rhs);
}

MMASingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc);
}

// Get virtual intrinsics that is composed/based on queried op.
SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
switch (getIntrinsic().getValue()) {
Expand Down Expand Up @@ -1098,18 +1062,6 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
return {};
}

MMASingleSubgroupLayout VirtualMMAAttr::getASingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs);
}

MMASingleSubgroupLayout VirtualMMAAttr::getBSingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Rhs);
}

MMASingleSubgroupLayout VirtualMMAAttr::getCSingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc);
}

//===----------------------------------------------------------------------===//
// Target Attributes
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 7 additions & 10 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ namespace mlir::iree_compiler::IREE::GPU {
// semantics in that case are that threads within the subgroup whose thread-ids
// differ by a multiple of `P`, are accessing the same elements.
//
// Example observed in RDNA3 WMMA Wave64 intrinsics:
// If the subgroup size is 64 but the product `P` of `thread` sizes is 32, that
// means that each element is being accessed by 2 threads (2 = 64/32), and the
// threads accessing the same element are those whose tids are exactly 32 apart.
// Example observed in RDNA3 WMMA Wave32 intrinsics:
// If the subgroup size is 32 but the product `P` of `thread` sizes is 16, that
// means that each element is being accessed by 2 threads (2 = 32/16), and the
// threads accessing the same element are those whose tids are exactly 16 apart.
struct MMASingleSubgroupLayout {
// Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
// outer-most in the layout. This happens when a MMA op, seen on a single
Expand All @@ -54,7 +54,7 @@ struct MMASingleSubgroupLayout {
// Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
// inner-most in the layout. This happens when a MMA op, seen on a single
// thread, has an operand that consists of multiple elements, and these elems
// are NOT contiguous.
// are contiguous.
// This is not used by every MMA op; ops which don't use that simply have 1's.
SmallVector<int64_t, 2> element;
};
Expand All @@ -65,11 +65,8 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
MMAFragment fragment);

MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind);

MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind);

MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind);
MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
MMAFragment fragment);

// Struct describing the shape of a MMA operation, but not the detailed layout.
// TODO(bjacob): the only user outside of IREEGPUAttrs.cpp is
Expand Down
22 changes: 0 additions & 22 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,6 @@ class IREEGPU_MmaVectorLayoutAttr<string attrname, string mmaintrinsic> :
"getMNKShape",
"getSubgroupSize",
"getMmaScope",
"getASingleSubgroupLayout",
"getBSingleSubgroupLayout",
"getCSingleSubgroupLayout",
"buildMmaOperation",
"populateOperandOffsetsSizesStrides",
]>
Expand Down Expand Up @@ -225,14 +222,6 @@ def IREEGPU_MMAAttr : IREEGPU_MmaVectorLayoutAttr<"MMA", "MMAIntrinsicAttr"> {
let extraClassDeclaration = [{
int64_t getBlockSize() const;

// Returns the A/B/C matrix's partial nested layout shape inside a single
// subgroup. Shape at each outer/thread/element level is a 2-D value,
// following canonical matmul order--(M, K) for A, (K, N) for B, and
// (M, N) for C.
MMASingleSubgroupLayout getASingleSubgroupLayout() const;
MMASingleSubgroupLayout getBSingleSubgroupLayout() const;
MMASingleSubgroupLayout getCSingleSubgroupLayout() const;

SmallVector<VirtualMMAIntrinsic> getVirtualIntrinsics() const;
}];
}
Expand Down Expand Up @@ -287,9 +276,6 @@ def IREEGPU_VirtualMMAAttr :
"getMNKShape",
"getSubgroupSize",
"getMmaScope",
"getASingleSubgroupLayout",
"getBSingleSubgroupLayout",
"getCSingleSubgroupLayout",
"populateOperandOffsetsSizesStrides",
"buildMmaOperation",
]>
Expand Down Expand Up @@ -319,14 +305,6 @@ def IREEGPU_VirtualMMAAttr :
let extraClassDeclaration = [{
int64_t getBlockSize() const;

// Returns the A/B/C matrix's partial nested layout shape inside a single
// subgroup. Shape at each outer/thread/element level is a 2-D value,
// following canonical matmul order--(M, K) for A, (K, N) for B, and
// (M, N) for C.
MMASingleSubgroupLayout getASingleSubgroupLayout() const;
MMASingleSubgroupLayout getBSingleSubgroupLayout() const;
MMASingleSubgroupLayout getCSingleSubgroupLayout() const;

// Factor to unroll K from native MMA/intrinsic size to virtual size.
// e.g MFMA_F32_16x16x16 has K of 16, while VMFMA_F32_16x16x32 has K of 32
// in this example, unrollK = 32/16 = 2.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,20 @@ LogicalResult materializeOperandConcreteShape(
SmallVector<ReassociationIndices> &reassociations,
RankedTensorType &resultType) {

SmallVector<int64_t, 2> outerSizes;
MMASingleSubgroupLayout layout = getSingleSubgroupLayout(mma, fragment);
SmallVector<int64_t, 2> outerSizes = layout.outer;
SmallVector<int64_t, 2> opaqueSizes;
auto [m, n, k] = mma.getMNKShape();
switch (fragment) {
case IREE::GPU::MMAFragment::Lhs: {
outerSizes = mma.getASingleSubgroupLayout().outer;
opaqueSizes.append({m, k});
break;
}
case IREE::GPU::MMAFragment::Rhs: {
outerSizes = mma.getBSingleSubgroupLayout().outer;
opaqueSizes.append({k, n});
break;
}
case IREE::GPU::MMAFragment::Acc: {
outerSizes = mma.getCSingleSubgroupLayout().outer;
opaqueSizes.append({m, n});
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,12 @@ getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
cSubgroupStrides[dim] = subgroupNStrides[i];
}

auto cLayout = createNestedLayout(context, cRank, m, n,
/*subgroupCount=*/cSubgroupSizes,
/*subgroupStrides=*/cSubgroupStrides,
/*batchCount=*/cBatchSizes,
getCSingleSubgroupLayout(mmaAttr));
IREE::VectorExt::NestedLayoutAttr cLayout = createNestedLayout(
context, cRank, m, n,
/*subgroupCount=*/cSubgroupSizes,
/*subgroupStrides=*/cSubgroupStrides,
/*batchCount=*/cBatchSizes,
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Acc));
LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; });

// A matrix layout
Expand All @@ -339,11 +340,12 @@ getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
}
aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;

auto aLayout = createNestedLayout(context, aRank, afm, afk,
/*subgroupCount=*/aSubgroupSizes,
/*subgroupStrides=*/aSubgroupStrides,
/*batchCount=*/aBatchSizes,
getASingleSubgroupLayout(mmaAttr));
IREE::VectorExt::NestedLayoutAttr aLayout = createNestedLayout(
context, aRank, afm, afk,
/*subgroupCount=*/aSubgroupSizes,
/*subgroupStrides=*/aSubgroupStrides,
/*batchCount=*/aBatchSizes,
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Lhs));
LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; });

int64_t bRank = opInfo.getBRank();
Expand All @@ -363,11 +365,12 @@ getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
}
bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;

auto bLayout = createNestedLayout(context, bRank, bfk, bfn,
/*subgroupCount=*/bSubgroupSizes,
/*subgroupStrides=*/bSubgroupStrides,
/*batchCount=*/bBatchSizes,
getBSingleSubgroupLayout(mmaAttr));
IREE::VectorExt::NestedLayoutAttr bLayout = createNestedLayout(
context, bRank, bfk, bfn,
/*subgroupCount=*/bSubgroupSizes,
/*subgroupStrides=*/bSubgroupStrides,
/*batchCount=*/bBatchSizes,
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Rhs));
LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; });

std::tuple<VectorLayoutInterface, VectorLayoutInterface,
Expand Down Expand Up @@ -618,11 +621,11 @@ static LogicalResult setAttentionMatmulAnchor(RewriterBase &rewriter,
auto pvIntrinsic =
cast<IREE::GPU::MmaInterfaceAttr>(pvSchedule.getIntrinsic());
IREE::GPU::MMASingleSubgroupLayout lhsLayout =
getASingleSubgroupLayout(pvIntrinsic);
getSingleSubgroupLayout(pvIntrinsic, IREE::GPU::MMAFragment::Lhs);
IREE::GPU::MMASingleSubgroupLayout rhsLayout =
getBSingleSubgroupLayout(pvIntrinsic);
getSingleSubgroupLayout(pvIntrinsic, IREE::GPU::MMAFragment::Rhs);
IREE::GPU::MMASingleSubgroupLayout outLayout =
getCSingleSubgroupLayout(qkIntrinsic);
getSingleSubgroupLayout(qkIntrinsic, IREE::GPU::MMAFragment::Acc);

auto matchLayout = [](IREE::GPU::MMASingleSubgroupLayout layoutA,
IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool {
Expand Down

0 comments on commit 29c451b

Please sign in to comment.