Skip to content

Commit

Permalink
[LLVMGPU] Teach KernelConfig to set MMA schedules per op in Lowering…
Browse files Browse the repository at this point in the history
…Config (iree-org#18984)

The main motivation for this change is to enable different
intrinsics/layout on different ops inside the same function/dispatch,
especially for attention. To that extend, we move the scheduling MMA
information such as mma_intrinsic, subgroup_m_count, and
subgroup_n_count, from the translation info attached to the function
onto the lowering_config per op. Here is a quick summary of things we
needed to do to achieve that:

1. Introduce setMmaKind, set/get subgroupMCount, and set/get
subgroupMCount on IREE::GPU::LoweringConfigAttr
2. Move configuring of QK matmul's schedule into KernelConfig from
LLVMGPUConfigureTensorLayout.
3. Now that qk and pv may have different intrinsic, update information
used to decide transposeIntrinsic and reuseIntrinsic in
LLVMGPUConfigureTensorLayout
4. Update a bunch of tests to use lowering config to configure MMAs now.

---------

Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu authored Nov 4, 2024
1 parent ec7528c commit a5537bc
Show file tree
Hide file tree
Showing 12 changed files with 244 additions and 157 deletions.
45 changes: 45 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,51 @@ IREE::GPU::MmaInterfaceAttr LoweringConfigAttr::getMmaKind() const {
return getAttributes().getAs<IREE::GPU::MmaInterfaceAttr>(kMmaKindName);
}

void LoweringConfigAttr::setMmaKind(MLIRContext *context,
SmallVectorImpl<NamedAttribute> &attrs,
IREE::GPU::MmaInterfaceAttr kind) {
attrs.emplace_back(StringAttr::get(context, kMmaKindName), kind);
}

// TODO: Merge subgroup counts functionality into subgroup tiling level
// lowering, when we have it implemented.
constexpr StringLiteral kSubgroupMCountName = "subgroup_m_count";
constexpr StringLiteral kSubgroupNCountName = "subgroup_n_count";

std::optional<int64_t> LoweringConfigAttr::getSubgroupMCount() const {
auto subgroup_m_count_attr =
getAttributes().getAs<IntegerAttr>(kSubgroupMCountName);
if (!subgroup_m_count_attr) {
return std::nullopt;
}
return subgroup_m_count_attr.getInt();
}

std::optional<int64_t> LoweringConfigAttr::getSubgroupNCount() const {
auto subgroup_n_count_attr =
getAttributes().getAs<IntegerAttr>(kSubgroupNCountName);
if (!subgroup_n_count_attr) {
return std::nullopt;
}
return subgroup_n_count_attr.getInt();
}

void LoweringConfigAttr::setSubgroupMCount(
MLIRContext *context, SmallVectorImpl<NamedAttribute> &attrs,
int64_t subgroup_m_count) {
attrs.emplace_back(
StringAttr::get(context, kSubgroupMCountName),
IntegerAttr::get(IntegerType::get(context, 64), subgroup_m_count));
}

void LoweringConfigAttr::setSubgroupNCount(
MLIRContext *context, SmallVectorImpl<NamedAttribute> &attrs,
int64_t subgroup_n_count) {
attrs.emplace_back(
StringAttr::get(context, kSubgroupNCountName),
IntegerAttr::get(IntegerType::get(context, 64), subgroup_n_count));
}

constexpr StringLiteral kPromoteOperandsName = "promote_operands";

std::optional<SmallVector<int64_t>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,23 @@ def IREEGPU_LoweringConfigAttr :
"The configured fields, including tiling levels">:$attributes
);
let extraClassDeclaration = [{
/// Helper to retrieve a target mma intrinsic if present.
/// Helper to retrieve/set a target mma intrinsic.
::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr getMmaKind() const;
static void setMmaKind(MLIRContext *context,
SmallVectorImpl<NamedAttribute> &attrs,
::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr kind);

// TODO: Merge subgroup counts functionality into subgroup tiling level
// lowering, when we have it implemented.
/// Helper to retrieve/set a target subgroup M/N counts.
std::optional<int64_t> getSubgroupMCount() const;
std::optional<int64_t> getSubgroupNCount() const;
static void setSubgroupMCount(MLIRContext *context,
SmallVectorImpl<NamedAttribute> &attrs,
int64_t subgroup_m_count);
static void setSubgroupNCount(MLIRContext *context,
SmallVectorImpl<NamedAttribute> &attrs,
int64_t subgroup_n_count);

/// Helper to retrieve/set a list of operand indices to promote.
std::optional<SmallVector<int64_t>> getPromotedOperandList() const;
Expand Down
54 changes: 35 additions & 19 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,18 +416,17 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target,
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});
IREE::GPU::LoweringConfigAttr::setMmaKind(context, attrs,
mmaAttrs[schedule->index]);
IREE::GPU::LoweringConfigAttr::setSubgroupMCount(
context, attrs, schedule->mSubgroupCounts[0]);
IREE::GPU::LoweringConfigAttr::setSubgroupNCount(
context, attrs, schedule->nSubgroupCounts[0]);

auto configDict = DictionaryAttr::get(context, attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);

// Attach the MMA schedule as an attribute to the entry point export function
// for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
schedule->nSubgroupCounts[0]);
pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
scheduleAttr);

// Prefetch shared memory if requested.
if (clLLVMGPUEnablePrefetch) {
Expand Down Expand Up @@ -682,18 +681,19 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});
IREE::GPU::LoweringConfigAttr::setMmaKind(context, attrs,
mmaAttrs[schedule->index]);
IREE::GPU::LoweringConfigAttr::setSubgroupMCount(
context, attrs, schedule->mSubgroupCounts[0]);
IREE::GPU::LoweringConfigAttr::setSubgroupNCount(
context, attrs, schedule->nSubgroupCounts[0]);

auto configDict = DictionaryAttr::get(context, attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);

// Attach the MMA schedule as an attribute to the entry point export function
// for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
schedule->nSubgroupCounts[0]);
pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
scheduleAttr);

// Prefetch shared memory if requested.
if (clLLVMGPUEnablePrefetch) {
Expand Down Expand Up @@ -902,9 +902,32 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
SmallVector<NamedAttribute, 2> qkConfig;
SmallVector<NamedAttribute, 2> pvConfig;

// On attention subgroup distribution:
// The subgroup distribution in attention is controlled by the second matmul
// (Parallel dimension distribution is usually (almost always) controlled by
// the last reduction operation in a dispatch). Since VectorDistribution
// doesn't have logic to set subgroup and thread layouts seperately, we
// explicitly set the subgroup count for the first matmul as well,
// corresponding to what the second matmul dictates.

// Configuring for qk matmul.
// subgroup_n count for qk matmul is always 1, since we do not tile K1.
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, qkConfig,
{0, 1});
IREE::GPU::LoweringConfigAttr::setMmaKind(context, qkConfig,
mmaAttrs[schedule->index]);
IREE::GPU::LoweringConfigAttr::setSubgroupMCount(
context, qkConfig, schedule->mSubgroupCounts[0]);
IREE::GPU::LoweringConfigAttr::setSubgroupNCount(context, qkConfig, 1);

// Configuring for pv matmul.
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, pvConfig, {1});
IREE::GPU::LoweringConfigAttr::setMmaKind(context, pvConfig,
mmaAttrs[schedule->index]);
IREE::GPU::LoweringConfigAttr::setSubgroupMCount(
context, pvConfig, schedule->mSubgroupCounts[0]);
IREE::GPU::LoweringConfigAttr::setSubgroupNCount(
context, pvConfig, schedule->nSubgroupCounts[0]);

SmallVector<NamedAttribute, 2> qkAttrs;
SmallVector<NamedAttribute, 2> pvAttrs;
Expand Down Expand Up @@ -938,14 +961,7 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
auto configDict = b.getDictionaryAttr(attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);

// Attach the MMA schedule as an attribute to the entry point export function
// for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
schedule->nSubgroupCounts[0]);
pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
scheduleAttr);

// TODO: We do not turn prefetching on even when requested by the prefetching
// flag because there is a shared memory allocation the two matmuls, which
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,29 @@ static SmallVector<bool> getPromotedOperands(Operation *op) {
return promotedOperands;
}

static IREE::GPU::MmaInterfaceAttr getIntrinsic(Operation *op) {
auto config = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
assert(config && "Cannot find intrinsic from unconfigured op.");

IREE::GPU::MmaInterfaceAttr mmaIntrinsic = config.getMmaKind();
assert(mmaIntrinsic && "Cannot find intrinsic in lowering config.");
return mmaIntrinsic;
}

static int64_t getSubgroupMCount(Operation *op) {
auto config = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
assert(config && "Cannot find intrinsic from unconfigured op.");

return *config.getSubgroupMCount();
}

static int64_t getSubgroupNCount(Operation *op) {
auto config = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
assert(config && "Cannot find intrinsic from unconfigured op.");

return *config.getSubgroupNCount();
}

static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
SmallVector<bool> promotedOperands,
RewriterBase &rewriter,
Expand Down Expand Up @@ -264,14 +287,19 @@ transposeSchedule(RewriterBase &rewriter, IREE::GPU::MMAScheduleAttr schedule) {
schedule.getSubgroupMCount());
}

static LogicalResult
setAttentionMatmulAnchor(IREE::GPU::MMAScheduleAttr schedule,
RewriterBase &rewriter, linalg::LinalgOp qkMatmul,
linalg::LinalgOp pvMatmul) {
// TODO: Add SIMT fallback.
if (!schedule) {
return pvMatmul->emitError("missing mma schedule for contraction");
}
static LogicalResult setAttentionMatmulAnchor(RewriterBase &rewriter,
linalg::LinalgOp qkMatmul,
linalg::LinalgOp pvMatmul) {

IREE::GPU::MMAScheduleAttr qkSchedule =
rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(getIntrinsic(qkMatmul),
getSubgroupMCount(qkMatmul),
getSubgroupNCount(qkMatmul));

IREE::GPU::MMAScheduleAttr pvSchedule =
rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(getIntrinsic(pvMatmul),
getSubgroupMCount(pvMatmul),
getSubgroupNCount(pvMatmul));

// Check if the intrinsic output for qkMatmul can be reused for pvMatmul.
// We know that pvMatmul takes result of qkMatmul as it's lhs.
Expand All @@ -280,13 +308,14 @@ setAttentionMatmulAnchor(IREE::GPU::MMAScheduleAttr schedule,
bool reuseIntrinsicOutput = false;
bool transposeIntrinsic = false;

auto intrinsic = cast<IREE::GPU::MMAAttr>(schedule.getIntrinsic());
auto qkIntrinsic = cast<IREE::GPU::MMAAttr>(qkSchedule.getIntrinsic());
auto pvIntrinsic = cast<IREE::GPU::MMAAttr>(pvSchedule.getIntrinsic());
IREE::GPU::MMASingleSubgroupLayout lhsLayout =
intrinsic.getASingleSubgroupLayout();
pvIntrinsic.getASingleSubgroupLayout();
IREE::GPU::MMASingleSubgroupLayout rhsLayout =
intrinsic.getBSingleSubgroupLayout();
pvIntrinsic.getBSingleSubgroupLayout();
IREE::GPU::MMASingleSubgroupLayout outLayout =
intrinsic.getCSingleSubgroupLayout();
qkIntrinsic.getCSingleSubgroupLayout();

auto matchLayout = [](IREE::GPU::MMASingleSubgroupLayout layoutA,
IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool {
Expand All @@ -305,15 +334,6 @@ setAttentionMatmulAnchor(IREE::GPU::MMAScheduleAttr schedule,
transposeIntrinsic = true;
}

// subgroup_n count for attention matmul is always 1, because it is the
// reduction dimension. The subgroup_n count is in reality, for the pvMatmul.
IREE::GPU::MMAScheduleAttr qkSchedule =
rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
schedule.getIntrinsic(),
/*subgroup_m_count=*/schedule.getSubgroupMCount(),
/*subgroup_n_count=*/1);
IREE::GPU::MMAScheduleAttr pvSchedule = schedule;

SmallVector<bool> promotedQKOperands = getPromotedOperands(qkMatmul);
SmallVector<bool> promotedPVOperands = getPromotedOperands(pvMatmul);

Expand Down Expand Up @@ -488,12 +508,6 @@ struct LLVMGPUConfigureTensorLayoutsPass final
return signalPassFailure();
}

llvm::StringLiteral scheduleAttrName =
IREE::GPU::MMAScheduleAttr::getMnemonic();
DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
auto scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
configDict.get(scheduleAttrName));

// Vector layout option setter aimed at contractions and convolutions. For
// now, layout setting for other problems like reductions is TODO.
SmallVector<linalg::LinalgOp> contracts;
Expand Down Expand Up @@ -529,23 +543,28 @@ struct LLVMGPUConfigureTensorLayoutsPass final

for (linalg::LinalgOp contract : contracts) {
SmallVector<bool> promotedOperands = getPromotedOperands(contract);
if (failed(setContractionAnchor(scheduleAttr, promotedOperands, rewriter,
contract))) {
auto contractionSchedule = rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
getIntrinsic(contract), getSubgroupMCount(contract),
getSubgroupNCount(contract));
if (failed(setContractionAnchor(contractionSchedule, promotedOperands,
rewriter, contract))) {
return signalPassFailure();
}
}

for (linalg::LinalgOp conv : convs) {
SmallVector<bool> promotedOperands = getPromotedOperands(conv);
if (failed(setConvolutionAnchor(scheduleAttr, promotedOperands, rewriter,
auto convSchedule = rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
getIntrinsic(conv), getSubgroupMCount(conv), getSubgroupNCount(conv));
if (failed(setConvolutionAnchor(convSchedule, promotedOperands, rewriter,
conv))) {
return signalPassFailure();
}
}

if (attentionQKMatmul && attentionPVMatmul) {
if (failed(setAttentionMatmulAnchor(
scheduleAttr, rewriter, attentionQKMatmul, attentionPVMatmul))) {
if (failed(setAttentionMatmulAnchor(rewriter, attentionQKMatmul,
attentionPVMatmul))) {
return signalPassFailure();
}
}
Expand Down
Loading

0 comments on commit a5537bc

Please sign in to comment.