Skip to content

Commit

Permalink
[RVV] Optimize Generic RVV Matmul codegen (iree-org#18986)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhbruce authored Nov 22, 2024
1 parent 205af92 commit 17fde4d
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 0 deletions.
94 changes: 94 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ static llvm::cl::opt<bool> clDisableArmSMETiling(
"target (i.e., when the +sme feature flag is present)"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clEnableRiscvAggressiveDist(
"iree-llvmcpu-riscv-aggressive-distribution",
llvm::cl::desc(
"Enable aggressive method for distribution tile size. "
"It is only applied for linalg contraction ops now. "
"If distConfig.minTileSizes[i] >= distConfig.maxTileSizes[i], "
"set distConfig.maxTileSizes[i] to 2 * distConfig.minTileSizes[i]."),
llvm::cl::init(false));

using IREE::Codegen::DispatchLoweringPassPipeline;

// Encodes the pre-processing strategy to be applied on a Linalg operation
Expand Down Expand Up @@ -1289,6 +1298,62 @@ static void getMatmulVectorSizesUsingFullVectorHeuristics(
sizes[1] = std::max<int64_t>(sizes[1], minNumElements);
}

/// Utility to compute the tile sizes for RISC-V Vector.
/// For now, it only supports nonWideningLinalgElementType float.
/// TileSize is set to m = 7, n = maxNumberElementsForLMUL4, and k = 1.
///
/// Example: for an pure f32-matmul and a 512-bit vector register.
/// nativeVectorSize is equal to VLEN * LMUL2 / 8, so it's 128.
/// maxNumberElementsForLMUL4 = 128 * 2 * 8 / 32 = 64.
///
/// TODO: Currently it only supports for nonWideningLinalgElementType.
static void
getMatmulRISCVVectorSizes(mlir::FunctionOpInterface entryPointFn,
linalg::LinalgOp op, int64_t vectorSize,
SmallVectorImpl<int64_t> &sizes,
SmallVectorImpl<bool> &scalableSizeFlags) {
if (sizes.empty())
getDefaultMatmulVectorSizes(op, vectorSize, sizes, scalableSizeFlags);
// TODO: support widening matmul.
// Determines n dimension tile size with VLEN for
// nonWideningLinalgElementType.
FailureOr<Type> elementType = nonWideningLinalgElementType(op);
if (failed(elementType))
return;

// nativeVectorSize is cacluated with VLEN and LMUL=2.
int64_t nativeVectorSize = getNativeVectorSizeInBytes(entryPointFn);
int64_t elementSize;
if (elementType->isF16()) {
elementSize = 16;
} else if (elementType->isF32()) {
elementSize = 32;
} else if (elementType->isF64()) {
elementSize = 64;
} else {
// TODO: support int data type
return;
}
FailureOr<linalg::ContractionDimensions> cDims =
linalg::inferContractionDims(op);
if (failed(cDims) || cDims->m.size() != 1)
return;
// Use 7 x lmul4 to fully utilize vector registers.
sizes[0] = 7;
// Calculate tile size for the main vector dimension (N).
constexpr int64_t kByteSizeInBits = 8;
int64_t maxNumberElementsForLMUL4 =
(nativeVectorSize * 2 * kByteSizeInBits) / elementSize;
sizes[1] = maxNumberElementsForLMUL4;
sizes[2] = 1;
ArrayRef<int64_t> lhsShape = op.getShape(op.getDpsInputOperand(0));
// If m = 1, set tile size to 1 x lmul8
if (lhsShape[cDims->m[0]] == 1) {
sizes[0] = 1;
sizes[1] *= 2;
}
}

/// Utility to compute the tile sizes for AArch64 SME. Unlike other targets, the
/// tile sizes picked here must exactly match multiples of the SME hardware
/// virtual tiles, as there is currently no support for lowering non-standard
Expand Down Expand Up @@ -1354,6 +1419,16 @@ getMatmulVectorSizes(mlir::FunctionOpInterface entryPointFn,
}
}

if (isRISCV(targetAttr) && hasAnyVFeature(targetAttr)) {
// Use default tile size for matmul_transpose_b &
// batch_matmul_transpose_b to avoid performance drop.
if (!isa<linalg::MatmulTransposeBOp, linalg::BatchMatmulTransposeBOp>(op)) {
// Try to maximize the vector register utilization rate for matmul.
getMatmulRISCVVectorSizes(entryPointFn, op, vectorSize, matmulTileSizes,
matmulScalableFlags);
}
}

// If tile sizes were not computed by previous heuristics, use default
// hard-coded tile sizes.
if (matmulTileSizes.empty()) {
Expand Down Expand Up @@ -1494,6 +1569,25 @@ setRootConfig(mlir::FunctionOpInterface entryPointFn,
int64_t minTileSize = cacheTileSize != 0 ? cacheTileSize : vecTileSize;
distConfig.minTileSizes.push_back(minTileSize);
}
// FIXME: Apply maxTileSize modification for all targets.
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(entryPointFn);
if (isRISCV(targetAttr) && hasAnyVFeature(targetAttr)) {
LLVM_DEBUG(KD_DBGS() << "RISC-V Aggressive Distribution: "
<< clEnableRiscvAggressiveDist << "\n");
for (auto loopNum :
llvm::seq<unsigned>(static_cast<unsigned>(isBM), numLoops)) {
if (clEnableRiscvAggressiveDist) {
if (distConfig.maxTileSizes[loopNum] <=
distConfig.minTileSizes[loopNum]) {
distConfig.maxTileSizes[loopNum] =
2 * distConfig.minTileSizes[loopNum];
}
} else {
distConfig.maxTileSizes[loopNum] = std::max(
distConfig.maxTileSizes[loopNum], distConfig.minTileSizes[loopNum]);
}
}
}
SmallVector<int64_t> distTileSizes =
getDefaultDistributedLevelTileSizes(linalgOp, distConfig);

Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ bool hasZve64xFeature(IREE::HAL::ExecutableTargetAttr targetAttr) {
return hasFeature(targetAttr, "+zve64x");
}

bool hasAnyVFeature(IREE::HAL::ExecutableTargetAttr targetAttr) {
return hasVFeature(targetAttr) || hasZve32xFeature(targetAttr) ||
hasZve32fFeature(targetAttr) || hasZve64xFeature(targetAttr) ||
hasFeature(targetAttr, "+zve64f") || hasFeature(targetAttr, "+zve64d");
}

bool hasAnySVEFeature(IREE::HAL::ExecutableTargetAttr targetAttr) {
return hasFeature(targetAttr, "+sve") || hasFeature(targetAttr, "+sve2") ||
hasFeature(targetAttr, "+v9a");
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ bool hasZve32fFeature(IREE::HAL::ExecutableTargetAttr targetAttr);
/// Returns true if the 'targetAttr' contains '+zve64x' in its cpu features.
bool hasZve64xFeature(IREE::HAL::ExecutableTargetAttr targetAttr);

/// Returns true if the 'targetAttr' contains any riscv vector feature in its
/// cpu features.
bool hasAnyVFeature(IREE::HAL::ExecutableTargetAttr targetAttr);

/// Returns true if the 'targetAttr' contains '+sve' or '+sve2' in its cpu
/// features or any other feature flag that includes them.
bool hasAnySVEFeature(IREE::HAL::ExecutableTargetAttr targetAttr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ iree_lit_test_suite(
"pipeline_pad_conv_tests.mlir",
"pipeline_pad_tests.mlir",
"pipeline_peel_and_vectorize_tests.mlir",
"pipeline_riscv_aggressive_distribution_tests.mlir",
"pipeline_split_reduction_tests.mlir",
"pipeline_tests.mlir",
"pipeline_transpose_avx2_tests.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ iree_lit_test_suite(
"pipeline_pad_conv_tests.mlir"
"pipeline_pad_tests.mlir"
"pipeline_peel_and_vectorize_tests.mlir"
"pipeline_riscv_aggressive_distribution_tests.mlir"
"pipeline_split_reduction_tests.mlir"
"pipeline_tests.mlir"
"pipeline_transpose_avx2_tests.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: iree-opt --iree-llvmcpu-riscv-aggressive-distribution=true --pass-pipeline='builtin.module(iree-llvmcpu-select-lowering-strategy, func.func(iree-llvmcpu-lower-executable-target))' --split-input-file %s | FileCheck %s

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>

#executable_target_embedded_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-riscv_64", {cpu_features = "+m,+a,+f,+d,+zvl1024b,+v", data_layout = "e-m:e-p:64:64-i64:64-i256:256-n32:64-S256", native_vector_size = 256 : index, target_triple = "riscv64-unknown-unknown-eabi-elf"}>
builtin.module {
func.func @f32_rvv_matmul() attributes {hal.executable.target = #executable_target_embedded_elf_riscv_64_} {
%cst = arith.constant 0.0 : f32
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor<readonly:tensor<384x512xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor<readonly:tensor<512x256xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor<writeonly:tensor<384x256xf32>>
%lhs = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<384x512xf32>> -> tensor<384x512xf32>
%rhs = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<512x256xf32>> -> tensor<512x256xf32>
%init = tensor.empty() : tensor<384x256xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<384x256xf32>) -> tensor<384x256xf32>
%res = linalg.matmul ins(%lhs, %rhs : tensor<384x512xf32>, tensor<512x256xf32>) outs(%fill : tensor<384x256xf32>) -> tensor<384x256xf32>
flow.dispatch.tensor.store %res, %2, offsets = [0, 0], sizes = [384, 256], strides = [1, 1] : tensor<384x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x256xf32>>
return
}
}
// CHECK-LABEL: func.func @f32_rvv_matmul(
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[c128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[c256:.+]] = arith.constant 256 : index
// CHECK-DAG: %[[c512:.+]] = arith.constant 512 : index
// CHECK: scf.for {{.*}} step %[[c7]]
// CHECK: scf.for {{.*}} step %[[c128]]
// CHECK: scf.for {{.*}} step %[[c1]]
// CHECK-COUNT-7: vector.fma
// CHECK-COUNT-7: vector.store
// CHECK: scf.for {{.*}} step %[[c128]]
// CHECK: scf.for {{.*}} step %[[c1]]
// CHECK-COUNT-4: vector.fma
// CHECK-COUNT-4: vector.store
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmcpu-select-lowering-strategy)' --split-input-file %s | FileCheck %s
// RUN: iree-opt --iree-llvmcpu-riscv-aggressive-distribution=true --pass-pipeline='builtin.module(iree-llvmcpu-select-lowering-strategy)' --split-input-file %s | FileCheck %s -check-prefixes=CHECK-AGGRESSIVE

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
Expand Down Expand Up @@ -30,6 +31,113 @@ func.func @matmul_riscv() attributes {hal.executable.target = #executable_target

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>

#executable_target_embedded_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-riscv_64", {cpu_features = "+m,+a,+f,+d,+zvl512b,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 128 : index, target_triple = "riscv64-unknown-unknown-eabi-elf"}>
builtin.module {
func.func @matmul_gemm_riscv_vl512() attributes {hal.executable.target = #executable_target_embedded_elf_riscv_64_} {
%cst = arith.constant 0.0 : f32
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor<readonly:tensor<384x512xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor<readonly:tensor<512x128xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor<writeonly:tensor<384x128xf32>>
%lhs = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<384x512xf32>> -> tensor<384x512xf32>
%rhs = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<512x128xf32>> -> tensor<512x128xf32>
%init = tensor.empty() : tensor<384x128xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<384x128xf32>) -> tensor<384x128xf32>
%res = linalg.matmul ins(%lhs, %rhs : tensor<384x512xf32>, tensor<512x128xf32>) outs(%fill : tensor<384x128xf32>) -> tensor<384x128xf32>
flow.dispatch.tensor.store %res, %2, offsets = [0, 0], sizes = [384, 128], strides = [1, 1] : tensor<384x128xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x128xf32>>
return
}
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64], [7, 64], [0, 0], [0, 0]]>
// CHECK-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [64, 64, 0], [0, 0, 0], [7, 64, 0], [0, 0, 1], [0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert, {{\{}}enable_loop_peeling}>
// CHECK: func.func @matmul_gemm_riscv_vl512()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.matmul
// CHECK-SAME: lowering_config = #[[CONFIG2]]

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>

#executable_target_embedded_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-riscv_64", {cpu_features = "+m,+a,+f,+d,+zvl1024b,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 256 : index, target_triple = "riscv64-unknown-unknown-eabi-elf"}>
builtin.module {
func.func @matmul_gemm_riscv_vl1024() attributes {hal.executable.target = #executable_target_embedded_elf_riscv_64_} {
%cst = arith.constant 0.0 : f32
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor<readonly:tensor<384x512xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor<readonly:tensor<512x256xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor<writeonly:tensor<384x256xf32>>
%lhs = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<384x512xf32>> -> tensor<384x512xf32>
%rhs = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<512x256xf32>> -> tensor<512x256xf32>
%init = tensor.empty() : tensor<384x256xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<384x256xf32>) -> tensor<384x256xf32>
%res = linalg.matmul ins(%lhs, %rhs : tensor<384x512xf32>, tensor<512x256xf32>) outs(%fill : tensor<384x256xf32>) -> tensor<384x256xf32>
flow.dispatch.tensor.store %res, %2, offsets = [0, 0], sizes = [384, 256], strides = [1, 1] : tensor<384x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x256xf32>>
return
}
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 128], [7, 128], [0, 0], [0, 0]]>
// CHECK-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 128, 0], [64, 128, 0], [0, 0, 0], [7, 128, 0], [0, 0, 1], [0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert, {{\{}}enable_loop_peeling}>
// CHECK: func.func @matmul_gemm_riscv_vl1024()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.matmul
// CHECK-SAME: lowering_config = #[[CONFIG2]]

// CHECK-AGGRESSIVE-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[32, 256], [7, 128], [0, 0], [0, 0]]>
// CHECK-AGGRESSIVE-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[32, 256, 0], [32, 256, 0], [0, 0, 0], [7, 128, 0], [0, 0, 1], [0, 0, 0]]>
// CHECK-AGGRESSIVE-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert, {{\{}}enable_loop_peeling}>
// CHECK-AGGRESSIVE: func.func @matmul_gemm_riscv_vl1024()
// CHECK-AGGRESSIVE-SAME: translation_info = #[[TRANSLATION]]
// CHECK-AGGRESSIVE: linalg.matmul
// CHECK-AGGRESSIVE-SAME: lowering_config = #[[CONFIG2]]

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>

#executable_target_embedded_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-riscv_64", {cpu_features = "+m,+a,+f,+d,+zvl512b,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 128 : index, target_triple = "riscv64-unknown-unknown-eabi-elf"}>
builtin.module {
func.func @matmul_gemv_riscv_vl512() attributes {hal.executable.target = #executable_target_embedded_elf_riscv_64_} {
%cst = arith.constant 0.0 : f32
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor<readonly:tensor<1x512xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor<readonly:tensor<512x128xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor<writeonly:tensor<1x128xf32>>
%lhs = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x512xf32>> -> tensor<1x512xf32>
%rhs = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<512x128xf32>> -> tensor<512x128xf32>
%init = tensor.empty() : tensor<1x128xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x128xf32>) -> tensor<1x128xf32>
%res = linalg.matmul ins(%lhs, %rhs : tensor<1x512xf32>, tensor<512x128xf32>) outs(%fill : tensor<1x128xf32>) -> tensor<1x128xf32>
flow.dispatch.tensor.store %res, %2, offsets = [0, 0], sizes = [1, 128], strides = [1, 1] : tensor<1x128xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x128xf32>>
return
}
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 128], [1, 128], [0, 0], [0, 0]]>
// CHECK-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 128, 0], [0, 128, 0], [0, 0, 0], [1, 128, 0], [0, 0, 1], [0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert, {{\{}}enable_loop_peeling}>
// CHECK: func.func @matmul_gemv_riscv_vl512()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.matmul
// CHECK-SAME: lowering_config = #[[CONFIG2]]

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
Expand Down

0 comments on commit 17fde4d

Please sign in to comment.