From 9beaeac007ba207a748c32faeb3367a09342a354 Mon Sep 17 00:00:00 2001 From: Bourgerie Quentin Date: Thu, 25 Apr 2024 11:42:07 +0200 Subject: [PATCH] feat(compiler): Allow concat with only one operand --- .../concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td | 1 + .../lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp | 11 +++++++++-- .../check_tests/Dialect/FHELinalg/concat.invalid.mlir | 10 +--------- .../tests/check_tests/Dialect/FHELinalg/folding.mlir | 7 +++++++ 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index c0e9550a6a..ab4f87fc62 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -1168,6 +1168,7 @@ def FHELinalg_ConcatOp : FHELinalg_Op<"concat", [Pure]> { ); let hasVerifier = 1; + let hasFolder = 1; } def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", [Pure, BinaryEintInt, DeclareOpInterfaceMethods]> { diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 273ffc4f7f..de929e86d9 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -573,10 +573,17 @@ static bool sameShapeExceptAxis(llvm::ArrayRef shape1, return true; } +OpFoldResult ConcatOp::fold(FoldAdaptor operands) { + if (this->getNumOperands() == 1) { + return this->getOperand(0); + } + return nullptr; +} + mlir::LogicalResult ConcatOp::verify() { unsigned numOperands = this->getNumOperands(); - if (numOperands < 2) { - this->emitOpError() << "should have at least 2 inputs"; + if (numOperands < 1) { + this->emitOpError() << "should have at least 1 input"; return mlir::failure(); } diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/concat.invalid.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/concat.invalid.mlir index d4ab4f5f9b..054dbdc54b 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/concat.invalid.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/concat.invalid.mlir @@ -3,21 +3,13 @@ // ----- func.func @main() -> tensor<0x!FHE.eint<7>> { - // expected-error @+1 {{'FHELinalg.concat' op should have at least 2 inputs}} + // expected-error @+1 {{'FHELinalg.concat' op should have at least 1 input}} %0 = "FHELinalg.concat"() : () -> tensor<0x!FHE.eint<7>> return %0 : tensor<0x!FHE.eint<7>> } // ----- -func.func @main(%x: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { - // expected-error @+1 {{'FHELinalg.concat' op should have at least 2 inputs}} - %0 = "FHELinalg.concat"(%x) : (tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> - return %0 : tensor<4x!FHE.eint<7>> -} - -// ----- - func.func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>> { // expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equal}} %0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir index 23d8147d03..4c1411e1a3 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir @@ -160,3 +160,10 @@ func.func @to_unsigned_zero() -> tensor<4x!FHE.eint<7>> { %1 = "FHELinalg.to_unsigned"(%0) : (tensor<4x!FHE.esint<7>>) -> tensor<4x!FHE.eint<7>> return %1 : tensor<4x!FHE.eint<7>> } + +// CHECK: func.func @concat_1_operand(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { +// CHECK-NEXT: return %[[a0]] +func.func @concat_1_operand(%x: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x) : (tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + return %0 : tensor<4x!FHE.eint<7>> +}