Skip to content

Commit

Permalink
Refresh transform dialect (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
chelini authored Jan 17, 2024
1 parent 30633f2 commit e30bd80
Show file tree
Hide file tree
Showing 17 changed files with 232 additions and 233 deletions.
12 changes: 0 additions & 12 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@ def ConvertCheckToLoops : Pass<"convert-check-to-loops", "func::FuncOp"> {
let dependentDialects = ["scf::SCFDialect"];
}

def TransformDialectInterpreter : Pass<"transform-dialect-interpreter", "ModuleOp"> {
let summary = "Apply transform dialect operations one by one";
let description = [{
Copy and paste from 'TestTransformDialectInterpreter.cpp'. Apply the transform
schedule.
}];
}

def ConvertPerfToLoops : Pass<"convert-perf-to-loops", "func::FuncOp"> {
let summary = "Convert perf to loops";
let description = [{
Expand All @@ -89,10 +81,6 @@ def ConvertPerfToFunc : Pass<"convert-perf-to-func", "ModuleOp"> {
"tensor::TensorDialect"];
}

def TransformDropSchedule : Pass<"transform-drop-schedule", "ModuleOp"> {
let summary = "Drop the transform schedule";
}

def PackVNNI : Pass<"pack-vnni", "func::FuncOp"> {
let summary = "Convert matmul/brgemm to vnni layout";
let description = [{
Expand Down
3 changes: 0 additions & 3 deletions lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,6 @@ struct DefaultTppPasses
void constructPipeline() override {
pm.clear();

// Default pipeline does not support transforms yet
pm.addPass(createTransformDropSchedule());

if (linalgToLoops) {
// Lower linalg directly to loops.
// Skip all TPP transformations.
Expand Down
1 change: 0 additions & 1 deletion lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ add_mlir_library(TPPTransforms
RewriteToBatchReduceGemm.cpp
TileConsumerAndFuseProducers.cpp
ToBlockLayoutAndBack.cpp
TransformDialectInterpreter.cpp
TransformUtils.cpp
CombineXsmmPass.cpp

Expand Down
60 changes: 0 additions & 60 deletions lib/TPP/Transforms/TransformDialectInterpreter.cpp

This file was deleted.

40 changes: 26 additions & 14 deletions test/Dialect/Transform/transform-collapse.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// RUN: tpp-opt -transform-dialect-interpreter -split-input-file -verify-diagnostics %s | FileCheck %s
// RUN: tpp-opt -transform-interpreter -split-input-file -verify-diagnostics %s | FileCheck %s

transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.collapse %0 [[0, 1], [2], [3, 4]] : !transform.any_op -> !transform.any_op
transform.yield
}
}

// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
Expand All @@ -28,11 +30,13 @@ func.func @parallel(%arg0: tensor<5x5x4x3x3xf32>, %arg1: tensor<5x5x4x3x3xf32>)

// -----

transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.collapse %0 [[0, 1], [2]] : !transform.any_op -> !transform.any_op
}
transform.yield
}
}

// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
Expand All @@ -59,10 +63,12 @@ func.func @parallel(%arg0: tensor<5x5x5xf32>, %arg1: tensor<5x5x5xf32>) -> tenso
// -----

// This must fail as we attempt to collapse dimensions of different types.
transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.collapse %0 [[0, 1, 2]] : !transform.any_op -> !transform.any_op
transform.yield
}
}

#map0 = affine_map<(i, j, k) -> (i, j)>
Expand All @@ -83,10 +89,12 @@ func.func @matmul(%arg0: tensor<3x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<
// -----

// This must fail as the reassociation dimensions do not match the number of loops.
transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.collapse %0 [[0, 1]] : !transform.any_op -> !transform.any_op
transform.yield
}
}

#map0 = affine_map<(i, j, k) -> (i, j)>
Expand All @@ -106,10 +114,12 @@ func.func @matmul(%arg0: tensor<3x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<

// -----

transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.collapse %0 [[0, 1], [2]] : !transform.any_op -> !transform.any_op
transform.yield
}
}

#map0 = affine_map<(i, j, k) -> (i, j, k)>
Expand All @@ -130,10 +140,12 @@ func.func @parallel(%arg0: tensor<3x3x3xf32> , %arg1: tensor<3x3x3xf32>) -> tens

// -----

transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.collapse %0 [[0, 1], [2]] : !transform.any_op -> !transform.any_op
transform.yield
}
}

#map0 = affine_map<(i, j, k) -> (i, j)>
Expand Down
50 changes: 32 additions & 18 deletions test/Dialect/Transform/transform-convolutions.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// RUN: tpp-opt %s -transform-dialect-interpreter -canonicalize -split-input-file | FileCheck %s
// RUN: tpp-opt %s -transform-interpreter -canonicalize -split-input-file | FileCheck %s

// Map a linalg.conv_2d_nhwc_hwcf to a matmul operation.
// Unit filter.
transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%1 = transform.structured.generalize %0 : (!transform.any_op) -> !transform.any_op
Expand Down Expand Up @@ -33,6 +33,8 @@ transform.sequence failures(propagate) {
%2 = transform.structured.interchange %1 iterator_interchange = [ 0, 1, 4, 5, 2, 3, 6 ]
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_conv_to_matmul %2 : !transform.any_op
transform.yield
}
}

func.func @conv1(%arg0: memref<1x4x4x3xf32>, %arg1: memref<1x1x3x8xf32>, %arg2: memref<1x4x4x8xf32>) {
Expand All @@ -57,13 +59,15 @@ func.func @conv1(%arg0: memref<1x4x4x3xf32>, %arg1: memref<1x1x3x8xf32>, %arg2:

// Map a linalg.conv_2d_nhwc_hwcf to a matmul operation.
// Non-unit filter.
transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.generalize %0 : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.interchange %1 iterator_interchange = [ 0, 1, 4, 5, 2, 3, 6 ]
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_conv_to_matmul %2 : !transform.any_op
transform.yield
}
}

func.func @conv2(%arg0: memref<1x4x4x3xf32>, %arg1: memref<2x2x3x8xf32>, %arg2: memref<1x3x3x8xf32>) {
Expand Down Expand Up @@ -92,13 +96,15 @@ func.func @conv2(%arg0: memref<1x4x4x3xf32>, %arg1: memref<2x2x3x8xf32>, %arg2:
// -----

// Unit filter but non-static dims.
transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.generalize %0 : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.interchange %1 iterator_interchange = [ 0, 1, 4, 5, 2, 3, 6 ]
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_conv_to_matmul %2 : !transform.any_op
transform.yield
}
}

func.func @conv3(%arg0: memref<?x?x?x?xf32>, %arg1: memref<1x1x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
Expand Down Expand Up @@ -130,8 +136,8 @@ func.func @conv3(%arg0: memref<?x?x?x?xf32>, %arg1: memref<1x1x?x?xf32>, %arg2:

// -----

transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// Original layout: [N][K][P][Q] = [N][C][H][W] * [K][C][R][S]
// New layout: [N][K'][P][Q][k] = [N][C'][H][W][c] * [K'][C'][R][S][c][k]
Expand Down Expand Up @@ -163,6 +169,8 @@ transform.sequence failures(propagate) {
%4 = transform.structured.interchange %3 iterator_interchange = [0, 1, 4, 2, 3, 5]
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_to_brgemm %4 : !transform.any_op
transform.yield
}
}

func.func @conv(%i: tensor<14x512x28x28xf32>, %f: tensor<1024x512x1x1xf32>,
Expand Down Expand Up @@ -272,8 +280,8 @@ func.func @walk(%arg0: tensor<1x1x64x64xf32>, %arg1: tensor<3x3x64x64xf32>, %arg
return %9 : tensor<1x56x56x64xf32>
}

transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
: (!transform.any_op) -> !transform.any_op
// Blocks all the convs
Expand All @@ -287,7 +295,7 @@ transform.sequence failures(propagate) {
: (!transform.any_op) -> !transform.any_op
%4 = transform.structured.get_blocked_convolutions %3
: (!transform.any_op) -> (!transform.op<"linalg.generic">)
%blocked_matmuls:2 = split_handle %4
%blocked_matmuls:2 = transform.split_handle %4
: (!transform.op<"linalg.generic">)
-> (!transform.op<"linalg.generic">, !transform.op<"linalg.generic">)
%first_relu = transform.get_consumers_of_result %blocked_matmuls#0[0]
Expand All @@ -308,8 +316,8 @@ transform.sequence failures(propagate) {
!transform.any_op,
!transform.any_op,
!transform.any_op)
%6 = get_producer_of_operand %5[0] : (!transform.any_op) -> !transform.any_op
%convs:2 = split_handle %6 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%6 = transform.get_producer_of_operand %5[0] : (!transform.any_op) -> !transform.any_op
%convs:2 = transform.split_handle %6 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)

// Map the conv to linalg.matmul
// With R = S = 3 we map to linalg.matmul
Expand All @@ -325,17 +333,21 @@ transform.sequence failures(propagate) {
%9 = transform.structured.interchange %8 iterator_interchange = [0, 1, 4, 2, 3, 5]
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_to_brgemm %9 : !transform.any_op
transform.yield
}
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.generalize %0 : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.interchange %1 iterator_interchange = [0, 1, 4, 5, 2, 3, 6]
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_conv_to_matmul %2 : !transform.any_op
transform.yield
}
}

func.func @conv2d_stride(%arg0: tensor<1x113x113x64xf32>, %arg1: tensor<3x3x64x256xf32>, %arg2: tensor<1x56x56x256xf32>) -> tensor<1x56x56x256xf32> {
Expand Down Expand Up @@ -367,13 +379,15 @@ func.func @conv2d_stride(%arg0: tensor<1x113x113x64xf32>, %arg1: tensor<3x3x64x2

// -----

transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.pack_ext %0 blocking_factors = [32, 32] : !transform.any_op -> !transform.any_op
%2 = transform.structured.interchange %1 iterator_interchange = [0, 1, 2, 5, 6, 7, 3, 4, 8]
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_conv_to_matmul %2 : !transform.any_op
transform.yield
}
}

func.func @conv2d_stride(%arg0: tensor<1x113x113x64xf32>, %arg1: tensor<3x3x64x256xf32>, %arg2: tensor<1x56x56x256xf32>) -> tensor<1x56x56x256xf32> {
Expand Down
20 changes: 0 additions & 20 deletions test/Dialect/Transform/transform-drop-schedule.mlir

This file was deleted.

Loading

0 comments on commit e30bd80

Please sign in to comment.