From 1c0c5a6ff64bedb1cb1275eaba16aab3fc26acdf Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Wed, 21 Aug 2024 16:57:53 -0500 Subject: [PATCH] [Flow] Convert from tensor.cast to flow.tensor.reshape early (#18256) The reason to do this is that tensor.cast's can end up in dispatch regions and break the logic there. If they are converted to flow.tensor.reshape before dispatch formation then the correct thing happens and the resulting flow.tensor.reshape ops are left out of the dispatch.regions. Fixes : https://github.com/iree-org/iree/issues/18229 --- .../Flow/Conversion/TensorToFlow/Patterns.cpp | 5 ++ .../Flow/Conversion/TensorToFlow/Patterns.h | 4 ++ .../Dialect/Flow/Transforms/Canonicalizer.cpp | 7 ++ .../test/dispatch_linalg_on_tensors.mlir | 68 ++++++++----------- .../Transforms/test/flow_canonicalize.mlir | 24 +++++++ 5 files changed, 70 insertions(+), 38 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp index d5794afd504e..df25f922f4af 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp @@ -326,4 +326,9 @@ void populateTensorToFlowConversionPatterns(MLIRContext *context, ConvertTensorReshapePattern>(context); } +void populateTensorDialectCastOpPattern(MLIRContext *context, + RewritePatternSet &patterns) { + patterns.insert(context); +} + } // namespace mlir::iree_compiler::IREE::Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.h b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.h index a0bb75b48128..95f8c47a3dd6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.h @@ -16,6 +16,10 @@ namespace mlir::iree_compiler::IREE::Flow { void populateTensorToFlowConversionPatterns(MLIRContext *context, RewritePatternSet &patterns); +// Add pattern to convert tensor.cast -> flow.tensor.reshape. +void populateTensorDialectCastOpPattern(MLIRContext *context, + RewritePatternSet &patterns); + } // namespace mlir::iree_compiler::IREE::Flow #endif // IREE_COMPILER_DIALECT_FLOW_CONVERSION_TENSORTOFLOW_PATTERNS_H_ diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp index b19bccc14da3..e7da87363c9a 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp @@ -4,6 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -103,6 +104,11 @@ struct CanonicalizerPass CanonicalizerPass>::CanonicalizerPassBase; /// Initialize the canonicalizer by building the set of patterns used during /// execution. + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + LogicalResult initialize(MLIRContext *context) override { // Inherit the same config defaults from the upstream canonicalizer pass. config.useTopDownTraversal = true; @@ -117,6 +123,7 @@ struct CanonicalizerPass // Pull in some borderline/downstream canonicalizations for the Flow // compilation phase. tensor::populateMergeConsecutiveInsertExtractSlicePatterns(owningPatterns); + IREE::Flow::populateTensorDialectCastOpPattern(context, owningPatterns); owningPatterns.add(context); patterns = diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir index ea22db1bf38e..d5fd174a88ab 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-convert-tensor-to-flow, canonicalize, iree-flow-materialize-default-workgroup-count-region), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s +// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-canonicalize, iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-convert-tensor-to-flow, canonicalize, iree-flow-materialize-default-workgroup-count-region), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s util.func public @tile_matmul_alone(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { %1 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) @@ -231,17 +231,17 @@ util.func public @always_fuse_cast // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<4x?xf32> // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index // CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[ARG0]] +// CHECK-SAME: tensor{%[[M]], %[[C4]]} -> tensor{%[[M]]} // CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[RESULT1:.+]] = flow.dispatch.workgroups[%[[M]], %[[K]], %[[N1]]] -// CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[M]], %[[K]], %[[N1]]) -// CHECK: tensor.cast +// CHECK: %[[RESULT1:.+]] = flow.dispatch.workgroups[%[[N1]], %[[M]]] +// CHECK-SAME: (%[[RESHAPE]], %[[ARG1]], %[[N1]], %[[M]]) // CHECK: flow.return // CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ARG2]], %[[C1]] -// CHECK: %[[RESULT2:.+]] = flow.dispatch.workgroups[%[[M]], %[[K]], %[[N2]]] -// CHECK-SAME: (%[[ARG0]], %[[ARG2]], %[[M]], %[[K]], %[[N2]]) -// CHECK: tensor.cast +// CHECK: %[[RESULT2:.+]] = flow.dispatch.workgroups[%[[N2]], %[[M]]] +// CHECK-SAME: (%[[RESHAPE]], %[[ARG2]], %[[N2]], %[[M]]) // CHECK: flow.return // CHECK: util.return %[[RESULT1]], %[[RESULT2]] @@ -512,26 +512,21 @@ util.func public @inline_dag_1( // CHECK-NOT: linalg. // CHECK-NOT: tensor.extract_slice // CHECK: flow.dispatch.workgroups -// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> +// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> // CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG9:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG11:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG12:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> -// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]] -// CHECK: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]] -// CHECK: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]] -// CHECK: %[[INIT:.+]] = tensor.empty -// CHECK-DAG: %[[OP1:.+]] = tensor.cast %[[LEAF1]] -// CHECK-DAG: %[[OP2:.+]] = tensor.cast %[[LEAF2]] -// CHECK-DAG: %[[OP3:.+]] = tensor.extract_slice %[[OP1]][0, 0] -// CHECK-DAG: %[[OP4:.+]] = tensor.extract_slice %[[OP1]][0, 10] -// CHECK-DAG: %[[OP5:.+]] = tensor.extract_slice %[[OP1]][0, 20] +// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> +// CHECK-DAG: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 0] +// CHECK-DAG: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 10] +// CHECK-DAG: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 20] +// CHECK-DAG: %[[LEAF4:.+]] = flow.dispatch.tensor.load %[[ARG5]] +// CHECK-DAG: %[[LEAF5:.+]] = flow.dispatch.tensor.load %[[ARG6]] +// CHECK-DAG: %[[INIT:.+]] = tensor.empty // CHECK: linalg.generic -// CHECK-SAME: ins(%[[LEAF3]], %[[OP5]], %[[OP2]], %[[OP4]], %[[OP3]] : +// CHECK-SAME: ins(%[[LEAF4]], %[[LEAF3]], %[[LEAF5]], %[[LEAF2]], %[[LEAF1]] : // CHECK-SAME: outs(%[[INIT]] : // ----- @@ -572,24 +567,21 @@ util.func public @inline_dag_2( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x?xf32> // CHECK: flow.dispatch.workgroups -// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> +// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> // CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG9:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG11:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> -// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}} -// CHECK: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}} -// CHECK: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}} -// CHECK: %[[INIT:.+]] = tensor.empty -// CHECK-DAG: %[[OP1:.+]] = tensor.cast %[[LEAF1]] -// CHECK-DAG: %[[OP3:.+]] = tensor.extract_slice %[[OP1]][0, 0] -// CHECK-DAG: %[[OP4:.+]] = tensor.extract_slice %[[OP1]][0, 10] -// CHECK-DAG: %[[OP5:.+]] = tensor.extract_slice %[[OP1]][0, 20] +// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor> +// CHECK-DAG: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 0] +// CHECK-DAG: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 10] +// CHECK-DAG: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 20] +// CHECK-DAG: %[[LEAF4:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}} +// CHECK-DAG: %[[LEAF5:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}} +// CHECK-DAG: %[[INIT:.+]] = tensor.empty // CHECK: linalg.generic -// CHECK-SAME: ins(%[[LEAF3]], %[[OP5]], %[[LEAF2]], %[[OP4]], %[[OP3]] : +// CHECK-SAME: ins(%[[LEAF4]], %[[LEAF3]], %[[LEAF5]], %[[LEAF2]], %[[LEAF1]] : // CHECK-SAME: outs(%[[INIT]] : // ----- diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir index 81203a5db24c..effadc47da5d 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir @@ -82,3 +82,27 @@ util.func public @dont_merge_constant_padding_different_vals( // CHECK-LABEL: util.func public @dont_merge_constant_padding_different_vals // CHECK: tensor.pad // CHECK: tensor.pad + +// ----- + +util.func public @tensor_cast_to_reshape(%reshape_17 : tensor, %65 : tensor, %0 : index, %1 : index) -> tensor { + %cast = tensor.cast %reshape_17 : tensor to tensor + %66 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%cast : tensor) outs(%65 : tensor) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor + %cast_18 = tensor.cast %66 : tensor to tensor + util.return %cast_18 : tensor +} + +// CHECK-LABEL: util.func public @tensor_cast_to_reshape +// CHECK: flow.tensor.reshape +// CHECK-SAME: tensor +// CHECK-SAME: -> tensor +// CHECK: linalg.generic +// CHECK: flow.tensor.reshape +// CHECK-SAME: tensor +// CHECK-SAME: -> tensor