diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 3bd1bec3ebfa23..4a0565c901b31d 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -1665,6 +1665,7 @@ cc_library( ":sink_constants_to_control_flow", ":symbolic_shape_optimization", ":test_passes", + ":tile_loops_pass", ":transforms_pass_details", ":transforms_pass_inc_gen", ":userange_analysis", @@ -1743,6 +1744,7 @@ cc_library( ], deps = [ ":transforms_pass_inc_gen", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", ], @@ -1796,6 +1798,7 @@ cc_library( deps = [ ":hlo", ":transforms_pass_inc_gen", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:BufferizationTransforms", @@ -1818,6 +1821,7 @@ cc_library( deps = [ ":hlo", ":transforms_pass_inc_gen", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:FuncDialect", @@ -1841,6 +1845,7 @@ cc_library( ":hlo", ":lhlo", ":transforms_pass_inc_gen", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:CopyOpInterface", @@ -1862,6 +1867,7 @@ cc_library( ":shape_component_analysis", ":transforms_pass_inc_gen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -1996,6 +2002,7 @@ cc_library( ":hlo", ":transforms_pass_inc_gen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", @@ -2019,6 +2026,7 @@ cc_library( deps = [ ":hlo", ":transforms_pass_inc_gen", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -2042,6 +2050,7 @@ cc_library( ":transforms_pass_inc_gen", "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:kernel_gen_passes_inc_gen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -2052,6 +2061,30 @@ cc_library( ], ) +cc_library( + name = "tile_loops_pass", + srcs = [ + "lib/Transforms/tile_loops_pass.cc", + ], + hdrs = [ + "include/mlir-hlo/Transforms/PassDetail.h", + "include/mlir-hlo/Transforms/passes.h", + ], + deps = [ + ":hlo", + ":transforms_pass_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + ], +) + CAPI_HEADERS = [ "include/mlir-hlo-c/Attributes.h", "include/mlir-hlo-c/Dialects.h", diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/PassDetail.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/PassDetail.h index b916a752411832..6951a42d2a1876 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/PassDetail.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/PassDetail.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef MLIR_HLO_TRANSFORMS_PASSDETAIL_H #define MLIR_HLO_TRANSFORMS_PASSDETAIL_H +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Pass/Pass.h" diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/passes.h index 2634791af06a6f..e8f31a72f87be9 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/passes.h @@ -95,6 +95,11 @@ std::unique_ptr> CreateFinalBufferizePass( std::unique_ptr> CreatePropagateStaticShapesToKernelPass(Type pointer_type = {}); +// Creates a TileLoopsPass with tiles sizes provided through `tile_sizes` +// and unroll factors provided through `unroll_factors`. +std::unique_ptr> CreateTileLoopsPass( + ArrayRef tile_sizes = {}, ArrayRef unroll_factors = {}); + namespace hlo { std::unique_ptr> CreateOneShotBufferizePass(); } // namespace hlo diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/passes.td index 7591c626f7420f..bc04d26d8595f4 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/passes.td @@ -61,6 +61,25 @@ def BufferPacking : Pass<"buffer-packing", "func::FuncOp"> { ]; } +def TileLoopsPass : Pass<"tile-loops", "func::FuncOp"> { + let summary = "Tiles parallel loops."; + let description = [{ The pass converts an `scf.parallel` loop into a nested, + "tiled", `scf.parallel` loop with 2 to 3 levels of nesting. The 3rd level of + nesting represents operation unrolling within a tile and is only applied on + simple memory access patterns (ones resulting from same shape, scalar, and/or + constant operands).}]; + let constructor = "CreateTileLoopsPass()"; + let options = [ + ListOption<"tile_sizes_", "tile-sizes", "int64_t", "The size of the tile " + "in each dimension, expressed as the number of " + "`unroll_factors_` in that dimension.", "llvm::cl::ZeroOrMore">, + ListOption<"unroll_factors_", "unroll-factors", "int64_t", "The unroll " + "factor in each dimension, expressed as the number of elements " + "in that dimension.", "llvm::cl::ZeroOrMore">, + ]; + let dependentDialects = ["AffineDialect"]; +} + def MemoryCount : Pass<"memory-count", "func::FuncOp"> { let summary = "Test pass to count the allocated memory of a module."; let description = [{A test pass that prints the size of allocated memory of a diff --git a/tensorflow/compiler/mlir/hlo/lib/Transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Transforms/CMakeLists.txt index 58e309b4b6c388..f595d599dd4f35 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Transforms/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/lib/Transforms/CMakeLists.txt @@ -24,6 +24,7 @@ add_mlir_library(MLIRBufferTransforms lower_index_cast_pass.cc symbolic_shape_optimization.cc shape_simplification.cc + tile_loops_pass.cc DEPENDS LMHLOTransformsPassIncGen diff --git a/tensorflow/compiler/mlir/hlo/lib/Transforms/tile_loops_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Transforms/tile_loops_pass.cc new file mode 100644 index 00000000000000..e0d65e7097eee3 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Transforms/tile_loops_pass.cc @@ -0,0 +1,100 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the logic for converting `scf.parallel` loops into +// tiled loops. + +#include "mlir-hlo/Transforms/PassDetail.h" +#include "mlir-hlo/Transforms/passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" + +namespace mlir { + +using ::llvm::to_vector; +using ::mlir::scf::ParallelOp; + +namespace { + +// This is the implementation of the TileLoops pass declared in +// include/mlir-hlo/Transforms/passes.td +class TileLoopsPass : public TileLoopsPassBase { + public: + // Creates a TileLoopsPass with tiles sizes provided through `tile_sizes` + // and unroll factors provided through `unroll_factors`. + explicit TileLoopsPass(ArrayRef tile_sizes, + ArrayRef unroll_factors) { + tile_sizes_ = tile_sizes; + unroll_factors_ = unroll_factors; + } + + void runOnOperation() override; +}; + +} // namespace + +// Checks if the access pattern in the `scf.parallel` loop `ploop` is "complex". +// I.e., its memory load patterns include more than just scalar accesses, and +// accesses with offsets corresponding to loop inductions variables. +static bool IsComplexAccessPattern(ParallelOp ploop) { + for (Operation& nested : ploop.getBody()->without_terminator()) { + if (auto load_op = llvm::dyn_cast(nested)) { + if (!load_op.getMemRefType().getLayout().isIdentity() || + (!load_op.getIndices().empty() && + load_op.getIndices() != ploop.getInductionVars())) { + return true; + } + } + } + return false; +} + +void TileLoopsPass::runOnOperation() { + auto unrolled_tile = [&]() -> SmallVector { + if (tile_sizes_.size() != unroll_factors_.size()) return {}; + auto multiply = [](std::tuple tuple) { + return std::get<0>(tuple) * std::get<1>(tuple); + }; + return to_vector<4>( + llvm::map_range(llvm::zip(tile_sizes_, unroll_factors_), multiply)); + }(); + + SmallVector innermostPloops; + getInnermostParallelLoops(this->getOperation().getOperation(), + innermostPloops); + + for (ParallelOp ploop : innermostPloops) { + // Do not unroll if the multiplier has the wrong rank, or if we have complex + // memory access patterns. + if (unrolled_tile.empty() || IsComplexAccessPattern(ploop)) { + tileParallelLoop(ploop, tile_sizes_, /*noMinMaxBounds=*/false); + continue; + } + auto tiled_loops = + tileParallelLoop(ploop, unrolled_tile, /*noMinMaxBounds=*/false); + tileParallelLoop(tiled_loops.second, unroll_factors_, + /*noMinMaxBounds=*/false); + } +} + +std::unique_ptr> CreateTileLoopsPass( + ArrayRef tile_sizes, ArrayRef unroll_factors) { + return std::make_unique(tile_sizes, unroll_factors); +} + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/tests/tile_loops.mlir b/tensorflow/compiler/mlir/hlo/tests/tile_loops.mlir new file mode 100644 index 00000000000000..0897d343a18e71 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/tile_loops.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-hlo-opt --tile-loops="tile-sizes=2 unroll-factors=4" %s | \ +// RUN: FileCheck %s + +// CHECK-LABEL: func @parallel_loop +func.func @parallel_loop(%arg0: memref<16xf32>, %arg1: memref<16xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %0 = memref.alloc() {alignment = 128 : i64} : memref<16xf32> + scf.parallel (%arg2) = (%c0) to (%c16) step (%c1) { + // CHECK: %[[C8:.*]] = arith.constant 8 + // CHECK: %[[TILE:.*]] = arith.muli {{.*}} %[[C8]] + // CHECK: scf.parallel {{.*}} step (%[[TILE]]) + // CHECK: %[[C4:.*]] = arith.constant 4 + // CHECK: %[[UNROLL:.*]] = arith.muli {{.*}} %[[C4]] + // CHECK: scf.parallel {{.*}} to (%[[TILE]]) step (%[[UNROLL]]) + // CHECK: scf.parallel + %2 = memref.load %arg0[%arg2] : memref<16xf32> + %3 = math.log %2 : f32 + memref.store %3, %0[%arg2] : memref<16xf32> + scf.yield + } + %1 = bufferization.to_tensor %0 : memref<16xf32> + memref.tensor_store %1, %arg1 : memref<16xf32> + "lmhlo.terminator"() : () -> () +} + +// CHECK-LABEL: func @complex_access +func.func @complex_access(%arg0: memref<16xf32>, %arg1: memref<4xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = memref.alloc() {alignment = 128 : i64} : memref<4xf32> + scf.parallel (%arg2) = (%c0) to (%c4) step (%c1) { + // CHECK: %[[C2:.*]] = arith.constant 2 + // CHECK: %[[TILE:.*]] = arith.muli {{.*}} %[[C2]] + // CHECK: scf.parallel {{.*}} step (%[[TILE]]) + // CHECK: scf.parallel + // We should see only 2 loops for complex access patterns + // CHECK-NOT: scf.parallel + %idx = arith.muli %arg2, %c4 : index + %2 = memref.load %arg0[%idx] : memref<16xf32> + %3 = math.log %2 : f32 + memref.store %3, %0[%arg2] : memref<4xf32> + scf.yield + } + %1 = bufferization.to_tensor %0 : memref<4xf32> + memref.tensor_store %1, %arg1 : memref<4xf32> + "lmhlo.terminator"() : () -> () +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 827dd572e25038..76287bdfcdea19 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -45,17 +45,9 @@ cc_library( hdrs = ["kernel_creator.h"], copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]), deps = [ - ":compile_cache_item_proto_cc", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:all_passes", - "//tensorflow/compiler/mlir/hlo:hlo_legalize_shape_ops_to_standard", - "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", - "//tensorflow/compiler/mlir/hlo:legalize_to_linalg", "//tensorflow/compiler/mlir/hlo:legalize_trigonometric_to_approximation", - "//tensorflow/compiler/mlir/hlo:lhlo", - "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", - "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", - "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", "//tensorflow/compiler/mlir/hlo:shape_simplification", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", @@ -64,10 +56,8 @@ cc_library( "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", "//tensorflow/compiler/mlir/xla:xla_legalize_tf_no_fallback", "//tensorflow/core:lib", - "//tensorflow/core/platform:cuda_libdevice_path", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineToStandard", - "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:ArithmeticTransforms", "@llvm-project//mlir:BufferizationTransforms", @@ -79,28 +69,18 @@ cc_library( "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMToLLVMIRTranslation", - "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:NVVMToLLVMIRTranslation", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ROCDLDialect", "@llvm-project//mlir:ROCDLToLLVMIRTranslation", "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToGPUPass", "@llvm-project//mlir:SCFToStandard", - "@llvm-project//mlir:SCFTransforms", - "@llvm-project//mlir:SCFUtils", - "@llvm-project//mlir:Shape", "@llvm-project//mlir:ShapeToStandard", - "@llvm-project//mlir:ShapeTransforms", - "@llvm-project//mlir:StandardOpsTransforms", - "@llvm-project//mlir:ToLLVMIRTranslation", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorToLLVM", ], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index dfdeec782be091..992b3fd69b7c44 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -31,24 +31,16 @@ limitations under the License. #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project -#include "mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project #include "mlir/Dialect/GPU/Passes.h" // from @llvm-project -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project -#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project -#include "mlir/Dialect/SCF/Utils/Utils.h" // from @llvm-project -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project @@ -56,20 +48,16 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/statusor.h" namespace tensorflow { @@ -136,70 +124,6 @@ struct CollapseParallelLoopsTo1D } }; -class TileLoops - : public mlir::PassWrapper> { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileLoops) - - explicit TileLoops(llvm::ArrayRef tile_sizes, - llvm::ArrayRef unroll_factors) { - tile_sizes_ = llvm::to_vector<4>(tile_sizes); - outer_tile_ = tile_sizes_; - - // We have to anticipate later unrolling in tiling to make sure that we get - // the requested tiling after unrolling. - if (unroll_factors.size() == tile_sizes.size()) { - inner_tile_ = llvm::to_vector<4>(unroll_factors); - for (auto en : llvm::enumerate(unroll_factors)) { - outer_tile_[en.index()] *= en.value(); - } - } - } - - void runOnOperation() override { - llvm::SmallVector innermostPloops; - mlir::getInnermostParallelLoops(this->getOperation().getOperation(), - innermostPloops); - auto is_simple_access_pattern = [](ParallelOp ploop) { - for (mlir::Operation& nested : ploop.getBody()->without_terminator()) { - if (auto load_op = llvm::dyn_cast(nested)) { - if (!load_op.getMemRefType().getLayout().isIdentity() || - (!load_op.getIndices().empty() && - load_op.getIndices() != ploop.getInductionVars())) { - return false; - } - } - } - return true; - }; - - for (ParallelOp ploop : innermostPloops) { - // Support unrolling only for simple memory access patterns (that result - // from same shape operands, scalar operands, and/or constant operands). - if (!is_simple_access_pattern(ploop)) { - tileParallelLoop(ploop, tile_sizes_, /*noMinMaxBounds=*/false); - continue; - } - auto tiled_loops = - tileParallelLoop(ploop, outer_tile_, /*noMinMaxBounds=*/false); - // Tile twice if the inner_tile is non-empty. - if (!inner_tile_.empty()) { - tileParallelLoop(tiled_loops.second, inner_tile_, - /*noMinMaxBounds=*/false); - } - } - } - - private: - // Outer tile size = unroll_factor.empty() ? tile_sizes : tile_sizes * - // unroll_factors. - llvm::SmallVector outer_tile_; - // Inner tile size if the unrolling factors were specified. - llvm::SmallVector inner_tile_; - // Original tile sizes. - llvm::SmallVector tile_sizes_; -}; - Status LowerTFToJITInvocation(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, @@ -328,7 +252,7 @@ Status LowerTFtoLoops(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, // provide benefits to CPU and tiling is handled by vectorization. pm.addNestedPass(std::make_unique()); pm.addNestedPass( - std::make_unique(tile_sizes, unroll_factors)); + mlir::CreateTileLoopsPass(tile_sizes, unroll_factors)); } pm.addNestedPass(::mlir::createCanonicalizerPass());