diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp index 2c71aa9444fa..fcfec190ab41 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp @@ -195,7 +195,7 @@ class VectorReductionToGPUPass bool expandSubgroupReduction, std::function getWarpSize) : expandSubgroupReduction(expandSubgroupReduction), - getWarpSize(getWarpSize) {} + getWarpSize(std::move(getWarpSize)) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert>, vector<1xf32> // CHECK: vector.transfer_write {{.*}} : vector<1xf32>, memref<128x32xf32> // CHECK: return + + +// ----- + +// Check that we multi-row matvec gets distributed across subgoroup threads. + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}> +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable private @multirow { + hal.executable.variant @rocm target(#executable_target_rocm_hsaco_fb) { + hal.executable.export @multirow layout(#pipeline_layout) attributes { + workgroup_size = [64 : index, 1 : index, 1 : index] + } + builtin.module { + func.func @multirow() { + %cst = arith.constant dense<0.000000e+00> : vector<4x512xf16> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant dense<0.000000e+00> : vector<1x4xf16> + %c4096 = arith.constant 4096 : index + %c512 = arith.constant 512 : index + %cst_1 = arith.constant 0.000000e+00 : f16 + %id = gpu.thread_id x + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x4096xf16, #hal.descriptor_type> + memref.assume_alignment %0, 64 : memref<1x4096xf16, #hal.descriptor_type> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<32000x4096xf16, #hal.descriptor_type> + memref.assume_alignment %1, 64 : memref<32000x4096xf16, #hal.descriptor_type> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<1x32000xf16, #hal.descriptor_type> + memref.assume_alignment %2, 64 : memref<1x32000xf16, #hal.descriptor_type> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x] + %4 = scf.for %arg0 = %c0 to %c4096 step %c512 iter_args(%arg1 = %cst) -> (vector<4x512xf16>) { + %8 = vector.transfer_read %0[%c0, %arg0], %cst_1 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (0, d1)>} : memref<1x4096xf16, #hal.descriptor_type>, vector<4x512xf16> + %9 = vector.transfer_read %1[%3, %arg0], %cst_1 {in_bounds = [true, true]} : memref<32000x4096xf16, #hal.descriptor_type>, vector<4x512xf16> + %10 = arith.mulf %8, %9 : vector<4x512xf16> + %11 = arith.addf %arg1, %10 : vector<4x512xf16> + scf.yield %11 : vector<4x512xf16> + } + %5 = vector.broadcast %4 : vector<4x512xf16> to vector<1x4x512xf16> + %6 = vector.multi_reduction , %5, %cst_0 [2] : vector<1x4x512xf16> to vector<1x4xf16> + %7 = vector.extract %6[0] : vector<4xf16> from vector<1x4xf16> + vector.transfer_write %7, %2[%c0, %3] {in_bounds = [true]} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type> + return + } + } + } +} + +// CHECK-LABEL: func.func @multirow() { +// CHECK: scf.for {{.*}} -> (vector<4x8xf16>) { +// CHECK: vector.transfer_read {{.*}} : memref<32000x4096xf16, #hal.descriptor_type>, vector<4x8xf16> +// CHECK: vector.transfer_read {{.*}} : memref<1x4096xf16, #hal.descriptor_type>, vector<4x8xf16> +// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<4x8xf16> +// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4x8xf16> +// CHECK: } +// CHECK: gpu.shuffle xor +// CHECK: scf.if {{.*}} { +// CHECK: vector.transfer_write {{.*}} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type> +// CHECK: } +// CHECK-NEXT: return diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index ef26385f57ea..577c248cc09e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -22,7 +22,9 @@ #include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -924,6 +926,25 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, if ((groupSize / subgroupSize) > subgroupSize) return failure(); + // With just one subgroup per workgroup, make each subgroup do more work and + // process a few reductions along the last parallel dimension. + // TODO: We should also check that this will result in data reuse for at least + // one argument. + // TODO: This is experimental for matvec (matmul_transpose_b) on rocm-only for + // now. + if (numDynamicReductionDims == 0 && numParallelDims == 2 && + isRocmTarget(entryPoint)) { + if (*parallelSize && !parallelDims.empty() && groupSize == subgroupSize) { + int maxParallelFactor = 4; // Keeping this conservative for now. + int64_t lastParallelBound = bounds[parallelDims.back()]; + if (!ShapedType::isDynamic(lastParallelBound) && + (lastParallelBound % maxParallelFactor == 0) && + lastParallelBound > maxParallelFactor) { + workgroupTileSizes.back() = maxParallelFactor; + } + } + } + std::array workgroupSize = {groupSize, 1, 1}; SmallVector reductionTileSizes(op.getNumLoops(), 0); int64_t remainingGroupSize = groupSize; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir index 47b34315160d..2cfa7a8b3aeb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir @@ -50,3 +50,50 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gf // CHECK: func.func @dynamic_batch_matvec() // CHECK: linalg.batch_matmul // CHECK-SAME: lowering_config = #[[$CONFIG]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> + +hal.executable @vmt { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}>) { + hal.executable.export @vmt layout(#pipeline_layout) + builtin.module { + func.func @vmt() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 4096], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<1x4096xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<32000x4096xf16> + %5 = tensor.empty() : tensor<1x32000xf16> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<1x32000xf16>) -> tensor<1x32000xf16> + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<1x4096xf16>, tensor<32000x4096xf16>) outs(%6 : tensor<1x32000xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %out, %8 : f16 + linalg.yield %9 : f16 + } -> tensor<1x32000xf16> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 32000], strides = [1, 1] : tensor<1x32000xf16> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-LABEL: hal.executable.export public @vmt +// CHECK-SAME: subgroup_size = 64 : index +// CHECK-SAME: translation_info = #[[$TRANSLATION]] +// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index] +// CHECK: func.func @vmt() +// CHECK: linalg.generic +// CHECK-SAME: lowering_config = #[[$CONFIG]]