Skip to content

Commit

Permalink
[Codegen][GPU] Add support for WMMA_I32_16x16x16_I8 (iree-org#18372)
Browse files Browse the repository at this point in the history
This adds support for the signed variant of the I8 WMMA intrinsic for
RDNA3. The same instruction supports unsigned and mixed signedness
variants so integer intrinsics will need to be refactored away from
forced signed in the future.
  • Loading branch information
qedawkins authored Aug 27, 2024
1 parent 83be1a0 commit 7050033
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],

// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>]
// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>]
// GFX1100-SAME: subgroup_size_choices = [32, 64]

// GFX941: target = #iree_gpu.target<arch = "gfx941",
Expand Down
51 changes: 24 additions & 27 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f16};
}
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return OpaqueMmaLayout{16, 16, 16, i8, i8, i32};
}
}
llvm_unreachable("unhandled mfma layout type");
return OpaqueMmaLayout{};
Expand Down Expand Up @@ -353,38 +356,25 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>

auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {1, 16});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout =
PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {8, 2, 1});
auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16});
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>
int64_t vecYShape = type == MMAIntrinsic::WMMA_F16_16x16x16_F16 ? 16 : 8;
int64_t laneYShape = type == MMAIntrinsic::WMMA_F16_16x16x16_F16 ? 1 : 2;

auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {1, 16});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout =
PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {16, 1, 1});
auto cMLayout = PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX},
{vecYShape, laneYShape, 1});
auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16});
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
Expand Down Expand Up @@ -480,7 +470,8 @@ MMAAttr::getABCVectorTypes() const {
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
auto aType = VectorType::get({16}, getAType());
auto bType = VectorType::get({16}, getBType());
auto cType = VectorType::get({8}, getCType());
Expand Down Expand Up @@ -514,7 +505,8 @@ int64_t MMAAttr::getBlockSize() const {
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return 1;
}
}
Expand All @@ -533,7 +525,8 @@ int64_t MMAAttr::getSubgroupSize() const {
return 64;
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return 32;
}
}
Expand Down Expand Up @@ -565,7 +558,8 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
/*element=*/{1, 8}};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 0},
/*element=*/{1, 16}};
}
Expand Down Expand Up @@ -597,7 +591,8 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
/*element=*/{8, 1}};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{0, 1},
/*element=*/{16, 1}};
}
Expand All @@ -619,7 +614,8 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return {/*outer=*/{8, 1}, /*thread=*/{2, 16}, /*strides=*/{16, 1},
/*element=*/{1, 1}};
}
Expand Down Expand Up @@ -670,7 +666,8 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
.getResult();
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
.getResult();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 5>;
def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 6>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 7>;

// TODO: The actual I8 instruction allows specifying (mixed) signedness.
// This will need to become its own class of MMA attribute.
def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 8>;

def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F32_16x16x4_F32,
Expand All @@ -118,7 +122,8 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
MFMA_I32_16x16x32_I8,
MFMA_I32_32x32x16_I8,
WMMA_F32_16x16x16_F16,
WMMA_F16_16x16x16_F16
WMMA_F16_16x16x16_F16,
WMMA_I32_16x16x16_I8
]>;

def MMA_LHS : I32EnumAttrCase<"Lhs", 0>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ const WgpDetails *getRDNA3WgpDetails() {
static const MMAIntrinsic rdna3MMAOps[] = {
MMAIntrinsic::WMMA_F32_16x16x16_F16,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
MMAIntrinsic::WMMA_I32_16x16x16_I8,
};
static const WgpDetails rdna3Wgp = {allComputeBits,
allStorageBits,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,34 @@ func.func @concretize_WMMA_F16_16x16x16_F16(%lhs: tensor<16x16xf16>, %rhs: tenso
// CHECK-RESULT-SAME: : tensor<16x16xf16>, tensor<16x16xf16> into tensor<16x1x16xf16>
// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0, 1], [2]]
// CHECK-RESULT: return %[[COLLAPSED]]

// -----

#contraction_accesses = [
affine_map<() -> ()>,
affine_map<() -> ()>,
affine_map<() -> ()>
]
func.func @concretize_WMMA_I32_16x16x16_I8(%lhs: tensor<16x16xi8>, %rhs: tensor<16x16xi8>, %acc: tensor<16x16xi32>) -> tensor<16x16xi32> {
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [],
kind = #iree_gpu.mma_layout<WMMA_I32_16x16x16_I8>
} : tensor<16x16xi8>, tensor<16x16xi8> into tensor<16x16xi32>
return %0 : tensor<16x16xi32>
}

// CHECK-LABEL: func @concretize_WMMA_I32_16x16x16_I8
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<16x16xi8>
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<16x16xi8>
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<16x16xi32>

// CHECK-INPUTS-NOT: tensor.expand_shape
// CHECK-INPUTS: %[[MMA:.+]] = iree_gpu.multi_mma
// CHECK-INPUTS: return %[[MMA]]

// CHECK-RESULT: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0, 1], [2]] output_shape [8, 2, 16]
// CHECK-RESULT: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]]
// CHECK-RESULT-SAME: : tensor<16x16xi8>, tensor<16x16xi8> into tensor<8x2x16xi32>
// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0, 1], [2]]
// CHECK-RESULT: return %[[COLLAPSED]]
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ func.func @distribute_WMMA_F16_16x16x16_F16(%lhs: tensor<16x16xf16>, %rhs: tenso
}

// CHECK-DAG: #[[$XMAP:.+]] = affine_map<(d0) -> (d0 mod 16)>
// CHECK-DAG: #[[$YMAP:.+]] = affine_map<() -> ()>

// CHECK-LABEL: func @distribute_WMMA_F16_16x16x16_F16
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<16x16xf16>
Expand All @@ -254,3 +253,44 @@ func.func @distribute_WMMA_F16_16x16x16_F16(%lhs: tensor<16x16xf16>, %rhs: tenso
// CHECK-SAME: : tensor<1x16xf16>, tensor<16x1xf16> into tensor<16x1x1xf16>
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDX]]] [16, 1, 1]
// CHECK: mapping = [#iree_gpu.lane_id<0>]

// -----

#contraction_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
module {
func.func @matmul_wmma_i32_16x16x16_i8(%arg0: tensor<2x8x16x16xi8>, %arg1: tensor<8x2x16x16xi8>, %arg2: tensor<2x2x8x2x16xi32>) -> tensor<2x2x8x2x16xi32> {
%mm = iree_gpu.multi_mma %arg0, %arg1, %arg2 {
indexing_maps = #contraction_accesses,
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.mma_layout<WMMA_I32_16x16x16_I8>,
rhs_permutation = array<i64: 1, 0>
} : tensor<2x8x16x16xi8>, tensor<8x2x16x16xi8> into tensor<2x2x8x2x16xi32>
return %mm : tensor<2x2x8x2x16xi32>
}
}

// CHECK-DAG: #[[$XMAP:.+]] = affine_map<(d0) -> (d0 mod 16)>
// CHECK-DAG: #[[$YMAP:.+]] = affine_map<(d0) -> ((d0 floordiv 16) mod 2)>
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>

// CHECK-LABEL: func @matmul_wmma_i32_16x16x16_i8
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x8x16x16xi8>
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x16x16xi8>
// CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x8x2x16xi32>)
// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[$XMAP]](%[[LANEID]])
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[IDX]], 0] [2, 8, 1, 16]
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[IDX]], 0] [8, 2, 1, 16]
// CHECK-DAG: %[[IDY:.+]] = affine.apply #[[$YMAP]](%[[LANEID]])
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[IDX]]] [2, 2, 8, 1, 1]
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: kind = #iree_gpu.mma_layout<WMMA_I32_16x16x16_I8>
// CHECK-SAME: : tensor<2x8x1x16xi8>, tensor<8x2x1x16xi8> into tensor<2x2x8x1x1xi32>
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[IDX]]] [2, 2, 8, 1, 1]
// CHECK: mapping = [#iree_gpu.lane_id<0>]
29 changes: 29 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2508,6 +2508,35 @@ iree_generated_e2e_runner_test(
"requires-gpu-rdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_i8_large_rdna3_wmma_tb
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
"--acc_type=i32"
"--transpose_rhs"
"--shapes=gpu_large_aligned"
"--compilation_info=LLVMGPUVectorDistributeWMMA"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"hip"
COMPILER_FLAGS
${IREE_HIP_TEST_COMPILER_FLAGS}
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-rdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rdna3_experimental_dt_f32_f32
Expand Down
11 changes: 11 additions & 0 deletions tests/e2e/matmul/generate_e2e_matmul_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,13 @@ def get_rocm_test_compilation_infos(
MMASchedule("WMMA_F32_16x16x16_F16", 2, 2, 1, 1, 1),
MMASchedule("WMMA_F32_16x16x16_F16", 2, 4, 2, 1, 2),
MMASchedule("WMMA_F32_16x16x16_F16", 4, 2, 4, 2, 2),
MMASchedule("WMMA_I32_16x16x16_I8", 1, 1, 1, 1, 1),
MMASchedule("WMMA_I32_16x16x16_I8", 1, 1, 1, 1, 2),
MMASchedule("WMMA_I32_16x16x16_I8", 1, 1, 1, 2, 1),
MMASchedule("WMMA_I32_16x16x16_I8", 1, 1, 2, 1, 1),
MMASchedule("WMMA_I32_16x16x16_I8", 2, 2, 1, 1, 1),
MMASchedule("WMMA_I32_16x16x16_I8", 2, 4, 2, 1, 2),
MMASchedule("WMMA_I32_16x16x16_I8", 4, 2, 4, 2, 2),
]
else:
raise NotImplementedError("unhandled intrinsic case")
Expand Down Expand Up @@ -342,6 +349,10 @@ def get_rocm_test_compilation_infos(
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 16
elif schedule.intrinsic == "WMMA_I32_16x16x16_I8":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 16
else:
raise NotImplementedError("unhandled intrinsic case")

Expand Down

0 comments on commit 7050033

Please sign in to comment.