Skip to content

Commit

Permalink
Migrate TileLoopsPass into mlir-hlo
Browse files Browse the repository at this point in the history
This allows us to reuse it in XLA Next, in addition to kernel_gen.

PiperOrigin-RevId: 448935170
  • Loading branch information
gflegar authored and tensorflower-gardener committed May 16, 2022
1 parent fcc0207 commit 23529a7
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 97 deletions.
33 changes: 33 additions & 0 deletions tensorflow/compiler/mlir/hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,7 @@ cc_library(
":sink_constants_to_control_flow",
":symbolic_shape_optimization",
":test_passes",
":tile_loops_pass",
":transforms_pass_details",
":transforms_pass_inc_gen",
":userange_analysis",
Expand Down Expand Up @@ -1743,6 +1744,7 @@ cc_library(
],
deps = [
":transforms_pass_inc_gen",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
Expand Down Expand Up @@ -1796,6 +1798,7 @@ cc_library(
deps = [
":hlo",
":transforms_pass_inc_gen",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:BufferizationTransforms",
Expand All @@ -1818,6 +1821,7 @@ cc_library(
deps = [
":hlo",
":transforms_pass_inc_gen",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:BufferizationTransforms",
"@llvm-project//mlir:FuncDialect",
Expand All @@ -1841,6 +1845,7 @@ cc_library(
":hlo",
":lhlo",
":transforms_pass_inc_gen",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:BufferizationTransforms",
"@llvm-project//mlir:CopyOpInterface",
Expand All @@ -1862,6 +1867,7 @@ cc_library(
":shape_component_analysis",
":transforms_pass_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down Expand Up @@ -1996,6 +2002,7 @@ cc_library(
":hlo",
":transforms_pass_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
Expand All @@ -2019,6 +2026,7 @@ cc_library(
deps = [
":hlo",
":transforms_pass_inc_gen",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand All @@ -2042,6 +2050,7 @@ cc_library(
":transforms_pass_inc_gen",
"//tensorflow/compiler/mlir/tools/kernel_gen/transforms:kernel_gen_passes_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand All @@ -2052,6 +2061,30 @@ cc_library(
],
)

cc_library(
name = "tile_loops_pass",
srcs = [
"lib/Transforms/tile_loops_pass.cc",
],
hdrs = [
"include/mlir-hlo/Transforms/PassDetail.h",
"include/mlir-hlo/Transforms/passes.h",
],
deps = [
":hlo",
":transforms_pass_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:SCFUtils",
],
)

CAPI_HEADERS = [
"include/mlir-hlo-c/Attributes.h",
"include/mlir-hlo-c/Dialects.h",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef MLIR_HLO_TRANSFORMS_PASSDETAIL_H
#define MLIR_HLO_TRANSFORMS_PASSDETAIL_H

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Pass/Pass.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass(
std::unique_ptr<OperationPass<ModuleOp>>
CreatePropagateStaticShapesToKernelPass(Type pointer_type = {});

// Creates a TileLoopsPass with tiles sizes provided through `tile_sizes`
// and unroll factors provided through `unroll_factors`.
std::unique_ptr<OperationPass<func::FuncOp>> CreateTileLoopsPass(
ArrayRef<int64_t> tile_sizes = {}, ArrayRef<int64_t> unroll_factors = {});

namespace hlo {
std::unique_ptr<OperationPass<ModuleOp>> CreateOneShotBufferizePass();
} // namespace hlo
Expand Down
19 changes: 19 additions & 0 deletions tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,25 @@ def BufferPacking : Pass<"buffer-packing", "func::FuncOp"> {
];
}

def TileLoopsPass : Pass<"tile-loops", "func::FuncOp"> {
let summary = "Tiles parallel loops.";
let description = [{ The pass converts an `scf.parallel` loop into a nested,
"tiled", `scf.parallel` loop with 2 to 3 levels of nesting. The 3rd level of
nesting represents operation unrolling within a tile and is only applied on
simple memory access patterns (ones resulting from same shape, scalar, and/or
constant operands).}];
let constructor = "CreateTileLoopsPass()";
let options = [
ListOption<"tile_sizes_", "tile-sizes", "int64_t", "The size of the tile "
"in each dimension, expressed as the number of "
"`unroll_factors_` in that dimension.", "llvm::cl::ZeroOrMore">,
ListOption<"unroll_factors_", "unroll-factors", "int64_t", "The unroll "
"factor in each dimension, expressed as the number of elements "
"in that dimension.", "llvm::cl::ZeroOrMore">,
];
let dependentDialects = ["AffineDialect"];
}

def MemoryCount : Pass<"memory-count", "func::FuncOp"> {
let summary = "Test pass to count the allocated memory of a module.";
let description = [{A test pass that prints the size of allocated memory of a
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/hlo/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_mlir_library(MLIRBufferTransforms
lower_index_cast_pass.cc
symbolic_shape_optimization.cc
shape_simplification.cc
tile_loops_pass.cc

DEPENDS
LMHLOTransformsPassIncGen
Expand Down
100 changes: 100 additions & 0 deletions tensorflow/compiler/mlir/hlo/lib/Transforms/tile_loops_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This files implements the logic for converting `scf.parallel` loops into
// tiled loops.

#include "mlir-hlo/Transforms/PassDetail.h"
#include "mlir-hlo/Transforms/passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"

namespace mlir {

using ::llvm::to_vector;
using ::mlir::scf::ParallelOp;

namespace {

// This is the implementation of the TileLoops pass declared in
// include/mlir-hlo/Transforms/passes.td
class TileLoopsPass : public TileLoopsPassBase<TileLoopsPass> {
public:
// Creates a TileLoopsPass with tiles sizes provided through `tile_sizes`
// and unroll factors provided through `unroll_factors`.
explicit TileLoopsPass(ArrayRef<int64_t> tile_sizes,
ArrayRef<int64_t> unroll_factors) {
tile_sizes_ = tile_sizes;
unroll_factors_ = unroll_factors;
}

void runOnOperation() override;
};

} // namespace

// Checks if the access pattern in the `scf.parallel` loop `ploop` is "complex".
// I.e., its memory load patterns include more than just scalar accesses, and
// accesses with offsets corresponding to loop inductions variables.
static bool IsComplexAccessPattern(ParallelOp ploop) {
for (Operation& nested : ploop.getBody()->without_terminator()) {
if (auto load_op = llvm::dyn_cast<memref::LoadOp>(nested)) {
if (!load_op.getMemRefType().getLayout().isIdentity() ||
(!load_op.getIndices().empty() &&
load_op.getIndices() != ploop.getInductionVars())) {
return true;
}
}
}
return false;
}

void TileLoopsPass::runOnOperation() {
auto unrolled_tile = [&]() -> SmallVector<int64_t, 4> {
if (tile_sizes_.size() != unroll_factors_.size()) return {};
auto multiply = [](std::tuple<int64_t, int64_t> tuple) {
return std::get<0>(tuple) * std::get<1>(tuple);
};
return to_vector<4>(
llvm::map_range(llvm::zip(tile_sizes_, unroll_factors_), multiply));
}();

SmallVector<ParallelOp, 2> innermostPloops;
getInnermostParallelLoops(this->getOperation().getOperation(),
innermostPloops);

for (ParallelOp ploop : innermostPloops) {
// Do not unroll if the multiplier has the wrong rank, or if we have complex
// memory access patterns.
if (unrolled_tile.empty() || IsComplexAccessPattern(ploop)) {
tileParallelLoop(ploop, tile_sizes_, /*noMinMaxBounds=*/false);
continue;
}
auto tiled_loops =
tileParallelLoop(ploop, unrolled_tile, /*noMinMaxBounds=*/false);
tileParallelLoop(tiled_loops.second, unroll_factors_,
/*noMinMaxBounds=*/false);
}
}

std::unique_ptr<OperationPass<func::FuncOp>> CreateTileLoopsPass(
ArrayRef<int64_t> tile_sizes, ArrayRef<int64_t> unroll_factors) {
return std::make_unique<TileLoopsPass>(tile_sizes, unroll_factors);
}

} // namespace mlir
50 changes: 50 additions & 0 deletions tensorflow/compiler/mlir/hlo/tests/tile_loops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: mlir-hlo-opt --tile-loops="tile-sizes=2 unroll-factors=4" %s | \
// RUN: FileCheck %s

// CHECK-LABEL: func @parallel_loop
func.func @parallel_loop(%arg0: memref<16xf32>, %arg1: memref<16xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<16xf32>
scf.parallel (%arg2) = (%c0) to (%c16) step (%c1) {
// CHECK: %[[C8:.*]] = arith.constant 8
// CHECK: %[[TILE:.*]] = arith.muli {{.*}} %[[C8]]
// CHECK: scf.parallel {{.*}} step (%[[TILE]])
// CHECK: %[[C4:.*]] = arith.constant 4
// CHECK: %[[UNROLL:.*]] = arith.muli {{.*}} %[[C4]]
// CHECK: scf.parallel {{.*}} to (%[[TILE]]) step (%[[UNROLL]])
// CHECK: scf.parallel
%2 = memref.load %arg0[%arg2] : memref<16xf32>
%3 = math.log %2 : f32
memref.store %3, %0[%arg2] : memref<16xf32>
scf.yield
}
%1 = bufferization.to_tensor %0 : memref<16xf32>
memref.tensor_store %1, %arg1 : memref<16xf32>
"lmhlo.terminator"() : () -> ()
}

// CHECK-LABEL: func @complex_access
func.func @complex_access(%arg0: memref<16xf32>, %arg1: memref<4xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<4xf32>
scf.parallel (%arg2) = (%c0) to (%c4) step (%c1) {
// CHECK: %[[C2:.*]] = arith.constant 2
// CHECK: %[[TILE:.*]] = arith.muli {{.*}} %[[C2]]
// CHECK: scf.parallel {{.*}} step (%[[TILE]])
// CHECK: scf.parallel
// We should see only 2 loops for complex access patterns
// CHECK-NOT: scf.parallel
%idx = arith.muli %arg2, %c4 : index
%2 = memref.load %arg0[%idx] : memref<16xf32>
%3 = math.log %2 : f32
memref.store %3, %0[%arg2] : memref<4xf32>
scf.yield
}
%1 = bufferization.to_tensor %0 : memref<4xf32>
memref.tensor_store %1, %arg1 : memref<4xf32>
"lmhlo.terminator"() : () -> ()
}
20 changes: 0 additions & 20 deletions tensorflow/compiler/mlir/tools/kernel_gen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,9 @@ cc_library(
hdrs = ["kernel_creator.h"],
copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]),
deps = [
":compile_cache_item_proto_cc",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:all_passes",
"//tensorflow/compiler/mlir/hlo:hlo_legalize_shape_ops_to_standard",
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/hlo:legalize_to_linalg",
"//tensorflow/compiler/mlir/hlo:legalize_trigonometric_to_approximation",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/hlo:shape_simplification",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
Expand All @@ -64,10 +56,8 @@ cc_library(
"//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_no_fallback",
"//tensorflow/core:lib",
"//tensorflow/core/platform:cuda_libdevice_path",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineToStandard",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:ArithmeticTransforms",
"@llvm-project//mlir:BufferizationTransforms",
Expand All @@ -79,28 +69,18 @@ cc_library(
"@llvm-project//mlir:GPUToNVVMTransforms",
"@llvm-project//mlir:GPUTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ROCDLDialect",
"@llvm-project//mlir:ROCDLToLLVMIRTranslation",
"@llvm-project//mlir:ReconcileUnrealizedCasts",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToGPUPass",
"@llvm-project//mlir:SCFToStandard",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeToStandard",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:ToLLVMIRTranslation",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorToLLVM",
],
Expand Down
Loading

0 comments on commit 23529a7

Please sign in to comment.