Skip to content

Commit

Permalink
[LLVMGPU][ROCM][Layoutv1] Landing Implementation of WMMA on layoutV1 (i…
Browse files Browse the repository at this point in the history
…ree-org#17580)

Introducing WMMA on layoutv1. This feature is required/a prerequisite
for getting FA2 on RDNA3 GPUs.
This PR brings:

1. WMMA on layoutV1
2. Ensuring VectorX is specified to prevent unexpected behavior for
"step" setting in SIMT index.
3. Update of amdgpu_distribute_vectors/distribution of contract such
that it follow the newer style of `set_contraction_layout_attributes`
for setting intrinsic types as opposed to the less safer method of
inferring intrinsic types.
4. Update of tests to match above.

Co-authored-by: Groverkss <[email protected]>

---------

Co-authored-by: Kunwar Grover <[email protected]>
  • Loading branch information
raikonenfnu and Groverkss authored Jun 5, 2024
1 parent 9d60462 commit b44581a
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 314 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,9 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {8, 2});
auto cNLayout = outer;
auto cMLayout =
PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {8, 2, 1});
auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16});
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,7 @@ transform_dialect::SetContractionLayoutAttributes::apply(
<< "invalid opaque mma layout for annotation " << mmaType;
}

contract->setAttr("iree.amdgpu.mma", mmaType);
auto [aLayout, bLayout, cLayout] = *maybeLayouts;
contract->setAttr("__vector_layout_test_anchor_operand_0", aLayout);
contract->setAttr("__vector_layout_test_anchor_operand_1", bLayout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"

Expand All @@ -18,83 +19,6 @@ using VectorValue = TypedValue<VectorType>;
enum class ContractMatrixType { A, B, C, D };
enum class ContractType { MM, MMT, MTM, MTMT, UNSUPPORTED };

/// We define AMD MFMA instruction layouts only for contract type MM, i.e. C(i,
/// k) += A(i, j) * B(j, k). We call these canonical layouts: layoutA, layoutB,
/// layoutC, corresponding to the A, B, C matrices.
///
/// For any other contract type, the layout is simply transposed for that
/// operand. For example, for MMT, the layouts used should be layoutA,
/// layoutB.T, layoutC. For an easier understanding of this transposition,
/// think of the transpose simply being outside the contract:
///
/// vector.contract {type = MMT} %a, %b
///
/// is equivalent to
///
/// %bt = vector.transpose %b
/// vector.contract {type = MM} %a, %bt
///
/// Now, you would assign layouts based on contract type MM, and would get the
/// right layout for %b by transposing the layout for B.
///
/// Now that we have defined what layoutA, layoutB, layoutC are, we will define
/// what the canonical layouts are for each MFMA instruction. These are
/// represented as the original matrix, with elements representing which thread
/// id in the subgroup gets which element.
/// These layouts were referenced from
/// https://github.com/ROCm/amd_matrix_instruction_calculator
///
/// The naming scheme for these operators is InputType_MxNxK_OutputType.
enum class MFMAType {
/// layoutA:
/// 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
/// 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
/// 2 2 2 2 18 18 18 18 34 34 34 34 50 50 50 50
/// ...
/// 15 15 15 15 31 31 31 31 47 47 47 47 63 63 63 63
///
/// layoutB:
/// Transpose of layoutA
///
/// layoutC:
/// Same as layoutB
F16_16x16x16_F32,
/// layoutA:
/// 0 0 0 0 32 32 32 32
/// 1 1 1 1 33 33 33 33
/// 2 2 2 2 34 34 34 34
/// ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
/// 31 31 31 31 63 63 63 63
///
/// layoutB:
/// Transpose of layoutA
///
/// layoutC:
/// 0 1 2 ... 31
/// 0 1 2 ... 31
/// 0 1 2 ... 31
/// 0 1 2 ... 31
/// 32 33 34 ... 63
/// 32 33 34 ... 63
/// 32 33 34 ... 63
/// 32 33 34 ... 63
/// 0 1 2 ... 31
/// ⋮ ⋮ ⋮ ... ⋮
/// 32 33 34 ... 63
/// 0 1 2 ... 31
/// ⋮ ⋮ ⋮ ... ⋮
/// 32 33 34 ... 63
/// 0 1 2 ... 31
/// 0 1 2 ... 31
/// 0 1 2 ... 31
/// 0 1 2 ... 31
/// 32 33 34 ... 63
/// 32 33 34 ... 63
/// 32 33 34 ... 63
/// 32 33 34 ... 63
F16_32x32x8_F32,
};

namespace {

static bool isOperandATransposed(ContractType contractType) {
Expand Down Expand Up @@ -166,208 +90,6 @@ struct DistributeContractions final
return ContractType::UNSUPPORTED;
}

Value computeMMA(Value a, Value b, Value c, Location loc, OpBuilder &rewriter,
MFMAType mfmaType) const {
uint32_t m, n, k, blks;
if (mfmaType == MFMAType::F16_16x16x16_F32) {
m = n = k = 16;
} else if (mfmaType == MFMAType::F16_32x32x8_F32) {
m = n = 32;
k = 8;
}
blks = 1;
return rewriter.create<amdgpu::MFMAOp>(loc, c.getType(), m, n, k, blks, a,
b, c);
}

PerDimLayoutAttr createPerDimLayout(MLIRContext *ctx,
ArrayRef<LayoutDimension> dims,
ArrayRef<int64_t> shapes) const {
SmallVector<LayoutDimensionAttr> dimAttrs;
for (auto dim : dims)
dimAttrs.push_back(LayoutDimensionAttr::get(ctx, dim));
return PerDimLayoutAttr::get(ctx, dimAttrs, shapes);
}

std::tuple<PerDimLayoutAttr, PerDimLayoutAttr> createCanonicalLayouts16x16x16(
LayoutDimension batchRowLabel, int64_t batchRow,
LayoutDimension batchColLabel, int64_t batchCol) const {
MLIRContext *ctx = getContext();
PerDimLayoutAttr rowLayout = createPerDimLayout(
ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 16});
PerDimLayoutAttr colLayout = createPerDimLayout(
ctx, {batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX},
{batchCol, 4, 4});
return {rowLayout, colLayout};
}

bool isCompatible16x16x16A(LayoutAttr layout, int64_t batchRow,
int64_t batchCol) const {
auto [rowLayout, colLayout] = createCanonicalLayouts16x16x16(
LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol);
LayoutAttr canonicalLayout =
LayoutAttr::get(getContext(), {rowLayout, colLayout});
return layout == canonicalLayout;
}

bool isCompatible16x16x16B(LayoutAttr layout, int64_t batchRow,
int64_t batchCol) const {
auto [colLayout, rowLayout] = createCanonicalLayouts16x16x16(
LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow);
LayoutAttr canonicalLayout =
LayoutAttr::get(getContext(), {rowLayout, colLayout});
return layout == canonicalLayout;
}

bool isCompatible16x16x16C(LayoutAttr layout, int64_t batchRow,
int64_t batchCol) const {
return isCompatible16x16x16B(layout, batchRow, batchCol);
}

std::tuple<PerDimLayoutAttr, PerDimLayoutAttr>
createCanonicalLayouts32x32x8(LayoutDimension batchRowLabel, int64_t batchRow,
LayoutDimension batchColLabel, int64_t batchCol,
ContractMatrixType matrixType) const {
MLIRContext *ctx = getContext();
PerDimLayoutAttr rowLayout = createPerDimLayout(
ctx, {batchRowLabel, LayoutDimension::LANEX}, {batchRow, 32});
PerDimLayoutAttr colLayout;
if (matrixType == ContractMatrixType::C) {
colLayout =
createPerDimLayout(ctx,
{batchColLabel, LayoutDimension::VECTORY,
LayoutDimension::LANEY, LayoutDimension::VECTORX},
{batchCol, 4, 2, 4});
} else {
colLayout = createPerDimLayout(
ctx,
{batchColLabel, LayoutDimension::LANEY, LayoutDimension::VECTORX},
{batchCol, 2, 4});
}
return {rowLayout, colLayout};
}

bool isCompatible32x32x8A(LayoutAttr layout, int64_t batchRow,
int64_t batchCol) const {
auto [rowLayout, colLayout] = createCanonicalLayouts32x32x8(
LayoutDimension::BATCHX, batchRow, LayoutDimension::BATCHY, batchCol,
ContractMatrixType::A);
LayoutAttr canonicalLayout =
LayoutAttr::get(getContext(), {rowLayout, colLayout});
return layout == canonicalLayout;
}

bool isCompatible32x32x8B(LayoutAttr layout, int64_t batchRow,
int64_t batchCol) const {
auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8(
LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow,
ContractMatrixType::B);
LayoutAttr canonicalLayout =
LayoutAttr::get(getContext(), {rowLayout, colLayout});
return layout == canonicalLayout;
}

bool isCompatible32x32x8C(LayoutAttr layout, int64_t batchRow,
int64_t batchCol) const {
auto [colLayout, rowLayout] = createCanonicalLayouts32x32x8(
LayoutDimension::BATCHY, batchCol, LayoutDimension::BATCHX, batchRow,
ContractMatrixType::C);
LayoutAttr canonicalLayout =
LayoutAttr::get(getContext(), {rowLayout, colLayout});
return layout == canonicalLayout;
}

bool isCompatible16x16x16(LayoutAttr layout, ContractMatrixType matrixType,
int64_t batchRow, int64_t batchCol) const {
switch (matrixType) {
case ContractMatrixType::A:
return isCompatible16x16x16A(layout, batchRow, batchCol);
case ContractMatrixType::B:
return isCompatible16x16x16B(layout, batchRow, batchCol);
default:
return isCompatible16x16x16C(layout, batchRow, batchCol);
}
return false;
}

bool isCompatible32x32x8(LayoutAttr layout, ContractMatrixType matrixType,
int64_t batchRow, int64_t batchCol) const {
switch (matrixType) {
case ContractMatrixType::A:
return isCompatible32x32x8A(layout, batchRow, batchCol);
case ContractMatrixType::B:
return isCompatible32x32x8B(layout, batchRow, batchCol);
default:
return isCompatible32x32x8C(layout, batchRow, batchCol);
}
return false;
}

bool isCompatible(LayoutAttr layout, ContractMatrixType matrixType,
MFMAType mfmaType) const {
std::optional<int64_t> batchRow = layout.getBatchDim(0);
if (!batchRow)
return false;
std::optional<int64_t> batchCol = layout.getBatchDim(1);
if (!batchCol)
return false;
switch (mfmaType) {
case MFMAType::F16_16x16x16_F32:
return isCompatible16x16x16(layout, matrixType, batchRow.value(),
batchCol.value());
case MFMAType::F16_32x32x8_F32:
return isCompatible32x32x8(layout, matrixType, batchRow.value(),
batchCol.value());
default:
return false;
}
return false;
}

// If we have a prior guess of the MFMA type, only evaluate that type.
// Otherwise, evaluate all types to find a match.
std::optional<MFMAType> inferMFMAType(LayoutAttr layout,
ContractMatrixType matrixType,
std::optional<MFMAType> prior) const {
SmallVector<MFMAType> mfmaTypes;
if (prior) {
mfmaTypes.push_back(prior.value());
} else {
mfmaTypes = {MFMAType::F16_16x16x16_F32, MFMAType::F16_32x32x8_F32};
}
for (MFMAType mfmaType : mfmaTypes) {
if (isCompatible(layout, matrixType, mfmaType))
return mfmaType;
}
return std::nullopt;
}

// Inputs are LHS, RHS and ACC operands and corresponding layouts.
// Output is inferred MFMAType or none (if layout is not compatible with any
// MFMA layout).
std::optional<MFMAType>
inferCompatibleMFMAType(ArrayRef<LayoutAttr> layouts,
ContractType contractType) const {
std::optional<MFMAType> mfmaType{std::nullopt};
SmallVector<ContractMatrixType> matrixTypes{
ContractMatrixType::A, ContractMatrixType::B, ContractMatrixType::C};

// Canonical layouts for MFMA are transposes of each other.
if (isOperandATransposed(contractType)) {
matrixTypes[0] = ContractMatrixType::B;
}
if (isOperandBTransposed(contractType)) {
matrixTypes[1] = ContractMatrixType::A;
}

for (auto [layout, matrixType] : llvm::zip(layouts, matrixTypes)) {
mfmaType = inferMFMAType(layout, matrixType, mfmaType);
if (!mfmaType)
return std::nullopt;
}
return mfmaType;
}

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -408,10 +130,12 @@ struct DistributeContractions final
if (contractType == ContractType::UNSUPPORTED)
return failure();

std::optional<MFMAType> mfmaType =
inferCompatibleMFMAType(layouts, contractType);
if (!mfmaType)
return failure();
auto mmaAttr =
contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
if (!mmaAttr) {
return rewriter.notifyMatchFailure(
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}

std::optional<int64_t> rowBatch = layouts[LHS].getBatchDim(0);
if (!rowBatch)
Expand All @@ -434,10 +158,13 @@ struct DistributeContractions final
Value bMatrix = rewriter.create<vector::ExtractOp>(
loc, getDistributed(rewriter, operands[RHS], layouts[RHS]),
getIndices(contractType, ContractMatrixType::B, k, indices[1]));
dMatrix = computeMMA(aMatrix, bMatrix, dMatrix, loc, rewriter,
mfmaType.value());
dMatrix = mmaAttr
.buildMmaOperation(rewriter, loc, dMatrix.getType(),
aMatrix, bMatrix, dMatrix)
.value();
}
vector = rewriter.create<vector::InsertOp>(loc, dMatrix, vector, indices);
return success();
};

LayoutIterator iterator(resultLayout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Common:VectorLayoutAnalysis",
"//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ iree_cc_library(
iree::compiler::Codegen::Common
iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Common::VectorLayoutAnalysis
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::IR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
[
"amdgpu_chained_matmul.mlir",
"amdgpu_contraction_distribution.mlir",
"amdgpu_set_anchor_layouts.mlir",
"attention.mlir",
"attention_mfma.mlir",
"conv_pipeline_test_cuda.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_lit_test_suite(
SRCS
"amdgpu_chained_matmul.mlir"
"amdgpu_contraction_distribution.mlir"
"amdgpu_set_anchor_layouts.mlir"
"attention.mlir"
"attention_mfma.mlir"
"cast_address_space_function.mlir"
Expand Down
Loading

0 comments on commit b44581a

Please sign in to comment.