From 1ba5e3786a58db6b8d00737d9745d3f60431e617 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Fri, 29 Sep 2023 10:33:23 -0700 Subject: [PATCH] [LLVMGPU] Enable WarpReduction on ROCM + Let matvec use Warp Reduce. (#15034) This patch does two things: 1.Enables Warp reduction config on ROCm 2.Mirror SPIR-V logic for letting Matvec go down subgroup/warp reduce pipeline. --- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 98 ++++++++++++++++--- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 4 +- .../compiler/Codegen/LLVMGPU/test/BUILD.bazel | 6 +- .../Codegen/LLVMGPU/test/CMakeLists.txt | 6 +- ...line.mlir => reduction_pipeline_cuda.mlir} | 1 - .../LLVMGPU/test/reduction_pipeline_rocm.mlir | 35 +++++++ ...=> reduction_pipeline_transform_cuda.mlir} | 66 +++++++++++++ .../reduction_pipeline_transform_rocm.mlir | 98 +++++++++++++++++++ .../compiler/Codegen/SPIRV/KernelConfig.cpp | 11 ++- 9 files changed, 302 insertions(+), 23 deletions(-) rename compiler/src/iree/compiler/Codegen/LLVMGPU/test/{reduction_pipeline.mlir => reduction_pipeline_cuda.mlir} (99%) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir rename compiler/src/iree/compiler/Codegen/LLVMGPU/test/{reduction_pipeline_transform.mlir => reduction_pipeline_transform_cuda.mlir} (82%) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_rocm.mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 10e53b29e542..8e653c22b09f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -30,6 +30,7 @@ using namespace mlir::iree_compiler; static constexpr unsigned cudaWarpSize = 32; static constexpr StringLiteral kCudaTarget = "cuda"; +static constexpr StringLiteral kRocmTarget = "rocm"; namespace mlir { namespace iree_compiler { llvm::cl::opt clGPUCodegenTransformDialectFileName( @@ -162,11 +163,19 @@ bool isCudaTarget(func::FuncOp entryPoint) { return false; } -static TargetInfo getTargetInfo(func::FuncOp entryPoint) { +bool isRocmTarget(func::FuncOp entryPoint) { + if (auto variantOp = + entryPoint->getParentOfType()) { + IREE::HAL::ExecutableTargetAttr targetAttr = variantOp.getTarget(); + if (auto backend = targetAttr.getBackend()) { + return backend.getValue().str() == kRocmTarget; + } + } + return false; +} + +static TargetInfo getCudaTargetInfo(func::FuncOp entryPoint) { TargetInfo info; - // TODO: fill out target info for other vendors. - if (!isCudaTarget(entryPoint)) - return info; // All the cuda target are assumed to have warp support. info.hasWarpShuffle = true; StringRef targetName = getTargetArch(entryPoint); @@ -190,6 +199,34 @@ static TargetInfo getTargetInfo(func::FuncOp entryPoint) { return info; } +// TODO: Plumb in WarpSize into TargetInfo for wave64 systems. +static TargetInfo getRocmTargetInfo(func::FuncOp entryPoint) { + TargetInfo info; + StringRef targetName = getTargetArch(entryPoint); + // If no target name is set assume all the features are off. + if (targetName == "") + return info; + if (!targetName.starts_with("gfx")) { + entryPoint.emitError("unknown target name ") << targetName; + return info; + } + // Assumes all gfx has warp shuffle. + info.hasWarpShuffle = true; + // TODO: Check and enable for WMMA once pipeline is available. + return info; +} + +static TargetInfo getTargetInfo(func::FuncOp entryPoint) { + TargetInfo info; + // TODO: fill out target info for other vendors. + if (isCudaTarget(entryPoint)) { + info = getCudaTargetInfo(entryPoint); + } else if (isRocmTarget(entryPoint)) { + info = getRocmTargetInfo(entryPoint); + } + return info; +} + static bool supportsTensorCore(func::FuncOp entryPoint, linalg::LinalgOp op, const TargetInfo &targetInfo) { // Limit tensor core pipeline to matmul as not all combinations of transpose @@ -254,6 +291,20 @@ static LogicalResult setContractConfig(func::FuncOp entryPoint, if (!linalg::isaContractionOpInterface(op) || op.getNumParallelLoops() < 2) { return failure(); } + + // Also exclude the case of matvec, which has only one non-unit parallel dim. + // They should go down different pipelines. + int nonUnitParallelDimCount = 0; + SmallVector bounds = op.getStaticLoopRanges(); + SmallVector kinds = op.getIteratorTypesArray(); + for (auto [kind, bound] : llvm::zip(kinds, bounds)) { + if (kind == utils::IteratorType::parallel) + nonUnitParallelDimCount += bound != 1; + } + if (!isa(op) && + nonUnitParallelDimCount == 1) + return failure(); + // Don't consider operations that don't have a broadcast, those should go // through reductions. if (llvm::any_of(op.getIndexingMapsArray(), @@ -750,13 +801,23 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, } SmallVector reductionDims; op.getReductionDims(reductionDims); - if (reductionDims.size() != 1 || reductionDims[0] != op.getNumLoops() - 1) + if (reductionDims.empty()) + return failure(); + + // Make sure reduction dimensions are the innermost ones. + int64_t numParallelDims = op.getNumParallelLoops(); + if (llvm::any_of(reductionDims, [&](int64_t reductionDim) { + return reductionDim < numParallelDims; + })) { return failure(); + } + if (op.getRegionOutputArgs().size() != 1) return failure(); // Only support projected permutation, this could be extended to projected // permutated with broadcast. + if (llvm::any_of(op.getDpsInputOperands(), [&](OpOperand *input) { return !op.getMatchingIndexingMap(input).isProjectedPermutation(); })) @@ -779,8 +840,11 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, if (!foundSingleReductionOutput) return failure(); - std::optional dimSize = getLinalgDimSize(op, reductionDims[0]); - if (!dimSize || *dimSize % cudaWarpSize != 0) + SmallVector bounds = op.getStaticLoopRanges(); + int64_t dimSize = 1; + for (int64_t dim : reductionDims) + dimSize *= bounds[dim]; + if (dimSize % cudaWarpSize != 0) return failure(); const Type elementType = @@ -795,12 +859,12 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, const unsigned largestLoadSizeInBits = 128; unsigned vectorSize = largestLoadSizeInBits / bitWidth; - while ((*dimSize / vectorSize) % cudaWarpSize != 0) + while ((dimSize / vectorSize) % cudaWarpSize != 0) vectorSize /= 2; // TODO: Add reduction tiling to handle larger reductions. const int64_t maxWorkgroupSize = 1024; - int64_t groupSize = *dimSize / vectorSize; + int64_t groupSize = dimSize / vectorSize; if (groupSize > maxWorkgroupSize) { groupSize = llvm::APIntOps::GreatestCommonDivisor( {64, uint64_t(groupSize)}, {64, uint64_t(maxWorkgroupSize)}) @@ -813,8 +877,20 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1; // Tile all the parallel dimension to 1. SmallVector workgroupTileSizes(numLoops, 1); - SmallVector reductionTileSizes(numLoops, 0); - reductionTileSizes.push_back(groupSize * vectorSize); + SmallVector reductionTileSizes(op.getNumLoops(), 0); + int64_t remaingGroupSize = groupSize; + for (int i = reductionDims.size() - 1; i >= 0; --i) { + int64_t dim = reductionDims[i]; + int64_t bound = bounds[dim]; + if (i == reductionDims.size() - 1) + bound /= vectorSize; + APInt size = llvm::APIntOps::GreatestCommonDivisor( + {64, uint64_t(remaingGroupSize)}, {64, uint64_t(bound)}); + reductionTileSizes[dim] = size.getSExtValue(); + if (i == reductionDims.size() - 1) + reductionTileSizes[dim] *= vectorSize; + remaingGroupSize /= size.getSExtValue(); + } TileSizesListType tileSizes; tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level tileSizes.emplace_back(std::move(reductionTileSizes)); // reduction level diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index d4503cc806ee..ccde7e9eef8d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -374,6 +374,8 @@ void addGPUTransposePassPipeline(OpPassManager &pm) { void addGPUWarpReductionPassPipeline(OpPassManager &pm) { tileAndDistributeToWorkgroup(pm); auto &nestedModulePM = pm.nest(); + nestedModulePM.addNestedPass( + createRematerializeParallelOpsPass()); nestedModulePM.addNestedPass(createCanonicalizerPass()); nestedModulePM.addNestedPass(createGPUTileReductionPass()); nestedModulePM.addNestedPass(createCanonicalizerPass()); @@ -581,8 +583,6 @@ void addGPUTransformDialectPasses(OpPassManager &passManager) { void buildLLVMGPUTransformPassPipeline(OpPassManager &pm, bool useROCM) { addCommonTargetExecutablePreprocessingPasses(pm.nest()); - pm.nest().addNestedPass( - createRematerializeParallelOpsPass()); pm.addPass(createLLVMGPULowerExecutableTargetPass()); OpPassManager &nestedModulePM = pm.nest(); //===--------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index 83de7d0e89a9..3d72d666fa71 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -31,8 +31,10 @@ iree_lit_test_suite( "nvvm_extract_address_computation.mlir", "nvvm_pipeline_test.mlir", "nvvm_mma_sync_pipeline_test.mlir", - "reduction_pipeline_transform.mlir", - "reduction_pipeline.mlir", + "reduction_pipeline_cuda.mlir", + "reduction_pipeline_rocm.mlir", + "reduction_pipeline_transform_cuda.mlir", + "reduction_pipeline_transform_rocm.mlir", "rocdl_pipeline_test.mlir", "set_transform_strategy_batch_matmul.mlir", "set_transform_strategy_convolution.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index cf76f25a60fc..0997aa853780 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -33,8 +33,10 @@ iree_lit_test_suite( "nvvm_pipeline_test.mlir" "pack_pipeline_test.mlir" "pack_shared_memory_alloc.mlir" - "reduction_pipeline.mlir" - "reduction_pipeline_transform.mlir" + "reduction_pipeline_cuda.mlir" + "reduction_pipeline_rocm.mlir" + "reduction_pipeline_transform_cuda.mlir" + "reduction_pipeline_transform_rocm.mlir" "rocdl_pipeline_test.mlir" "set_transform_strategy_batch_matmul.mlir" "set_transform_strategy_convolution.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir similarity index 99% rename from compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline.mlir rename to compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir index 619a226bec49..3394639c1717 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir @@ -39,7 +39,6 @@ hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> { } } -#map = affine_map<()[s0] -> (s0 * 4)> // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4)> // CHECK-LABEL: func.func @warp_reduction_dispatch // CHECK-DAG: %[[C0I:.+]] = arith.constant 0 : i32 diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir new file mode 100644 index 000000000000..565ccc5a78e2 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir @@ -0,0 +1,35 @@ +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-linalg-ext-decompose-softmax)), iree-llvmgpu-lower-executable-target)))" %s | FileCheck %s + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> +hal.executable @softmax { +hal.executable.variant @rocm, target = <"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100"}> { + hal.executable.export @softmax layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @softmax() { + %c0 = arith.constant 0 : index + %cst = arith.constant -3.40282347E+38 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant 1.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<12x128x40960xf32> + %3 = tensor.empty() : tensor<12x128x40960xf32> + %4 = iree_linalg_ext.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32> + flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : tensor<12x128x40960xf32> -> !flow.dispatch.tensor> + return + } + } +} +} + +// CHECK-LABEL: func.func @softmax +// CHECK-COUNT-20: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir similarity index 82% rename from compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir rename to compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir index 321d2eb2f9cf..2ddd6ac91109 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir @@ -462,3 +462,69 @@ hal.executable @reduction_2d_trailing_elementwise_static_dispatch_0 { // CHECK: arith.divf {{.*}} : vector<1x2xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<1x2xf32>, memref<1x10xf32, strided<[10, 1], offset: ?>, #hal.descriptor_type> // CHECK: gpu.barrier + +// ----- + +hal.executable private @i4_dequant_matvec { + hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}> { + hal.executable.export public @i4_dequant_matvec ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer, ReadOnly>, <4, storage_buffer>]>]>) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @i4_dequant_matvec() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4096, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<4096x32x128xi4> + %6 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x32xf16> + %7 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x32xf16> + %8 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<32x128xf16> + %9 = tensor.empty() : tensor<4096xf16> + %10 = tensor.empty() : tensor<4096x32x128xf16> + %11 = linalg.fill ins(%cst : f16) outs(%9 : tensor<4096xf16>) -> tensor<4096xf16> + %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5, %6, %7 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>) outs(%10 : tensor<4096x32x128xf16>) { + ^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16): + %14 = arith.extui %in : i4 to i32 + %15 = arith.uitofp %14 : i32 to f16 + %16 = arith.subf %15, %in_1 : f16 + %17 = arith.mulf %16, %in_0 : f16 + linalg.yield %17 : f16 + } -> tensor<4096x32x128xf16> + %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%8, %12 : tensor<32x128xf16>, tensor<4096x32x128xf16>) outs(%11 : tensor<4096xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %14 = arith.mulf %in, %in_0 : f16 + %15 = arith.addf %14, %out : f16 + linalg.yield %15 : f16 + } -> tensor<4096xf16> + flow.dispatch.tensor.store %13, %4, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf16> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-LABEL: func.func @i4_dequant_matvec() +// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x8xf16> +// CHECK: %[[READ0:.+]] = vector.transfer_read {{.+}} : memref<4096x32x128xi4, #hal.descriptor_type>, vector<1x8xi4> +// CHECK: %[[READ1:.+]] = vector.transfer_read {{.+}} : memref<4096x32xf16, #hal.descriptor_type>, vector<1x8xf16> +// CHECK: %[[READ2:.+]] = vector.transfer_read {{.+}} : memref<4096x32xf16, #hal.descriptor_type>, vector<1x8xf16> +// CHECK: %[[READ3:.+]] = vector.transfer_read {{.+}} : memref<32x128xf16, #hal.descriptor_type>, vector<1x8xf16> +// CHECK: %[[EXTEND:.+]] = arith.extui %[[READ0]] : vector<1x8xi4> to vector<1x8xi32> +// CHECK: %[[CVT:.+]] = arith.uitofp %[[EXTEND]] : vector<1x8xi32> to vector<1x8xf16> +// CHECK: %[[SUB:.+]] = arith.subf %[[CVT]], %[[READ1]] : vector<1x8xf16> +// CHECK: %[[MUL0:.+]] = arith.mulf %[[SUB]], %[[READ2]] : vector<1x8xf16> +// CHECK: %[[MUL1:.+]] = arith.mulf %[[READ3]], %[[MUL0]] : vector<1x8xf16> +// CHECK: %[[ADD:.+]] = arith.addf %[[MUL1]], %[[CST]] : vector<1x8xf16> + +// CHECK: %[[SCAST:.+]] = vector.shape_cast %[[ADD]] : vector<1x8xf16> to vector<8xf16> +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SCAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: vector.reduction , %[[EXTRACT]] : vector<4xf16> into f16 +// CHECK-COUNT-9: gpu.shuffle xor +// CHECK: scf.if +// CHECK: vector.transfer_write diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_rocm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_rocm.mlir new file mode 100644 index 000000000000..da0dbdf019d0 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_rocm.mlir @@ -0,0 +1,98 @@ +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))" %s | FileCheck %s + +hal.executable @group_reduction_1d { +hal.executable.variant public @rocm_hsaco_fb, target = <"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100"}> { + hal.executable.export public @group_reduction_1d ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer>]>]>) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @group_reduction_1d() { + %c0 = arith.constant 0 : index + %cst = arith.constant -0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [64], strides = [1] : !flow.dispatch.tensor> -> tensor<64xf32> + %3 = tensor.empty() : tensor + %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor) -> tensor + %5 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%2 : tensor<64xf32>) outs(%4 : tensor) { + ^bb0(%in: f32, %out: f32): + %6 = arith.addf %in, %out : f32 + linalg.yield %6 : f32 + } -> tensor + flow.dispatch.tensor.store %5, %1, offsets = [], sizes = [], strides = [] : tensor -> !flow.dispatch.tensor> + return + } + } +} +} + +// CHECK-LABEL: func.func @group_reduction_1d +// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf + +// ----- + +hal.executable private @i4_dequant_matvec { + hal.executable.variant public @rocm_hsaco_fb, target = <"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100"}> { + hal.executable.export public @i4_dequant_matvec ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer, ReadOnly>, <4, storage_buffer>]>]>) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @i4_dequant_matvec() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4096, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<4096x32x128xi4> + %6 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x32xf16> + %7 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x32xf16> + %8 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<32x128xf16> + %9 = tensor.empty() : tensor<4096xf16> + %10 = tensor.empty() : tensor<4096x32x128xf16> + %11 = linalg.fill ins(%cst : f16) outs(%9 : tensor<4096xf16>) -> tensor<4096xf16> + %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5, %6, %7 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>) outs(%10 : tensor<4096x32x128xf16>) { + ^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16): + %14 = arith.extui %in : i4 to i32 + %15 = arith.uitofp %14 : i32 to f16 + %16 = arith.subf %15, %in_1 : f16 + %17 = arith.mulf %16, %in_0 : f16 + linalg.yield %17 : f16 + } -> tensor<4096x32x128xf16> + %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%8, %12 : tensor<32x128xf16>, tensor<4096x32x128xf16>) outs(%11 : tensor<4096xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %14 = arith.mulf %in, %in_0 : f16 + %15 = arith.addf %14, %out : f16 + linalg.yield %15 : f16 + } -> tensor<4096xf16> + flow.dispatch.tensor.store %13, %4, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf16> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-LABEL: func.func @i4_dequant_matvec() +// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x8xf16> +// CHECK: %[[READ0:.+]] = vector.transfer_read {{.+}} : memref<4096x32x128xi4, #hal.descriptor_type>, vector<1x8xi4> +// CHECK: %[[READ1:.+]] = vector.transfer_read {{.+}} : memref<4096x32xf16, #hal.descriptor_type>, vector<1x8xf16> +// CHECK: %[[READ2:.+]] = vector.transfer_read {{.+}} : memref<4096x32xf16, #hal.descriptor_type>, vector<1x8xf16> +// CHECK: %[[READ3:.+]] = vector.transfer_read {{.+}} : memref<32x128xf16, #hal.descriptor_type>, vector<1x8xf16> +// CHECK: %[[EXTEND:.+]] = arith.extui %[[READ0]] : vector<1x8xi4> to vector<1x8xi32> +// CHECK: %[[CVT:.+]] = arith.uitofp %[[EXTEND]] : vector<1x8xi32> to vector<1x8xf16> +// CHECK: %[[SUB:.+]] = arith.subf %[[CVT]], %[[READ1]] : vector<1x8xf16> +// CHECK: %[[MUL0:.+]] = arith.mulf %[[SUB]], %[[READ2]] : vector<1x8xf16> +// CHECK: %[[MUL1:.+]] = arith.mulf %[[READ3]], %[[MUL0]] : vector<1x8xf16> +// CHECK: %[[ADD:.+]] = arith.addf %[[MUL1]], %[[CST]] : vector<1x8xf16> + +// CHECK: %[[SCAST:.+]] = vector.shape_cast %[[ADD]] : vector<1x8xf16> to vector<8xf16> +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SCAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: vector.reduction , %[[EXTRACT]] : vector<4xf16> into f16 +// CHECK-COUNT-9: gpu.shuffle xor +// CHECK: scf.if +// CHECK: vector.transfer_write diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 9c586c6e6eb7..2b19af8bdbb7 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -1202,12 +1202,13 @@ static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv, return failure(); // Make sure reduction dimensions are the innermost ones. - for (int i = 0; i < reductionDims.size(); ++i) { - if (reductionDims[reductionDims.size() - 1 - i] != - op.getNumLoops() - 1 - i) { - return failure(); - } + int64_t numParallelDims = op.getNumParallelLoops(); + if (llvm::any_of(reductionDims, [&](int64_t reductionDim) { + return reductionDim < numParallelDims; + })) { + return failure(); } + if (op.getRegionOutputArgs().size() != 1) return failure();