Skip to content

Commit

Permalink
[ROCM] Enable WarpReduction on ROCM + Matvec on GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
raikonenfnu authored and github-actions[bot] committed Sep 14, 2023
1 parent dcb3353 commit 1431d8f
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase<ConvertToROCDLPass> {
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
}
ConvertToDynamicSharedMemory(m);
}
};

Expand Down
101 changes: 90 additions & 11 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using namespace mlir::iree_compiler;

static constexpr unsigned cudaWarpSize = 32;
static constexpr StringLiteral kCudaTarget = "cuda";
static constexpr StringLiteral kRocmTarget = "rocm";
namespace mlir {
namespace iree_compiler {
llvm::cl::opt<std::string> clGPUCodegenTransformDialectFileName(
Expand Down Expand Up @@ -162,11 +163,19 @@ bool isCudaTarget(func::FuncOp entryPoint) {
return false;
}

static TargetInfo getTargetInfo(func::FuncOp entryPoint) {
bool isRocmTarget(func::FuncOp entryPoint) {
if (auto variantOp =
entryPoint->getParentOfType<IREE::HAL::ExecutableVariantOp>()) {
IREE::HAL::ExecutableTargetAttr targetAttr = variantOp.getTarget();
if (auto backend = targetAttr.getBackend()) {
return backend.getValue().str() == kRocmTarget;
}
}
return false;
}

static TargetInfo getCudaTargetInfo(func::FuncOp entryPoint) {
TargetInfo info;
// TODO: fill out target info for other vendors.
if (!isCudaTarget(entryPoint))
return info;
// All the cuda target are assumed to have warp support.
info.hasWarpShuffle = true;
StringRef targetName = getTargetArch(entryPoint);
Expand All @@ -190,6 +199,34 @@ static TargetInfo getTargetInfo(func::FuncOp entryPoint) {
return info;
}

// TODO: Plumb in WarpSize into TargetInfo for wave64 systems.
static TargetInfo getRocmTargetInfo(func::FuncOp entryPoint) {
TargetInfo info;
StringRef targetName = getTargetArch(entryPoint);
// If no target name is set assume all the features are off.
if (targetName == "")
return info;
if (!StringRef(targetName).starts_with("gfx")) {
entryPoint.emitError("unknown target name ") << targetName;
return info;
}
// Assumes all gfx has warp shuffle.
info.hasWarpShuffle = true;
// TODO: Check and enable for WMMA once pipeline is available.
return info;
}

static TargetInfo getTargetInfo(func::FuncOp entryPoint) {
TargetInfo info;
// TODO: fill out target info for other vendors.
if (isCudaTarget(entryPoint)) {
info = getCudaTargetInfo(entryPoint);
} else if (isRocmTarget(entryPoint)) {
info = getRocmTargetInfo(entryPoint);
}
return info;
}

static bool supportsTensorCore(func::FuncOp entryPoint, linalg::LinalgOp op,
const TargetInfo &targetInfo) {
// Limit tensor core pipeline to matmul as not all combinations of transpose
Expand Down Expand Up @@ -254,6 +291,20 @@ static LogicalResult setContractConfig(func::FuncOp entryPoint,
if (!linalg::isaContractionOpInterface(op) || op.getNumParallelLoops() < 2) {
return failure();
}

// Also exclude the case of matvec, which has only one non-unit parallel dim.
// They should go down different pipelines.
int nonUnitParallelDimCount = 0;
SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
SmallVector<utils::IteratorType, 4> kinds = op.getIteratorTypesArray();
for (auto [kind, bound] : llvm::zip(kinds, bounds)) {
if (kind == utils::IteratorType::parallel)
nonUnitParallelDimCount += bound != 1;
}
if (!isa<linalg::MatmulOp, linalg::BatchMatmulOp>(op) &&
nonUnitParallelDimCount == 1)
return failure();

// Don't consider operations that don't have a broadcast, those should go
// through reductions.
if (llvm::any_of(op.getIndexingMapsArray(),
Expand Down Expand Up @@ -754,13 +805,24 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
}
SmallVector<unsigned> reductionDims;
op.getReductionDims(reductionDims);
if (reductionDims.size() != 1 || reductionDims[0] != op.getNumLoops() - 1)
if (reductionDims.empty())
return failure();

// Make sure reduction dimensions are the innermost ones.
for (int i = 0; i < reductionDims.size(); ++i) {
if (reductionDims[reductionDims.size() - 1 - i] !=
op.getNumLoops() - 1 - i) {
return failure();
}
}

if (op.getRegionOutputArgs().size() != 1)
return failure();


// Only support projected permutation, this could be extended to projected
// permutated with broadcast.

if (llvm::any_of(op.getDpsInputOperands(), [&](OpOperand *input) {
return !op.getMatchingIndexingMap(input).isProjectedPermutation();
}))
Expand All @@ -783,8 +845,12 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
if (!foundSingleReductionOutput)
return failure();

std::optional<int64_t> dimSize = getLinalgDimSize(op, reductionDims[0]);
if (!dimSize || *dimSize % cudaWarpSize != 0)

SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
int64_t dimSize = 1;
for (int64_t dim : reductionDims)
dimSize *= bounds[dim];
if (dimSize % cudaWarpSize != 0)
return failure();

const Type elementType =
Expand All @@ -797,14 +863,15 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8)
return failure();


const unsigned largestLoadSizeInBits = 128;
unsigned vectorSize = largestLoadSizeInBits / bitWidth;
while ((*dimSize / vectorSize) % cudaWarpSize != 0)
while ((dimSize / vectorSize) % cudaWarpSize != 0)
vectorSize /= 2;

// TODO: Add reduction tiling to handle larger reductions.
const int64_t maxWorkgroupSize = 1024;
int64_t groupSize = *dimSize / vectorSize;
int64_t groupSize = dimSize / vectorSize;
if (groupSize > maxWorkgroupSize) {
groupSize = llvm::APIntOps::GreatestCommonDivisor(
{64, uint64_t(groupSize)}, {64, uint64_t(maxWorkgroupSize)})
Expand All @@ -817,8 +884,20 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
// Tile all the parallel dimension to 1.
SmallVector<int64_t> workgroupTileSizes(numLoops, 1);
SmallVector<int64_t> reductionTileSizes(numLoops, 0);
reductionTileSizes.push_back(groupSize * vectorSize);
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
int64_t remaingGroupSize = groupSize;
for (int i = reductionDims.size() - 1; i >= 0; --i) {
int64_t dim = reductionDims[i];
int64_t bound = bounds[dim];
if (i == reductionDims.size() - 1)
bound /= vectorSize;
APInt size = llvm::APIntOps::GreatestCommonDivisor(
{64, uint64_t(remaingGroupSize)}, {64, uint64_t(bound)});
reductionTileSizes[dim] = size.getSExtValue();
if (i == reductionDims.size() - 1)
reductionTileSizes[dim] *= vectorSize;
remaingGroupSize /= size.getSExtValue();
}
TileSizesListType tileSizes;
tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
tileSizes.emplace_back(std::move(reductionTileSizes)); // reduction level
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ void addGPUTransposePassPipeline(OpPassManager &pm) {
void addGPUWarpReductionPassPipeline(OpPassManager &pm) {
tileAndDistributeToWorkgroup(pm);
auto &nestedModulePM = pm.nest<ModuleOp>();
nestedModulePM.addNestedPass<func::FuncOp>(
createRematerializeParallelOpsPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
nestedModulePM.addNestedPass<func::FuncOp>(createGPUTileReductionPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
Expand Down Expand Up @@ -588,8 +590,6 @@ void addGPUTransformDialectPasses(OpPassManager &passManager) {

void buildLLVMGPUTransformPassPipeline(OpPassManager &pm, bool useROCM) {
addCommonTargetExecutablePreprocessingPasses(pm.nest<ModuleOp>());
pm.nest<ModuleOp>().addNestedPass<func::FuncOp>(
createRematerializeParallelOpsPass());
pm.addPass(createLLVMGPULowerExecutableTargetPass());
OpPassManager &nestedModulePM = pm.nest<ModuleOp>();
//===--------------------------------------------------------------------===//
Expand Down

0 comments on commit 1431d8f

Please sign in to comment.