diff --git a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp index 61f9266bf6b8..ca807eecb6ac 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp @@ -24,19 +24,17 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/Iterators.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Visitors.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -777,16 +775,9 @@ hoistTensorReshapesOutOfDispatchRegion( auto shapedType = dyn_cast(origResult.getType()); assert(shapedType && "result should be shaped type"); - SmallVector outputShape; ValueRange dynamicDims = dispatchOp.getResultDynamicDims(index); - for (int64_t dim : shapedType.getShape()) { - if (ShapedType::isDynamic(dim)) { - outputShape.push_back(dynamicDims.front()); - dynamicDims.drop_front(); - continue; - } - outputShape.push_back(rewriter.getIndexAttr(dim)); - } + SmallVector outputShape = + mlir::getMixedValues(shapedType.getShape(), dynamicDims, rewriter); auto newExpandShapeOp = rewriter.create( loc, origResult.getType(), returnValue, @@ -1062,11 +1053,7 @@ void CollapseDimensionsPass::runOnOperation() { memref::populateResolveRankedShapedTypeResultDimsPatterns(moveReshapeOps); tensor::populateFoldTensorEmptyPatterns(moveReshapeOps); SmallVector candidateOps; - block.walk([&](Operation *op) { - if (isa(op)) { - candidateOps.push_back(op); - } - }); + block.walk([&](Operation *op) { candidateOps.push_back(op); }); if (failed( applyOpPatternsGreedily(candidateOps, std::move(moveReshapeOps)))) { funcOp.emitOpError( diff --git a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir index 5ae2cb71df1b..377c91b6a054 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir @@ -753,3 +753,37 @@ util.func public @uncollapsable_consumer_partial(%arg0: tensor<10x20x30x2304xf32 // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] // CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[RES]] // CHECK: flow.return %[[RES]] + +// ----- + +util.func @elementwise_dynamic(%arg0: tensor, %arg1: tensor) -> tensor{ + %cst_0 = arith.constant 0 : index + %cst_1 = arith.constant 1 : index + %0 = tensor.dim %arg0, %cst_0 : tensor + %1 = tensor.dim %arg0, %cst_1 : tensor + %3 = flow.dispatch.region -> (tensor{%0, %1}) { + %5 = tensor.empty(%0, %1) : tensor + %cst = arith.constant 1.000000e+02 : f32 + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor) outs(%5 : tensor) { + ^bb0(%in: f32, %out: f32): + %7 = arith.addf %in, %cst : f32 + linalg.yield %7 : f32 + } -> tensor + flow.return %6 : tensor + } + util.return %3 : tensor +} +// CHECK-LABEL: util.func public @elementwise_dynamic +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]] +// CHECK-DAG: %[[CST0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[CST1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]] +// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]] +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[VAL:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK: flow.return %[[VAL]] : tensor +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[DISPATCH]] +// CHECK-SAME: {{.+}} output_shape [%[[DIM0]], %[[DIM1]]] +// CHECK: util.return %[[EXPAND]] : tensor