diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index ae9c749fb0adb..077b49fad544f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -781,14 +781,20 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, return failure(); if (!isa(op)) return failure(); - // TODO(thomasraoux): Enable dynamic shape. - bool hasDynamicShape = false; - entryPoint.walk([&hasDynamicShape](linalg::LinalgOp op) { - if (op.hasDynamicShape()) - hasDynamicShape = true; + // TODO: Enable dynamic shape. + auto walkResult = entryPoint.walk([](linalg::LinalgOp op) { + using utils::IteratorType; + SmallVector kinds = op.getIteratorTypesArray(); + SmallVector bounds = op.getStaticLoopRanges(); + for (auto [kind, bound] : llvm::zip_equal(kinds, bounds)) { + if (kind == IteratorType::reduction && ShapedType::isDynamic(bound)) + return WalkResult::interrupt(); + } + return WalkResult::advance(); }); - if (hasDynamicShape) + if (walkResult.wasInterrupted()) { return failure(); + } SmallVector reductionDims; op.getReductionDims(reductionDims); if (reductionDims.empty()) @@ -831,7 +837,6 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, if (!foundSingleReductionOutput) return failure(); - SmallVector bounds = op.getStaticLoopRanges(); int64_t dimSize = 1; for (int64_t dim : reductionDims)