Skip to content

Commit

Permalink
Duplicate fill on contractions (#784)
Browse files Browse the repository at this point in the history
Duplicate fill operations when the use is a contraction and we can fold
the fill in the contraction later on in the pipeline using:
`fold-xsmm-flags`. Duplication avoids introducing `memref.copies` by
bufferization. Example,

```mlir
%0 = tensor.empty()
%1 = linalg.fill ins(...) outs(%0) // fill with zeros.
%2 = linalg.matmul ins(...) outs(%1)
%3 = linalg.matmul ins(...) outs(%1)
```
Without this PR it bufferizes as:

```mlir
%0 = memref.alloc()
%1 = memref.alloc()
linalg.fill ins(...) outs(%0) // fill with zeros.
memref.copy %0 into %1
linalg.matmul ins(...) outs(%0)
linalg.matmul ins(...) outs(%1)
```

With this PR the IR looks like:

```mlir
// no copies and fills folded as beta = 0.
%0 = memref.alloc()
%1 = memref.alloc()
xsmm.matmul ins(...) outs(%0) // beta = 0
xsmm.matmul ins(...) outs(%1) // beta = 0
```

The PR has minor performance impact, the only notable improvement is for
`fp32_mha_tensorflow_seq_len_32`. The IR looks cleaner too with 1 less
allocation and all the beta flags properly folded.
`fp32_mha_tensorflow_seq_len_1024` does not improve because
dimensionality allows fusion to distribute the fill, see:
b1167fe.

This PR is part of #783
  • Loading branch information
chelini authored Nov 20, 2023
1 parent 4091bf8 commit 48455d0
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 2 deletions.
17 changes: 15 additions & 2 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -292,18 +292,31 @@ def ConvInitSimplify : Pass<"conv-init-simplify", "func::FuncOp"> {
def Bufferize : Pass<"bufferize", "ModuleOp"> {
let summary = "Bufferize tensor to memref for the entire module";
let options = [
Option<"dealloc", "dealloc",
"bool", /*default=*/"1",
Option<"dealloc", "dealloc", "bool",
/*default=*/"true",
"Enables automatic deallocation.">,
Option<"testAnalysisOnly", "test-analysis-only", "bool",
/*default=*/"false",
"Only runs inplaceability analysis (for testing purposes only)">,
Option<"printConflicts", "print-conflicts", "bool",
/*default=*/"false",
"Annotates IR with RaW conflicts. Requires test-analysis-only.">,
Option<"duplicateFill", "duplicate-fill", "bool",
/*default=*/"true",
"Enable duplication of fill operation (for testing only).">
];
}

def DuplicateFill : Pass<"duplicate-fill", "func::FuncOp"> {
let summary = "Duplicate fill operations";
let description = [{
Duplicate linalg.fill operations to avoid memref.copy after
bufferization. This can trigger later folding of the fill.
We duplicate only zero fill on contraction operations.
}];
let dependentDialects = [ "linalg::LinalgDialect" ];
}

def Cleanup : Pass<"cleanup", "func::FuncOp"> {
let summary = "General IR cleanup e.g., canonicalization, CSE etc.";
}
Expand Down
45 changes: 45 additions & 0 deletions lib/TPP/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include "TPP/Passes.h"
#include "TPP/Transforms/Transforms.h"

#include "TPP/Transforms/Utils/TransformUtils.h"
#include "TPP/Transforms/Utils/ValueUtils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -42,6 +45,8 @@ namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_BUFFERIZE
#include "TPP/Passes.h.inc"
#define GEN_PASS_DEF_DUPLICATEFILL
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

Expand Down Expand Up @@ -72,12 +77,52 @@ struct Bufferize : public tpp::impl::BufferizeBase<Bufferize> {
void runOnOperation() override;
};

struct DuplicateFill : public tpp::impl::DuplicateFillBase<DuplicateFill> {
void runOnOperation() override;
};

void DuplicateFill::runOnOperation() {
IRRewriter rewriter(&getContext());

(void)getOperation()->walk([&](linalg::FillOp fillOp) {
if (!fillOp.hasTensorSemantics())
return WalkResult::advance();
Value fillVal = fillOp.getResult(0);
// We can fold only zero initialization. We duplicate only
// if the fill has multiple uses.
if (!utils::isZeroTensor(fillVal) || fillOp->hasOneUse())
return WalkResult::advance();
SetVector<Operation *> forwardSlice;
getForwardSlice(fillVal, &forwardSlice);
for (size_t idx = /*Skip first user. Use the current fill*/ 1;
idx < forwardSlice.size(); idx++)
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(forwardSlice[idx])) {
if (failed(linalgx::utils::isContraction(linalgOp)))
continue;
assert(linalgOp.getNumDpsInits() == 1);
Value outLinalg = linalgOp.getDpsInits()[0];
if (outLinalg == fillVal) {
rewriter.setInsertionPoint(linalgOp);
Operation *clonedOp = rewriter.clone(*fillOp.getOperation());
rewriter.replaceUsesWithIf(fillOp->getResults(),
clonedOp->getResults(),
[&](OpOperand &operand) {
return operand.getOwner() == linalgOp;
});
}
}
return WalkResult::advance();
});
}

void Bufferize::runOnOperation() {
ModuleOp moduleOp = getOperation();

OpPassManager passManager;

// Pre-processing.
if (this->duplicateFill)
passManager.addNestedPass<func::FuncOp>(tpp::createDuplicateFill());
passManager.addPass(bufferization::createEmptyTensorEliminationPass());
passManager.addPass(bufferization::createEmptyTensorToAllocTensorPass());

Expand Down
203 changes: 203 additions & 0 deletions test/Passes/pass-duplicate-fill.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// RUN: tpp-opt %s -duplicate-fill -split-input-file | FileCheck %s

// Check we do not introduce additional allocations or copies.
// RUN: tpp-opt %s -bufferize -split-input-file | FileCheck %s -check-prefix=BUFF
// RUN: tpp-opt %s -bufferize="duplicate-fill=false" -split-input-file | FileCheck %s -check-prefix=BUFFNOTDUP

#map = affine_map<(d0, d1, d2) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>

// CHECK-LABEL: duplicate_zero_fill_on_contractions
// BUFF-LABEL: duplicate_zero_fill_on_contractions
// BUFFNOTDUP-LABEL: duplicate_zero_fill_on_contractions
func.func @duplicate_zero_fill_on_contractions(%arg0: tensor<32x512xf32>,
%arg1: tensor<512x64xf32>) -> tensor<32x64xf32> {
// BUFF-COUNT-2: memref.alloc
// BUFF-COUNT-1: memref.dealloc
// BUFF-NOT: memref.copy
//
// BUFFNOTDUP-COUNT-2: memref.alloc
// BUFFNOTDUP-COUNT-1: memref.dealloc
// BUFFNOTDUP-NOT: memref.copy
%cst_2 = arith.constant 0.0 : f32
%0 = tensor.empty() : tensor<32x64xf32>
%1 = linalg.fill ins(%cst_2 : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
// CHECK: linalg.fill
// CHECK-NEXT: linalg.generic
%3 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%arg0, %arg1 : tensor<32x512xf32>, tensor<512x64xf32>) outs(%1 : tensor<32x64xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%9 = arith.mulf %in, %in_5 : f32
%10 = arith.addf %out, %9 : f32
linalg.yield %10 : f32
} -> tensor<32x64xf32>
// CHECK: linalg.fill
// CHECK-NEXT: linalg.generic
%4 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%arg0, %arg1 : tensor<32x512xf32>, tensor<512x64xf32>) outs(%1 : tensor<32x64xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%9 = arith.mulf %in, %in_5 : f32
%10 = arith.addf %out, %9 : f32
linalg.yield %10 : f32
} -> tensor<32x64xf32>
// CHECK-NOT: linalg.fill
%5 = linalg.add ins(%3, %4 : tensor<32x64xf32>, tensor<32x64xf32>)
outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
return %5 : tensor<32x64xf32>
}

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map5 = affine_map<(d0, d1, d2) -> (d2, d0)>

func.func @mha_contractions(%arg0: tensor<64x32x512xf32>, %arg1: tensor<64x32x512xf32>,
%arg2: tensor<64x32x512xf32>) -> tensor<64x8x32x32xf32> {
%cst = arith.constant dense<2.000000e-01> : tensor<512x64xf32>
%cst_0 = arith.constant dense<1.000000e-01> : tensor<512x64xf32>
%cst_1 = arith.constant dense<1.250000e-01> : tensor<32x64xf32>
%cst_2 = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<64x8x32x32xf32>
%1 = scf.forall (%arg3, %arg4) in (64, 8) shared_outs(%arg5 = %0) -> (tensor<64x8x32x32xf32>) {
// BUFF-COUNT-2: memref.alloc
// BUFF-COUNT-2: memref.dealloc
//
// BUFFNOTDUP-COUNT-2: memref.alloc
// BUFFNOTDUP-COUNT-2: memref.dealloc
%2 = tensor.empty() : tensor<32x64xf32>
// CHECK: linalg.fill
// CHECK-NEXT: tensor.extract_slice
// CHECK-NEXT: linalg.generic
%3 = linalg.fill ins(%cst_2 : f32) outs(%2 : tensor<32x64xf32>) -> tensor<32x64xf32>
%extracted_slice = tensor.extract_slice %arg1[%arg3, 0, 0] [1, 32, 512] [1, 1, 1] : tensor<64x32x512xf32> to tensor<32x512xf32>
%4 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%extracted_slice, %cst : tensor<32x512xf32>, tensor<512x64xf32>) outs(%3 : tensor<32x64xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%9 = arith.mulf %in, %in_5 : f32
%10 = arith.addf %out, %9 : f32
linalg.yield %10 : f32
} -> tensor<32x64xf32>
%extracted_slice_3 = tensor.extract_slice %arg0[%arg3, 0, 0] [1, 32, 512] [1, 1, 1] : tensor<64x32x512xf32> to tensor<32x512xf32>
// CHECK: linalg.fill
// CHECK-NEXT: linalg.generic
%5 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%extracted_slice_3, %cst_0 : tensor<32x512xf32>, tensor<512x64xf32>) outs(%3 : tensor<32x64xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%9 = arith.mulf %in, %in_5 : f32
%10 = arith.addf %out, %9 : f32
linalg.yield %10 : f32
} -> tensor<32x64xf32>
%6 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%5, %cst_1 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%2 : tensor<32x64xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%9 = arith.mulf %in, %in_5 : f32
linalg.yield %9 : f32
} -> tensor<32x64xf32>
%extracted_slice_4 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<64x8x32x32xf32> to tensor<32x32xf32>
%7 = linalg.fill ins(%cst_2 : f32) outs(%extracted_slice_4 : tensor<32x32xf32>) -> tensor<32x32xf32>
%8 = linalg.generic {indexing_maps = [#map, #map4, #map5], iterator_types = ["parallel", "reduction", "parallel"]} ins(%4, %6 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%7 : tensor<32x32xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%9 = arith.mulf %in, %in_5 : f32
%10 = arith.addf %out, %9 : f32
linalg.yield %10 : f32
} -> tensor<32x32xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %8 into %arg5[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<64x8x32x32xf32>
}
}
return %1 : tensor<64x8x32x32xf32>
}

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>

// CHECK-LABEL: duplicate_non_zero_fill_on_contractions
// BUFF-LABEL: duplicate_non_zero_fill_on_contractions
// BUFFNOTDUP-LABEL: duplicate_non_zero_fill_on_contractions
func.func @duplicate_non_zero_fill_on_contractions(%arg0: tensor<32x512xf32>,
%arg1: tensor<512x64xf32>) -> tensor<32x64xf32> {
%cst_2 = arith.constant 1.0 : f32
%0 = tensor.empty() : tensor<32x64xf32>
%1 = linalg.fill ins(%cst_2 : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
// CHECK: linalg.fill
// CHECK-NEXT: linalg.generic
%3 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%arg0, %arg1 : tensor<32x512xf32>, tensor<512x64xf32>) outs(%1 : tensor<32x64xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%9 = arith.mulf %in, %in_5 : f32
%10 = arith.addf %out, %9 : f32
linalg.yield %10 : f32
} -> tensor<32x64xf32>
// CHECK-NOT: linalg.fill
// CHECK: linalg.generic
%4 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%arg0, %arg1 : tensor<32x512xf32>, tensor<512x64xf32>) outs(%1 : tensor<32x64xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%9 = arith.mulf %in, %in_5 : f32
%10 = arith.addf %out, %9 : f32
linalg.yield %10 : f32
} -> tensor<32x64xf32>
// CHECK-NOT: linalg.fill
%5 = linalg.add ins(%3, %4 : tensor<32x64xf32>, tensor<32x64xf32>)
outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
return %5 : tensor<32x64xf32>
}

// -----

func.func @matmuls(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> {
// BUFF-COUNT-2: memref.alloc
// BUFFNOTDUP-COUNT-2: memref.alloc
%0 = tensor.empty() : tensor<32x32xf32>
%cst = arith.constant 0.0 : f32
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x32xf32>) -> tensor<32x32xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>)
outs(%1 : tensor<32x32xf32>) -> tensor<32x32xf32>
%3 = linalg.matmul ins(%arg0, %2 : tensor<32x32xf32>, tensor<32x32xf32>)
outs(%1 : tensor<32x32xf32>) -> tensor<32x32xf32>
return %3 : tensor<32x32xf32>
}

// CHECK-LABEL: matmuls
// CHECK-SAME: %[[ARG0:.+]]: tensor<32x32xf32>, %[[ARG1:.+]]: tensor<32x32xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x32xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}} : f32) outs(%[[EMPTY]] : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %[[MUL:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<32x32xf32>, tensor<32x32xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %[[FILL_1:.+]] = linalg.fill ins(%{{.+}} : f32) outs(%[[EMPTY]] : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.matmul ins(%[[ARG0]], %[[MUL]] : tensor<32x32xf32>, tensor<32x32xf32>)
// CHECK-SAME: outs(%[[FILL_1]] : tensor<32x32xf32>) -> tensor<32x32xf32>

// -----

// CHECK-LABEL: matmuls_1
func.func @matmuls_1(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> {
// BUFF-COUNT-3: memref.alloc
// BUFF-COUNT-1: memref.copy
// BUFFNOTDUP-COUNT-3: memref.alloc
// BUFFNOTDUP-COUNT-1: memref.copy
// CHECK-COUNT-2: linalg.fill
%0 = tensor.empty() : tensor<32x32xf32>
%cst = arith.constant 0.0 : f32
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x32xf32>) -> tensor<32x32xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>)
outs(%1 : tensor<32x32xf32>) -> tensor<32x32xf32>
%3 = linalg.add ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>)
outs(%1 : tensor<32x32xf32>) -> tensor<32x32xf32>
%4 = linalg.matmul ins(%3, %2 : tensor<32x32xf32>, tensor<32x32xf32>)
outs(%1 : tensor<32x32xf32>) -> tensor<32x32xf32>
return %4 : tensor<32x32xf32>
}

0 comments on commit 48455d0

Please sign in to comment.