Skip to content

Commit

Permalink
[LLVMGPU] Enable WarpReduction on ROCM + Let matvec use Warp Reduce. (i…
Browse files Browse the repository at this point in the history
…ree-org#15034)

This patch does two things:

1.Enables Warp reduction config on ROCm
2.Mirror SPIR-V logic for letting Matvec go down subgroup/warp reduce
  pipeline.
  • Loading branch information
raikonenfnu authored Sep 29, 2023
1 parent 750784d commit 1ba5e37
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 23 deletions.
98 changes: 87 additions & 11 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using namespace mlir::iree_compiler;

static constexpr unsigned cudaWarpSize = 32;
static constexpr StringLiteral kCudaTarget = "cuda";
static constexpr StringLiteral kRocmTarget = "rocm";
namespace mlir {
namespace iree_compiler {
llvm::cl::opt<std::string> clGPUCodegenTransformDialectFileName(
Expand Down Expand Up @@ -162,11 +163,19 @@ bool isCudaTarget(func::FuncOp entryPoint) {
return false;
}

static TargetInfo getTargetInfo(func::FuncOp entryPoint) {
bool isRocmTarget(func::FuncOp entryPoint) {
if (auto variantOp =
entryPoint->getParentOfType<IREE::HAL::ExecutableVariantOp>()) {
IREE::HAL::ExecutableTargetAttr targetAttr = variantOp.getTarget();
if (auto backend = targetAttr.getBackend()) {
return backend.getValue().str() == kRocmTarget;
}
}
return false;
}

static TargetInfo getCudaTargetInfo(func::FuncOp entryPoint) {
TargetInfo info;
// TODO: fill out target info for other vendors.
if (!isCudaTarget(entryPoint))
return info;
// All the cuda target are assumed to have warp support.
info.hasWarpShuffle = true;
StringRef targetName = getTargetArch(entryPoint);
Expand All @@ -190,6 +199,34 @@ static TargetInfo getTargetInfo(func::FuncOp entryPoint) {
return info;
}

// TODO: Plumb in WarpSize into TargetInfo for wave64 systems.
static TargetInfo getRocmTargetInfo(func::FuncOp entryPoint) {
TargetInfo info;
StringRef targetName = getTargetArch(entryPoint);
// If no target name is set assume all the features are off.
if (targetName == "")
return info;
if (!targetName.starts_with("gfx")) {
entryPoint.emitError("unknown target name ") << targetName;
return info;
}
// Assumes all gfx has warp shuffle.
info.hasWarpShuffle = true;
// TODO: Check and enable for WMMA once pipeline is available.
return info;
}

static TargetInfo getTargetInfo(func::FuncOp entryPoint) {
TargetInfo info;
// TODO: fill out target info for other vendors.
if (isCudaTarget(entryPoint)) {
info = getCudaTargetInfo(entryPoint);
} else if (isRocmTarget(entryPoint)) {
info = getRocmTargetInfo(entryPoint);
}
return info;
}

static bool supportsTensorCore(func::FuncOp entryPoint, linalg::LinalgOp op,
const TargetInfo &targetInfo) {
// Limit tensor core pipeline to matmul as not all combinations of transpose
Expand Down Expand Up @@ -254,6 +291,20 @@ static LogicalResult setContractConfig(func::FuncOp entryPoint,
if (!linalg::isaContractionOpInterface(op) || op.getNumParallelLoops() < 2) {
return failure();
}

// Also exclude the case of matvec, which has only one non-unit parallel dim.
// They should go down different pipelines.
int nonUnitParallelDimCount = 0;
SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
SmallVector<utils::IteratorType, 4> kinds = op.getIteratorTypesArray();
for (auto [kind, bound] : llvm::zip(kinds, bounds)) {
if (kind == utils::IteratorType::parallel)
nonUnitParallelDimCount += bound != 1;
}
if (!isa<linalg::MatmulOp, linalg::BatchMatmulOp>(op) &&
nonUnitParallelDimCount == 1)
return failure();

// Don't consider operations that don't have a broadcast, those should go
// through reductions.
if (llvm::any_of(op.getIndexingMapsArray(),
Expand Down Expand Up @@ -750,13 +801,23 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
}
SmallVector<unsigned> reductionDims;
op.getReductionDims(reductionDims);
if (reductionDims.size() != 1 || reductionDims[0] != op.getNumLoops() - 1)
if (reductionDims.empty())
return failure();

// Make sure reduction dimensions are the innermost ones.
int64_t numParallelDims = op.getNumParallelLoops();
if (llvm::any_of(reductionDims, [&](int64_t reductionDim) {
return reductionDim < numParallelDims;
})) {
return failure();
}

if (op.getRegionOutputArgs().size() != 1)
return failure();

// Only support projected permutation, this could be extended to projected
// permutated with broadcast.

if (llvm::any_of(op.getDpsInputOperands(), [&](OpOperand *input) {
return !op.getMatchingIndexingMap(input).isProjectedPermutation();
}))
Expand All @@ -779,8 +840,11 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
if (!foundSingleReductionOutput)
return failure();

std::optional<int64_t> dimSize = getLinalgDimSize(op, reductionDims[0]);
if (!dimSize || *dimSize % cudaWarpSize != 0)
SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
int64_t dimSize = 1;
for (int64_t dim : reductionDims)
dimSize *= bounds[dim];
if (dimSize % cudaWarpSize != 0)
return failure();

const Type elementType =
Expand All @@ -795,12 +859,12 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,

const unsigned largestLoadSizeInBits = 128;
unsigned vectorSize = largestLoadSizeInBits / bitWidth;
while ((*dimSize / vectorSize) % cudaWarpSize != 0)
while ((dimSize / vectorSize) % cudaWarpSize != 0)
vectorSize /= 2;

// TODO: Add reduction tiling to handle larger reductions.
const int64_t maxWorkgroupSize = 1024;
int64_t groupSize = *dimSize / vectorSize;
int64_t groupSize = dimSize / vectorSize;
if (groupSize > maxWorkgroupSize) {
groupSize = llvm::APIntOps::GreatestCommonDivisor(
{64, uint64_t(groupSize)}, {64, uint64_t(maxWorkgroupSize)})
Expand All @@ -813,8 +877,20 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
// Tile all the parallel dimension to 1.
SmallVector<int64_t> workgroupTileSizes(numLoops, 1);
SmallVector<int64_t> reductionTileSizes(numLoops, 0);
reductionTileSizes.push_back(groupSize * vectorSize);
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
int64_t remaingGroupSize = groupSize;
for (int i = reductionDims.size() - 1; i >= 0; --i) {
int64_t dim = reductionDims[i];
int64_t bound = bounds[dim];
if (i == reductionDims.size() - 1)
bound /= vectorSize;
APInt size = llvm::APIntOps::GreatestCommonDivisor(
{64, uint64_t(remaingGroupSize)}, {64, uint64_t(bound)});
reductionTileSizes[dim] = size.getSExtValue();
if (i == reductionDims.size() - 1)
reductionTileSizes[dim] *= vectorSize;
remaingGroupSize /= size.getSExtValue();
}
TileSizesListType tileSizes;
tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
tileSizes.emplace_back(std::move(reductionTileSizes)); // reduction level
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ void addGPUTransposePassPipeline(OpPassManager &pm) {
void addGPUWarpReductionPassPipeline(OpPassManager &pm) {
tileAndDistributeToWorkgroup(pm);
auto &nestedModulePM = pm.nest<ModuleOp>();
nestedModulePM.addNestedPass<func::FuncOp>(
createRematerializeParallelOpsPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
nestedModulePM.addNestedPass<func::FuncOp>(createGPUTileReductionPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
Expand Down Expand Up @@ -581,8 +583,6 @@ void addGPUTransformDialectPasses(OpPassManager &passManager) {

void buildLLVMGPUTransformPassPipeline(OpPassManager &pm, bool useROCM) {
addCommonTargetExecutablePreprocessingPasses(pm.nest<ModuleOp>());
pm.nest<ModuleOp>().addNestedPass<func::FuncOp>(
createRematerializeParallelOpsPass());
pm.addPass(createLLVMGPULowerExecutableTargetPass());
OpPassManager &nestedModulePM = pm.nest<ModuleOp>();
//===--------------------------------------------------------------------===//
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ iree_lit_test_suite(
"nvvm_extract_address_computation.mlir",
"nvvm_pipeline_test.mlir",
"nvvm_mma_sync_pipeline_test.mlir",
"reduction_pipeline_transform.mlir",
"reduction_pipeline.mlir",
"reduction_pipeline_cuda.mlir",
"reduction_pipeline_rocm.mlir",
"reduction_pipeline_transform_cuda.mlir",
"reduction_pipeline_transform_rocm.mlir",
"rocdl_pipeline_test.mlir",
"set_transform_strategy_batch_matmul.mlir",
"set_transform_strategy_convolution.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ iree_lit_test_suite(
"nvvm_pipeline_test.mlir"
"pack_pipeline_test.mlir"
"pack_shared_memory_alloc.mlir"
"reduction_pipeline.mlir"
"reduction_pipeline_transform.mlir"
"reduction_pipeline_cuda.mlir"
"reduction_pipeline_rocm.mlir"
"reduction_pipeline_transform_cuda.mlir"
"reduction_pipeline_transform_rocm.mlir"
"rocdl_pipeline_test.mlir"
"set_transform_strategy_batch_matmul.mlir"
"set_transform_strategy_convolution.mlir"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> {
}
}

#map = affine_map<()[s0] -> (s0 * 4)>
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-LABEL: func.func @warp_reduction_dispatch
// CHECK-DAG: %[[C0I:.+]] = arith.constant 0 : i32
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-linalg-ext-decompose-softmax)), iree-llvmgpu-lower-executable-target)))" %s | FileCheck %s

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
hal.executable @softmax {
hal.executable.variant @rocm, target = <"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100"}> {
hal.executable.export @softmax layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @softmax() {
%c0 = arith.constant 0 : index
%cst = arith.constant -3.40282347E+38 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<12x128x40960xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<12x128x40960xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<12x128x40960xf32>> -> tensor<12x128x40960xf32>
%3 = tensor.empty() : tensor<12x128x40960xf32>
%4 = iree_linalg_ext.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32>
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : tensor<12x128x40960xf32> -> !flow.dispatch.tensor<writeonly:tensor<12x128x40960xf32>>
return
}
}
}
}

// CHECK-LABEL: func.func @softmax
// CHECK-COUNT-20: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}}
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,69 @@ hal.executable @reduction_2d_trailing_elementwise_static_dispatch_0 {
// CHECK: arith.divf {{.*}} : vector<1x2xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<1x2xf32>, memref<1x10xf32, strided<[10, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
// CHECK: gpu.barrier

// -----

hal.executable private @i4_dequant_matvec {
hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}> {
hal.executable.export public @i4_dequant_matvec ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer, ReadOnly>, <4, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @i4_dequant_matvec() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4096, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4096x32x128xi4>
%6 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%7 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%8 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%9 = tensor.empty() : tensor<4096xf16>
%10 = tensor.empty() : tensor<4096x32x128xf16>
%11 = linalg.fill ins(%cst : f16) outs(%9 : tensor<4096xf16>) -> tensor<4096xf16>
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5, %6, %7 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>) outs(%10 : tensor<4096x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%14 = arith.extui %in : i4 to i32
%15 = arith.uitofp %14 : i32 to f16
%16 = arith.subf %15, %in_1 : f16
%17 = arith.mulf %16, %in_0 : f16
linalg.yield %17 : f16
} -> tensor<4096x32x128xf16>
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%8, %12 : tensor<32x128xf16>, tensor<4096x32x128xf16>) outs(%11 : tensor<4096xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%14 = arith.mulf %in, %in_0 : f16
%15 = arith.addf %14, %out : f16
linalg.yield %15 : f16
} -> tensor<4096xf16>
flow.dispatch.tensor.store %13, %4, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
}
}
}

// CHECK-LABEL: func.func @i4_dequant_matvec()
// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x8xf16>
// CHECK: %[[READ0:.+]] = vector.transfer_read {{.+}} : memref<4096x32x128xi4, #hal.descriptor_type<storage_buffer>>, vector<1x8xi4>
// CHECK: %[[READ1:.+]] = vector.transfer_read {{.+}} : memref<4096x32xf16, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
// CHECK: %[[READ2:.+]] = vector.transfer_read {{.+}} : memref<4096x32xf16, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
// CHECK: %[[READ3:.+]] = vector.transfer_read {{.+}} : memref<32x128xf16, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
// CHECK: %[[EXTEND:.+]] = arith.extui %[[READ0]] : vector<1x8xi4> to vector<1x8xi32>
// CHECK: %[[CVT:.+]] = arith.uitofp %[[EXTEND]] : vector<1x8xi32> to vector<1x8xf16>
// CHECK: %[[SUB:.+]] = arith.subf %[[CVT]], %[[READ1]] : vector<1x8xf16>
// CHECK: %[[MUL0:.+]] = arith.mulf %[[SUB]], %[[READ2]] : vector<1x8xf16>
// CHECK: %[[MUL1:.+]] = arith.mulf %[[READ3]], %[[MUL0]] : vector<1x8xf16>
// CHECK: %[[ADD:.+]] = arith.addf %[[MUL1]], %[[CST]] : vector<1x8xf16>

// CHECK: %[[SCAST:.+]] = vector.shape_cast %[[ADD]] : vector<1x8xf16> to vector<8xf16>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SCAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: vector.reduction <add>, %[[EXTRACT]] : vector<4xf16> into f16
// CHECK-COUNT-9: gpu.shuffle xor
// CHECK: scf.if
// CHECK: vector.transfer_write
Loading

0 comments on commit 1ba5e37

Please sign in to comment.