From 12ab53ba9248be4bca23d5f3847790ff2d66c678 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Thu, 11 Apr 2024 14:28:12 +0200 Subject: [PATCH] refactor(compiler): Remove attribute "tile-sizes" from operations after tiling The tiling infrastructure preserves attributes of tiled `linalg.generic` operations, such that the attribute for the tile sizes specified for the `linalg.generic` operation before tiling is copied to the `linalg.generic` operation that is part of the generated IR for a single tile. This change causes the attribute to be removed after tiling, since it does not make sense to preserve the attribute for per-tile operations. --- .../lib/Dialect/FHELinalg/Transforms/Tiling.cpp | 2 ++ .../tests/check_tests/Dialect/FHELinalg/tiling.mlir | 10 +++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp index b258115583..f816e814c5 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp @@ -101,6 +101,7 @@ class GenericTilingPattern if (lres.succeeded()) { res.value().tileOp->setAttr(kTransformMarker, rewriter.getUnitAttr()); res.value().tiledOp->setAttr(kTransformMarker, rewriter.getUnitAttr()); + res.value().tiledOp->removeAttr("tile-sizes"); rewriter.replaceOp(op.getOperation(), res.value().tileOp->getResults()); } @@ -141,6 +142,7 @@ class GenericTilingPattern if (lres.succeeded()) { res.value().parallelTiledOp->setAttr(kTransformMarker, rewriter.getUnitAttr()); + res.value().parallelTiledOp->removeAttr("tile-sizes"); res.value().mergeOp->setAttr(kTransformMarker, rewriter.getUnitAttr()); res.value().initialOp->setAttr(kTransformMarker, rewriter.getUnitAttr()); diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir index 5fdd4b3b2d..24fde20ead 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tiling.mlir @@ -6,7 +6,7 @@ // CHECK-NEXT: %[[V7:.*]] = affine.apply #map1(%[[Varg2]]) // CHECK-NEXT: %[[Vextracted_slice_0:.*]] = tensor.extract_slice %[[Varg0:.*]]{{\[0,}} %[[V6]]{{\] \[8, 2\] \[1, 1\]}} : tensor<8x4x!FHE.eint<6>> to tensor<8x2x!FHE.eint<6>> // CHECK-NEXT: %[[Vextracted_slice_1:.*]] = tensor.extract_slice %[[Varg1:.*]]{{\[}}%[[V7]], 0{{\] \[2, 2\] \[1, 1\]}} : tensor<4x2xi7> to tensor<2x2xi7> -// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map3, #map4{{\], iterator}}_types = {{\[}}"parallel", "parallel", "reduction"{{\]}}} ins(%[[Vextracted_slice_0]], %[[Vextracted_slice_1]] : tensor<8x2x!FHE.eint<6>>, tensor<2x2xi7>) outs(%[[Vextracted_slice]] : tensor<8x2x!FHE.eint<6>>) attrs = {"tile-sizes" = {{\[0, 0, 2\]}}} { +// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map3, #map4{{\], iterator}}_types = {{\[}}"parallel", "parallel", "reduction"{{\]}}} ins(%[[Vextracted_slice_0]], %[[Vextracted_slice_1]] : tensor<8x2x!FHE.eint<6>>, tensor<2x2xi7>) outs(%[[Vextracted_slice]] : tensor<8x2x!FHE.eint<6>>) { // CHECK-NEXT: ^bb0(%[[Vin:.*]]: !FHE.eint<6>, %[[Vin_2:.*]]: i7, %[[Vout:.*]]: !FHE.eint<6>): // CHECK-NEXT: %[[V9:.*]] = "FHE.mul_eint_int"(%[[Vin]], %[[Vin_2]]) : (!FHE.eint<6>, i7) -> !FHE.eint<6> // CHECK-NEXT: %[[V10:.*]] = "FHE.add_eint"(%[[Vout]], %[[V9]]) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> @@ -30,7 +30,7 @@ func.func @tiled_2(%a: tensor<8x4x!FHE.eint<6>>, %b: tensor<4x2xi7>) -> tensor<8 // CHECK-NEXT: %[[V7:.*]] = affine.apply #map1(%[[Varg2]]) // CHECK-NEXT: %[[Vextracted_slice_0:.*]] = tensor.extract_slice %[[Varg0:.*]]{{\[0,}} %[[V6]]{{\] \[8, 4\] \[1, 1\]}} : tensor<8x4x!FHE.eint<6>> to tensor<8x4x!FHE.eint<6>> // CHECK-NEXT: %[[Vextracted_slice_1:.*]] = tensor.extract_slice %[[Varg1:.*]]{{\[}}%[[V7]], 0{{\] \[4, 2\] \[1, 1\]}} : tensor<4x2xi7> to tensor<4x2xi7> -// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map3, #map4{{\], iterator}}_types = {{\[}}"parallel", "parallel", "reduction"{{\]}}} ins(%[[Vextracted_slice_0]], %[[Vextracted_slice_1]] : tensor<8x4x!FHE.eint<6>>, tensor<4x2xi7>) outs(%[[Vextracted_slice]] : tensor<8x2x!FHE.eint<6>>) attrs = {"tile-sizes" = {{\[0, 0, 4\]}}} { +// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map3, #map4{{\], iterator}}_types = {{\[}}"parallel", "parallel", "reduction"{{\]}}} ins(%[[Vextracted_slice_0]], %[[Vextracted_slice_1]] : tensor<8x4x!FHE.eint<6>>, tensor<4x2xi7>) outs(%[[Vextracted_slice]] : tensor<8x2x!FHE.eint<6>>) { // CHECK-NEXT: ^bb0(%[[Vin:.*]]: !FHE.eint<6>, %[[Vin_2:.*]]: i7, %[[Vout:.*]]: !FHE.eint<6>): // CHECK-NEXT: %[[V9:.*]] = "FHE.mul_eint_int"(%[[Vin]], %[[Vin_2]]) : (!FHE.eint<6>, i7) -> !FHE.eint<6> // CHECK-NEXT: %[[V10:.*]] = "FHE.add_eint"(%[[Vout]], %[[V9]]) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> @@ -57,7 +57,7 @@ func.func @tiled_one_big_tile(%a: tensor<8x4x!FHE.eint<6>>, %b: tensor<4x2xi7>) // CHECK-NEXT: %[[V7:.*]] = affine.apply #map(%[[Varg4]]) // CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V2]], %[[V3]], %[[V4]]{{\] \[2, 3, 2\] \[1, 1, 1\]}} : tensor<2x3x4x!FHE.eint<2>> to tensor<2x3x2x!FHE.eint<2>> // CHECK-NEXT: %[[Vextracted_slice_0:.*]] = tensor.extract_slice %[[Varg5]]{{\[}}%[[V5]], %[[V6]], %[[V7]]{{\] \[2, 3, 2\] \[1, 1, 1\]}} : tensor<2x3x4x!FHE.eint<2>> to tensor<2x3x2x!FHE.eint<2>> -// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map2{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Vextracted_slice]] : tensor<2x3x2x!FHE.eint<2>>) outs(%[[Vextracted_slice_0]] : tensor<2x3x2x!FHE.eint<2>>) attrs = {"tile-sizes" = {{\[2, 3, 2\]}}} { +// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map2{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Vextracted_slice]] : tensor<2x3x2x!FHE.eint<2>>) outs(%[[Vextracted_slice_0]] : tensor<2x3x2x!FHE.eint<2>>) { // CHECK-NEXT: ^bb0(%[[Vin:.*]]: !FHE.eint<2>, %[[Vout:.*]]: !FHE.eint<2>): // CHECK-NEXT: %[[V12:.*]] = "FHE.apply_lookup_table"(%[[Vin]], %[[Varg1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> // CHECK-NEXT: linalg.yield %[[V12]] : !FHE.eint<2> @@ -76,7 +76,7 @@ func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4 // ----- -// CHECK: %[[res:.*]] = linalg.generic {indexing_maps = [#[[map:.*]], #[[map2:.*]], #[[map3:.*]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[extracted_slice:.*]], %[[extracted_slice_0:.*]] : tensor<2x3x2x!FHE.eint<7>>, tensor<2x3x2xindex>) outs(%[[extracted_slice_1:.*]] : tensor<2x3x2x!FHE.eint<7>>) attrs = {"tile-sizes" = [2, 3, 2]} +// CHECK: %[[res:.*]] = linalg.generic {indexing_maps = [#[[map:.*]], #[[map2:.*]], #[[map3:.*]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[extracted_slice:.*]], %[[extracted_slice_0:.*]] : tensor<2x3x2x!FHE.eint<7>>, tensor<2x3x2xindex>) outs(%[[extracted_slice_1:.*]] : tensor<2x3x2x!FHE.eint<7>>) func.func @apply_mapped_lookup_table( %input: tensor<2x3x4x!FHE.eint<7>>, %luts: tensor<10x128xi64>, @@ -88,7 +88,7 @@ func.func @apply_mapped_lookup_table( // ----- -// CHECK: %[[res:.*]] = linalg.generic {indexing_maps = [#[[map:.*]], #[[map2:.*]]], iterator_types = ["parallel", "parallel"]} ins(%[[extracted_slice:.*]] : tensor<3x3x!FHE.eint<2>>) outs(%[[extracted_slice_0:.*]] : tensor<3x3x!FHE.eint<2>>) attrs = {"tile-sizes" = [3, 3]} { +// CHECK: %[[res:.*]] = linalg.generic {indexing_maps = [#[[map:.*]], #[[map2:.*]]], iterator_types = ["parallel", "parallel"]} ins(%[[extracted_slice:.*]] : tensor<3x3x!FHE.eint<2>>) outs(%[[extracted_slice_0:.*]] : tensor<3x3x!FHE.eint<2>>) { func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x3x4xi8>) -> tensor<3x3x!FHE.eint<2>> { %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1) { "tile-sizes" = [3, 3] }: (tensor<3x3x!FHE.eint<2>>, tensor<3x3x4xi8>) -> tensor<3x3x!FHE.eint<2>> return %1: tensor<3x3x!FHE.eint<2>>