Skip to content

Commit

Permalink
feat(compiler): Add support for tiling of fhelinalg.transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
andidr authored and BourgerieQuentin committed Apr 12, 2024
1 parent 12ab53b commit fcfaaee
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1261,13 +1261,14 @@ struct TransposeToLinalgGeneric
mlir::Value item = blockArgs[0];
nestedBuilder.create<linalg::YieldOp>(location, item);
};
mlir::Value result =
rewriter
.create<linalg::GenericOp>(location, resultTypes, ins, outs, maps,
iteratorTypes, regionBuilder)
.getResult(0);

rewriter.replaceOp(transposeOp, {result});
linalg::GenericOp genericOp = rewriter.create<linalg::GenericOp>(
location, resultTypes, ins, outs, maps, iteratorTypes, regionBuilder);

if (transposeOp->hasAttr("tile-sizes"))
genericOp->setAttr("tile-sizes", transposeOp->getAttr("tile-sizes"));

rewriter.replaceOp(transposeOp, genericOp.getResults());
return mlir::success();
};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class FHELinalgTilingMarkerPass
mlir::concretelang::FHELinalg::MatMulIntEintOp,
mlir::concretelang::FHELinalg::MatMulEintEintOp>(op) ||
llvm::isa<mlir::concretelang::FHELinalg::SumOp>(op) ||
llvm::isa<mlir::concretelang::FHELinalg::TransposeOp>(op) ||
llvm::isa<mlir::concretelang::FHELinalg::AddEintOp,
mlir::concretelang::FHELinalg::AddEintIntOp,
mlir::concretelang::FHELinalg::SubIntEintOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,42 @@ func.func @main(%arg0: tensor<3x8x4x!FHE.eint<7>>) -> tensor<3x8x!FHE.eint<7>> {
%0 = "FHELinalg.sum"(%arg0) { axes = [2], "tile-sizes" = [3, 8, 2] } : (tensor<3x8x4x!FHE.eint<7>>) -> tensor<3x8x!FHE.eint<7>>
return %0 : tensor<3x8x!FHE.eint<7>>
}

// -----

// CHECK: #map = affine_map<(d0) -> (d0 * -3 + 2, 3)>
// CHECK-NEXT: #map1 = affine_map<(d0) -> (d0 * -8 + 10, 8)>
// CHECK-NEXT: #map2 = affine_map<(d0) -> (d0 * 3)>
// CHECK-NEXT: #map3 = affine_map<(d0) -> (d0 * 8)>
// CHECK-NEXT: #map4 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-NEXT: #map5 = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @transpose_eint_2D(%[[Varg0:.*]]: tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> {
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<10x2x!FHE.eint<6>>
// CHECK-NEXT: %[[V1:.*]] = scf.forall (%[[Varg1]], %[[Varg2]]) in (1, 2) shared_outs(%[[Varg3:.*]] = %[[V0]]) -> (tensor<10x2x!FHE.eint<6>>) {
// CHECK-NEXT: %[[V2:.*]] = affine.min #map(%[[Varg1]])
// CHECK-NEXT: %[[V3:.*]] = affine.min #map1(%[[Varg2]])
// CHECK-NEXT: %[[V4:.*]] = affine.apply #map2(%[[Varg1]])
// CHECK-NEXT: %[[V5:.*]] = affine.apply #map3(%[[Varg2]])
// CHECK-NEXT: %[[V6:.*]] = affine.apply #map3(%[[Varg2]])
// CHECK-NEXT: %[[V7:.*]] = affine.apply #map2(%[[Varg1]])
// CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V4]], %[[V5]]{{\] \[}}%[[V2]], %[[V3]]{{\] \[1, 1\]}} : tensor<2x10x!FHE.eint<6>> to tensor<?x?x!FHE.eint<6>>
// CHECK-NEXT: %[[Vextracted_slice_0:.*]] = tensor.extract_slice %[[Varg3]]{{\[}}%[[V6]], %[[V7]]{{\] \[}}%[[V3]], %[[V2]]{{\] \[1, 1\]}} : tensor<10x2x!FHE.eint<6>> to tensor<?x?x!FHE.eint<6>>
// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map4, #map5{{\], iterator}}_types = {{\[}}"parallel", "parallel"{{\]}}} ins(%[[Vextracted_slice]] : tensor<?x?x!FHE.eint<6>>) outs(%[[Vextracted_slice_0]] : tensor<?x?x!FHE.eint<6>>) {
// CHECK-NEXT: ^bb0(%[[Vin:.*]]: !FHE.eint<6>, %[[Vout:.*]]: !FHE.eint<6>):
// CHECK-NEXT: linalg.yield %[[Vin]] : !FHE.eint<6>
// CHECK-NEXT: } -> tensor<?x?x!FHE.eint<6>>
// CHECK-NEXT: %[[V9:.*]] = affine.apply #map3(%[[Varg2]])
// CHECK-NEXT: %[[V10:.*]] = affine.apply #map2(%[[Varg1]])
// CHECK-NEXT: scf.forall.in_parallel {
// CHECK-NEXT: tensor.parallel_insert_slice %[[V8]] into %[[Varg3]]{{\[}}%[[V9]], %[[V10]]{{\] \[}}%[[V3]], %[[V2]]{{\] \[1, 1\]}} : tensor<?x?x!FHE.eint<6>> into tensor<10x2x!FHE.eint<6>>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V1]] : tensor<10x2x!FHE.eint<6>>
// CHECK-NEXT: }
// CHECK-NEXT: }

func.func @transpose_eint_2D(%arg0: tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> {
%c = "FHELinalg.transpose"(%arg0) { "tile-sizes" = [3, 8, 2] } : (tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>>
return %c : tensor<10x2x!FHE.eint<6>>
}

0 comments on commit fcfaaee

Please sign in to comment.