Skip to content

Commit

Permalink
[Flow] Convert from tensor.cast to flow.tensor.reshape early (iree-or…
Browse files Browse the repository at this point in the history
…g#18256)

The reason to do this is that tensor.cast's can end up in dispatch
regions and break the logic there. If they are converted to
flow.tensor.reshape before dispatch formation then the correct thing
happens and the resulting flow.tensor.reshape ops are left out of the
dispatch.regions.
Fixes : iree-org#18229
  • Loading branch information
nirvedhmeshram authored Aug 21, 2024
1 parent dd8abf7 commit 1c0c5a6
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,9 @@ void populateTensorToFlowConversionPatterns(MLIRContext *context,
ConvertTensorReshapePattern<tensor::ExpandShapeOp>>(context);
}

void populateTensorDialectCastOpPattern(MLIRContext *context,
RewritePatternSet &patterns) {
patterns.insert<ConvertTensorCastPattern>(context);
}

} // namespace mlir::iree_compiler::IREE::Flow
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ namespace mlir::iree_compiler::IREE::Flow {
void populateTensorToFlowConversionPatterns(MLIRContext *context,
RewritePatternSet &patterns);

// Add pattern to convert tensor.cast -> flow.tensor.reshape.
void populateTensorDialectCastOpPattern(MLIRContext *context,
RewritePatternSet &patterns);

} // namespace mlir::iree_compiler::IREE::Flow

#endif // IREE_COMPILER_DIALECT_FLOW_CONVERSION_TENSORTOFLOW_PATTERNS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand Down Expand Up @@ -103,6 +104,11 @@ struct CanonicalizerPass
CanonicalizerPass>::CanonicalizerPassBase;
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Flow::FlowDialect>();
}

LogicalResult initialize(MLIRContext *context) override {
// Inherit the same config defaults from the upstream canonicalizer pass.
config.useTopDownTraversal = true;
Expand All @@ -117,6 +123,7 @@ struct CanonicalizerPass
// Pull in some borderline/downstream canonicalizations for the Flow
// compilation phase.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(owningPatterns);
IREE::Flow::populateTensorDialectCastOpPattern(context, owningPatterns);
owningPatterns.add<FoldConsecutiveConstantPadding>(context);

patterns =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-convert-tensor-to-flow, canonicalize, iree-flow-materialize-default-workgroup-count-region), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-canonicalize, iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-convert-tensor-to-flow, canonicalize, iree-flow-materialize-default-workgroup-count-region), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s
util.func public @tile_matmul_alone(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
Expand Down Expand Up @@ -231,17 +231,17 @@ util.func public @always_fuse_cast
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<4x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[ARG0]]
// CHECK-SAME: tensor<?x?xf32>{%[[M]], %[[C4]]} -> tensor<?x4xf32>{%[[M]]}
// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[RESULT1:.+]] = flow.dispatch.workgroups[%[[M]], %[[K]], %[[N1]]]
// CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[M]], %[[K]], %[[N1]])
// CHECK: tensor.cast
// CHECK: %[[RESULT1:.+]] = flow.dispatch.workgroups[%[[N1]], %[[M]]]
// CHECK-SAME: (%[[RESHAPE]], %[[ARG1]], %[[N1]], %[[M]])
// CHECK: flow.return
// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ARG2]], %[[C1]]
// CHECK: %[[RESULT2:.+]] = flow.dispatch.workgroups[%[[M]], %[[K]], %[[N2]]]
// CHECK-SAME: (%[[ARG0]], %[[ARG2]], %[[M]], %[[K]], %[[N2]])
// CHECK: tensor.cast
// CHECK: %[[RESULT2:.+]] = flow.dispatch.workgroups[%[[N2]], %[[M]]]
// CHECK-SAME: (%[[RESHAPE]], %[[ARG2]], %[[N2]], %[[M]])
// CHECK: flow.return
// CHECK: util.return %[[RESULT1]], %[[RESULT2]]

Expand Down Expand Up @@ -512,26 +512,21 @@ util.func public @inline_dag_1(
// CHECK-NOT: linalg.
// CHECK-NOT: tensor.extract_slice
// CHECK: flow.dispatch.workgroups
// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<i32>>
// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<i32>>
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG9:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG11:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG12:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:tensor<1x?xf32>>
// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]]
// CHECK: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]]
// CHECK: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]]
// CHECK: %[[INIT:.+]] = tensor.empty
// CHECK-DAG: %[[OP1:.+]] = tensor.cast %[[LEAF1]]
// CHECK-DAG: %[[OP2:.+]] = tensor.cast %[[LEAF2]]
// CHECK-DAG: %[[OP3:.+]] = tensor.extract_slice %[[OP1]][0, 0]
// CHECK-DAG: %[[OP4:.+]] = tensor.extract_slice %[[OP1]][0, 10]
// CHECK-DAG: %[[OP5:.+]] = tensor.extract_slice %[[OP1]][0, 20]
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:tensor<1x?xf32>>
// CHECK-DAG: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 0]
// CHECK-DAG: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 10]
// CHECK-DAG: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 20]
// CHECK-DAG: %[[LEAF4:.+]] = flow.dispatch.tensor.load %[[ARG5]]
// CHECK-DAG: %[[LEAF5:.+]] = flow.dispatch.tensor.load %[[ARG6]]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LEAF3]], %[[OP5]], %[[OP2]], %[[OP4]], %[[OP3]] :
// CHECK-SAME: ins(%[[LEAF4]], %[[LEAF3]], %[[LEAF5]], %[[LEAF2]], %[[LEAF1]] :
// CHECK-SAME: outs(%[[INIT]] :

// -----
Expand Down Expand Up @@ -572,24 +567,21 @@ util.func public @inline_dag_2(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x?xf32>
// CHECK: flow.dispatch.workgroups
// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<i32>>
// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<i32>>
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG9:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG11:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:tensor<1x?xf32>>
// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}}
// CHECK: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
// CHECK: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
// CHECK: %[[INIT:.+]] = tensor.empty
// CHECK-DAG: %[[OP1:.+]] = tensor.cast %[[LEAF1]]
// CHECK-DAG: %[[OP3:.+]] = tensor.extract_slice %[[OP1]][0, 0]
// CHECK-DAG: %[[OP4:.+]] = tensor.extract_slice %[[OP1]][0, 10]
// CHECK-DAG: %[[OP5:.+]] = tensor.extract_slice %[[OP1]][0, 20]
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:tensor<1x?xf32>>
// CHECK-DAG: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 0]
// CHECK-DAG: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 10]
// CHECK-DAG: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 20]
// CHECK-DAG: %[[LEAF4:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
// CHECK-DAG: %[[LEAF5:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
// CHECK-DAG: %[[INIT:.+]] = tensor.empty
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LEAF3]], %[[OP5]], %[[LEAF2]], %[[OP4]], %[[OP3]] :
// CHECK-SAME: ins(%[[LEAF4]], %[[LEAF3]], %[[LEAF5]], %[[LEAF2]], %[[LEAF1]] :
// CHECK-SAME: outs(%[[INIT]] :

// -----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,27 @@ util.func public @dont_merge_constant_padding_different_vals(
// CHECK-LABEL: util.func public @dont_merge_constant_padding_different_vals
// CHECK: tensor.pad
// CHECK: tensor.pad

// -----

util.func public @tensor_cast_to_reshape(%reshape_17 : tensor<?x?x?x?xf32>, %65 : tensor<?x12x?x64xf32>, %0 : index, %1 : index) -> tensor<?x?x?x?xf32> {
%cast = tensor.cast %reshape_17 : tensor<?x?x?x?xf32> to tensor<?x?x12x64xf32>
%66 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%cast : tensor<?x?x12x64xf32>) outs(%65 : tensor<?x12x?x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x12x?x64xf32>
%cast_18 = tensor.cast %66 : tensor<?x12x?x64xf32> to tensor<?x?x?x?xf32>
util.return %cast_18 : tensor<?x?x?x?xf32>
}

// CHECK-LABEL: util.func public @tensor_cast_to_reshape
// CHECK: flow.tensor.reshape
// CHECK-SAME: tensor<?x?x?x?xf32>
// CHECK-SAME: -> tensor<?x?x12x64xf32>
// CHECK: linalg.generic
// CHECK: flow.tensor.reshape
// CHECK-SAME: tensor<?x12x?x64xf32>
// CHECK-SAME: -> tensor<?x?x?x?xf32>

0 comments on commit 1c0c5a6

Please sign in to comment.