Skip to content

Commit

Permalink
Fix output shape bug and add test
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 committed Jan 14, 2025
1 parent 5aeb8df commit 62c2e1e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
21 changes: 4 additions & 17 deletions compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,17 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -777,16 +775,9 @@ hoistTensorReshapesOutOfDispatchRegion(
auto shapedType = dyn_cast<ShapedType>(origResult.getType());
assert(shapedType && "result should be shaped type");

SmallVector<OpFoldResult> outputShape;
ValueRange dynamicDims = dispatchOp.getResultDynamicDims(index);
for (int64_t dim : shapedType.getShape()) {
if (ShapedType::isDynamic(dim)) {
outputShape.push_back(dynamicDims.front());
dynamicDims.drop_front();
continue;
}
outputShape.push_back(rewriter.getIndexAttr(dim));
}
SmallVector<OpFoldResult> outputShape =
mlir::getMixedValues(shapedType.getShape(), dynamicDims, rewriter);

auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, origResult.getType(), returnValue,
Expand Down Expand Up @@ -1062,11 +1053,7 @@ void CollapseDimensionsPass::runOnOperation() {
memref::populateResolveRankedShapedTypeResultDimsPatterns(moveReshapeOps);
tensor::populateFoldTensorEmptyPatterns(moveReshapeOps);
SmallVector<Operation *> candidateOps;
block.walk([&](Operation *op) {
if (isa<tensor::CollapseShapeOp>(op)) {
candidateOps.push_back(op);
}
});
block.walk([&](Operation *op) { candidateOps.push_back(op); });
if (failed(
applyOpPatternsGreedily(candidateOps, std::move(moveReshapeOps)))) {
funcOp.emitOpError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,3 +753,37 @@ util.func public @uncollapsable_consumer_partial(%arg0: tensor<10x20x30x2304xf32
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[RES]]
// CHECK: flow.return %[[RES]]

// -----

util.func @elementwise_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32>{
%cst_0 = arith.constant 0 : index
%cst_1 = arith.constant 1 : index
%0 = tensor.dim %arg0, %cst_0 : tensor<?x?xf32>
%1 = tensor.dim %arg0, %cst_1 : tensor<?x?xf32>
%3 = flow.dispatch.region -> (tensor<?x?xf32>{%0, %1}) {
%5 = tensor.empty(%0, %1) : tensor<?x?xf32>
%cst = arith.constant 1.000000e+02 : f32
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xf32>) outs(%5 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%7 = arith.addf %in, %cst : f32
linalg.yield %7 : f32
} -> tensor<?x?xf32>
flow.return %6 : tensor<?x?xf32>
}
util.return %3 : tensor<?x?xf32>
}
// CHECK-LABEL: util.func public @elementwise_dynamic
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]
// CHECK-DAG: %[[CST0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]]
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]]
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
// CHECK: %[[VAL:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel"]
// CHECK: flow.return %[[VAL]] : tensor<?xf32>
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[DISPATCH]]
// CHECK-SAME: {{.+}} output_shape [%[[DIM0]], %[[DIM1]]]
// CHECK: util.return %[[EXPAND]] : tensor<?x?xf32>

0 comments on commit 62c2e1e

Please sign in to comment.