diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index c0fc8cc3ca..25f6ba47da 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1261,13 +1261,14 @@ struct TransposeToLinalgGeneric mlir::Value item = blockArgs[0]; nestedBuilder.create(location, item); }; - mlir::Value result = - rewriter - .create(location, resultTypes, ins, outs, maps, - iteratorTypes, regionBuilder) - .getResult(0); - rewriter.replaceOp(transposeOp, {result}); + linalg::GenericOp genericOp = rewriter.create( + location, resultTypes, ins, outs, maps, iteratorTypes, regionBuilder); + + if (transposeOp->hasAttr("tile-sizes")) + genericOp->setAttr("tile-sizes", transposeOp->getAttr("tile-sizes")); + + rewriter.replaceOp(transposeOp, genericOp.getResults()); return mlir::success(); }; }; diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp index f816e814c5..5f4f403196 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp @@ -196,6 +196,7 @@ class FHELinalgTilingMarkerPass mlir::concretelang::FHELinalg::MatMulIntEintOp, mlir::concretelang::FHELinalg::MatMulEintEintOp>(op) || llvm::isa(op) || + llvm::isa(op) || llvm::isa>) -> tensor<3x8x!FHE.eint<7>> { %0 = "FHELinalg.sum"(%arg0) { axes = [2], "tile-sizes" = [3, 8, 2] } : (tensor<3x8x4x!FHE.eint<7>>) -> tensor<3x8x!FHE.eint<7>> return %0 : tensor<3x8x!FHE.eint<7>> } + +// ----- + +// CHECK: #map = affine_map<(d0) -> (d0 * -3 + 2, 3)> +// CHECK-NEXT: #map1 = affine_map<(d0) -> (d0 * -8 + 10, 8)> +// CHECK-NEXT: #map2 = affine_map<(d0) -> (d0 * 3)> +// CHECK-NEXT: #map3 = affine_map<(d0) -> (d0 * 8)> +// CHECK-NEXT: #map4 = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-NEXT: #map5 = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-NEXT: module { +// CHECK-NEXT: func.func @transpose_eint_2D(%[[Varg0:.*]]: tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> { +// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<10x2x!FHE.eint<6>> +// CHECK-NEXT: %[[V1:.*]] = scf.forall (%[[Varg1]], %[[Varg2]]) in (1, 2) shared_outs(%[[Varg3:.*]] = %[[V0]]) -> (tensor<10x2x!FHE.eint<6>>) { +// CHECK-NEXT: %[[V2:.*]] = affine.min #map(%[[Varg1]]) +// CHECK-NEXT: %[[V3:.*]] = affine.min #map1(%[[Varg2]]) +// CHECK-NEXT: %[[V4:.*]] = affine.apply #map2(%[[Varg1]]) +// CHECK-NEXT: %[[V5:.*]] = affine.apply #map3(%[[Varg2]]) +// CHECK-NEXT: %[[V6:.*]] = affine.apply #map3(%[[Varg2]]) +// CHECK-NEXT: %[[V7:.*]] = affine.apply #map2(%[[Varg1]]) +// CHECK-NEXT: %[[Vextracted_slice:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V4]], %[[V5]]{{\] \[}}%[[V2]], %[[V3]]{{\] \[1, 1\]}} : tensor<2x10x!FHE.eint<6>> to tensor> +// CHECK-NEXT: %[[Vextracted_slice_0:.*]] = tensor.extract_slice %[[Varg3]]{{\[}}%[[V6]], %[[V7]]{{\] \[}}%[[V3]], %[[V2]]{{\] \[1, 1\]}} : tensor<10x2x!FHE.eint<6>> to tensor> +// CHECK-NEXT: %[[V8:.*]] = linalg.generic {indexing_maps = {{\[}}#map4, #map5{{\], iterator}}_types = {{\[}}"parallel", "parallel"{{\]}}} ins(%[[Vextracted_slice]] : tensor>) outs(%[[Vextracted_slice_0]] : tensor>) { +// CHECK-NEXT: ^bb0(%[[Vin:.*]]: !FHE.eint<6>, %[[Vout:.*]]: !FHE.eint<6>): +// CHECK-NEXT: linalg.yield %[[Vin]] : !FHE.eint<6> +// CHECK-NEXT: } -> tensor> +// CHECK-NEXT: %[[V9:.*]] = affine.apply #map3(%[[Varg2]]) +// CHECK-NEXT: %[[V10:.*]] = affine.apply #map2(%[[Varg1]]) +// CHECK-NEXT: scf.forall.in_parallel { +// CHECK-NEXT: tensor.parallel_insert_slice %[[V8]] into %[[Varg3]]{{\[}}%[[V9]], %[[V10]]{{\] \[}}%[[V3]], %[[V2]]{{\] \[1, 1\]}} : tensor> into tensor<10x2x!FHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V1]] : tensor<10x2x!FHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: } + +func.func @transpose_eint_2D(%arg0: tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> { + %c = "FHELinalg.transpose"(%arg0) { "tile-sizes" = [3, 8, 2] } : (tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> + return %c : tensor<10x2x!FHE.eint<6>> +}