Skip to content

Commit

Permalink
[LLVMGPU] Allow reductions with dynamic parallel dimensions through W…
Browse files Browse the repository at this point in the history
…arpReduce
  • Loading branch information
qedawkins authored and github-actions[bot] committed Sep 10, 2023
1 parent 695850b commit 253cfaa
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,14 +781,20 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
return failure();
if (!isa<linalg::GenericOp>(op))
return failure();
// TODO(thomasraoux): Enable dynamic shape.
bool hasDynamicShape = false;
entryPoint.walk([&hasDynamicShape](linalg::LinalgOp op) {
if (op.hasDynamicShape())
hasDynamicShape = true;
// TODO: Enable dynamic shape.
auto walkResult = entryPoint.walk([](linalg::LinalgOp op) {
using utils::IteratorType;
SmallVector<IteratorType, 4> kinds = op.getIteratorTypesArray();
SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
for (auto [kind, bound] : llvm::zip_equal(kinds, bounds)) {
if (kind == IteratorType::reduction && ShapedType::isDynamic(bound))
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (hasDynamicShape)
if (walkResult.wasInterrupted()) {
return failure();
}
SmallVector<unsigned> reductionDims;
op.getReductionDims(reductionDims);
if (reductionDims.empty())
Expand Down Expand Up @@ -831,7 +837,6 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
if (!foundSingleReductionOutput)
return failure();


SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
int64_t dimSize = 1;
for (int64_t dim : reductionDims)
Expand Down

0 comments on commit 253cfaa

Please sign in to comment.