Skip to content

Commit

Permalink
refactor(compiler): Remove attribute "tile-sizes" from operations aft…
Browse files Browse the repository at this point in the history
…er tiling

The tiling infrastructure preserves attributes of tiled
`linalg.generic` operations, such that the attribute for the tile
sizes specified for the `linalg.generic` operation before tiling is
copied to the `linalg.generic` operation that is part of the generated
IR for a single tile.

This change causes the attribute to be removed after tiling, since it
does not make sense to preserve the attribute for per-tile operations.
  • Loading branch information
andidr authored and BourgerieQuentin committed Apr 12, 2024
1 parent a8231ce commit 12ab53b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class GenericTilingPattern
if (lres.succeeded()) {
res.value().tileOp->setAttr(kTransformMarker, rewriter.getUnitAttr());
res.value().tiledOp->setAttr(kTransformMarker, rewriter.getUnitAttr());
res.value().tiledOp->removeAttr("tile-sizes");
rewriter.replaceOp(op.getOperation(), res.value().tileOp->getResults());
}

Expand Down Expand Up @@ -141,6 +142,7 @@ class GenericTilingPattern
if (lres.succeeded()) {
res.value().parallelTiledOp->setAttr(kTransformMarker,
rewriter.getUnitAttr());
res.value().parallelTiledOp->removeAttr("tile-sizes");
res.value().mergeOp->setAttr(kTransformMarker, rewriter.getUnitAttr());
res.value().initialOp->setAttr(kTransformMarker,
rewriter.getUnitAttr());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// CHECK-NEXT: %[[V7:.*]] = affine.apply #map1(%[[Varg2]])
// CHECK-NEXT: %[[Vextracted_slice_0:.*]] = tensor.extract_slice %[[Varg0:.*]]{{\[0,}} %[[V6]]{{\] \[8, 2\] \[1, 1\]}} : tensor<8x4x!FHE.eint<6>> to tensor<8x2x!FHE.eint<6>>
// CHECK-NEXT: %[[Vextracted_slice_1:.*]] = tensor.extract_slice %[[Varg1:.*]]{{\[}}%[[V7]], 0{{\] \[2, 2\] \[1, 1\]}} : tensor<4x2xi7> to tensor<2x2xi7>
// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map3, #map4{{\], iterator}}_types = {{\[}}"parallel", "parallel", "reduction"{{\]}}} ins(%[[Vextracted_slice_0]], %[[Vextracted_slice_1]] : tensor<8x2x!FHE.eint<6>>, tensor<2x2xi7>) outs(%[[Vextracted_slice]] : tensor<8x2x!FHE.eint<6>>) attrs = {"tile-sizes" = {{\[0, 0, 2\]}}} {
// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map3, #map4{{\], iterator}}_types = {{\[}}"parallel", "parallel", "reduction"{{\]}}} ins(%[[Vextracted_slice_0]], %[[Vextracted_slice_1]] : tensor<8x2x!FHE.eint<6>>, tensor<2x2xi7>) outs(%[[Vextracted_slice]] : tensor<8x2x!FHE.eint<6>>) {
// CHECK-NEXT: ^bb0(%[[Vin:.*]]: !FHE.eint<6>, %[[Vin_2:.*]]: i7, %[[Vout:.*]]: !FHE.eint<6>):
// CHECK-NEXT: %[[V9:.*]] = "FHE.mul_eint_int"(%[[Vin]], %[[Vin_2]]) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
// CHECK-NEXT: %[[V10:.*]] = "FHE.add_eint"(%[[Vout]], %[[V9]]) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
Expand All @@ -30,7 +30,7 @@ func.func @tiled_2(%a: tensor<8x4x!FHE.eint<6>>, %b: tensor<4x2xi7>) -> tensor<8
// CHECK-NEXT: %[[V7:.*]] = affine.apply #map1(%[[Varg2]])
// CHECK-NEXT: %[[Vextracted_slice_0:.*]] = tensor.extract_slice %[[Varg0:.*]]{{\[0,}} %[[V6]]{{\] \[8, 4\] \[1, 1\]}} : tensor<8x4x!FHE.eint<6>> to tensor<8x4x!FHE.eint<6>>
// CHECK-NEXT: %[[Vextracted_slice_1:.*]] = tensor.extract_slice %[[Varg1:.*]]{{\[}}%[[V7]], 0{{\] \[4, 2\] \[1, 1\]}} : tensor<4x2xi7> to tensor<4x2xi7>
// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map3, #map4{{\], iterator}}_types = {{\[}}"parallel", "parallel", "reduction"{{\]}}} ins(%[[Vextracted_slice_0]], %[[Vextracted_slice_1]] : tensor<8x4x!FHE.eint<6>>, tensor<4x2xi7>) outs(%[[Vextracted_slice]] : tensor<8x2x!FHE.eint<6>>) attrs = {"tile-sizes" = {{\[0, 0, 4\]}}} {
// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map3, #map4{{\], iterator}}_types = {{\[}}"parallel", "parallel", "reduction"{{\]}}} ins(%[[Vextracted_slice_0]], %[[Vextracted_slice_1]] : tensor<8x4x!FHE.eint<6>>, tensor<4x2xi7>) outs(%[[Vextracted_slice]] : tensor<8x2x!FHE.eint<6>>) {
// CHECK-NEXT: ^bb0(%[[Vin:.*]]: !FHE.eint<6>, %[[Vin_2:.*]]: i7, %[[Vout:.*]]: !FHE.eint<6>):
// CHECK-NEXT: %[[V9:.*]] = "FHE.mul_eint_int"(%[[Vin]], %[[Vin_2]]) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
// CHECK-NEXT: %[[V10:.*]] = "FHE.add_eint"(%[[Vout]], %[[V9]]) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
Expand All @@ -57,7 +57,7 @@ func.func @tiled_one_big_tile(%a: tensor<8x4x!FHE.eint<6>>, %b: tensor<4x2xi7>)
// CHECK-NEXT: %[[V7:.*]] = affine.apply #map(%[[Varg4]])
// CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V2]], %[[V3]], %[[V4]]{{\] \[2, 3, 2\] \[1, 1, 1\]}} : tensor<2x3x4x!FHE.eint<2>> to tensor<2x3x2x!FHE.eint<2>>
// CHECK-NEXT: %[[Vextracted_slice_0:.*]] = tensor.extract_slice %[[Varg5]]{{\[}}%[[V5]], %[[V6]], %[[V7]]{{\] \[2, 3, 2\] \[1, 1, 1\]}} : tensor<2x3x4x!FHE.eint<2>> to tensor<2x3x2x!FHE.eint<2>>
// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map2{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Vextracted_slice]] : tensor<2x3x2x!FHE.eint<2>>) outs(%[[Vextracted_slice_0]] : tensor<2x3x2x!FHE.eint<2>>) attrs = {"tile-sizes" = {{\[2, 3, 2\]}}} {
// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map2, #map2{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Vextracted_slice]] : tensor<2x3x2x!FHE.eint<2>>) outs(%[[Vextracted_slice_0]] : tensor<2x3x2x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%[[Vin:.*]]: !FHE.eint<2>, %[[Vout:.*]]: !FHE.eint<2>):
// CHECK-NEXT: %[[V12:.*]] = "FHE.apply_lookup_table"(%[[Vin]], %[[Varg1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK-NEXT: linalg.yield %[[V12]] : !FHE.eint<2>
Expand All @@ -76,7 +76,7 @@ func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4

// -----

// CHECK: %[[res:.*]] = linalg.generic {indexing_maps = [#[[map:.*]], #[[map2:.*]], #[[map3:.*]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[extracted_slice:.*]], %[[extracted_slice_0:.*]] : tensor<2x3x2x!FHE.eint<7>>, tensor<2x3x2xindex>) outs(%[[extracted_slice_1:.*]] : tensor<2x3x2x!FHE.eint<7>>) attrs = {"tile-sizes" = [2, 3, 2]}
// CHECK: %[[res:.*]] = linalg.generic {indexing_maps = [#[[map:.*]], #[[map2:.*]], #[[map3:.*]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[extracted_slice:.*]], %[[extracted_slice_0:.*]] : tensor<2x3x2x!FHE.eint<7>>, tensor<2x3x2xindex>) outs(%[[extracted_slice_1:.*]] : tensor<2x3x2x!FHE.eint<7>>)
func.func @apply_mapped_lookup_table(
%input: tensor<2x3x4x!FHE.eint<7>>,
%luts: tensor<10x128xi64>,
Expand All @@ -88,7 +88,7 @@ func.func @apply_mapped_lookup_table(

// -----

// CHECK: %[[res:.*]] = linalg.generic {indexing_maps = [#[[map:.*]], #[[map2:.*]]], iterator_types = ["parallel", "parallel"]} ins(%[[extracted_slice:.*]] : tensor<3x3x!FHE.eint<2>>) outs(%[[extracted_slice_0:.*]] : tensor<3x3x!FHE.eint<2>>) attrs = {"tile-sizes" = [3, 3]} {
// CHECK: %[[res:.*]] = linalg.generic {indexing_maps = [#[[map:.*]], #[[map2:.*]]], iterator_types = ["parallel", "parallel"]} ins(%[[extracted_slice:.*]] : tensor<3x3x!FHE.eint<2>>) outs(%[[extracted_slice_0:.*]] : tensor<3x3x!FHE.eint<2>>) {
func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x3x4xi8>) -> tensor<3x3x!FHE.eint<2>> {
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1) { "tile-sizes" = [3, 3] }: (tensor<3x3x!FHE.eint<2>>, tensor<3x3x4xi8>) -> tensor<3x3x!FHE.eint<2>>
return %1: tensor<3x3x!FHE.eint<2>>
Expand Down

0 comments on commit 12ab53b

Please sign in to comment.