From e9dcc20cc5e492031bdb411d022279af0b2bc326 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 4 Apr 2024 10:38:28 +0100 Subject: [PATCH] refactor(compiler, simu): rewrite add/mul to CAPI calls --- .../include/concretelang/Runtime/simulation.h | 19 +++++ .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 75 ++++++++++++++++--- .../compiler/lib/Runtime/simulation.cpp | 4 + 3 files changed, 88 insertions(+), 10 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h index 3854258baf..c642ddf240 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h @@ -24,6 +24,25 @@ uint64_t sim_encrypt_lwe_u64(uint64_t message, uint32_t lwe_dim, void *csprng); /// \return uint64_t uint64_t sim_neg_lwe_u64(uint64_t plaintext); +/// \brief simulate the addition of a noisy plaintext with another +/// plaintext (noisy or not) +/// +/// The function also checks for overflow and print a warning when it happens +/// +/// \param lhs left operand +/// \param rhs right operand +/// \return uint64_t +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs); + +/// \brief simulate the multiplication of a noisy plaintext with an integer +/// +/// The function also checks for overflow and print a warning when it happens +/// +/// \param lhs left operand +/// \param rhs right operand +/// \return uint64_t +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs); + /// \brief simulate a keyswitch on a noisy plaintext /// /// \param plaintext noisy plaintext diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp index 211f95c8b2..6ddada67c0 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp @@ -92,6 +92,67 @@ struct NegOpPattern : public mlir::OpConversionPattern { } }; +template +struct AddOpPattern : public mlir::OpConversionPattern { + + AddOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(AddOp addOp, AddOpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + const std::string funcName = "sim_add_lwe_u64"; + + if (insertForwardDeclaration( + addOp, rewriter, funcName, + rewriter.getFunctionType( + {rewriter.getIntegerType(64), rewriter.getIntegerType(64)}, + {rewriter.getIntegerType(64)})) + .failed()) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, + mlir::ValueRange({adaptor.getA(), adaptor.getB()})); + + return mlir::success(); + } +}; + +struct MulOpPattern : public mlir::OpConversionPattern { + + MulOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(TFHE::MulGLWEIntOp mulOp, TFHE::MulGLWEIntOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + const std::string funcName = "sim_mul_lwe_u64"; + + if (insertForwardDeclaration( + mulOp, rewriter, funcName, + rewriter.getFunctionType( + {rewriter.getIntegerType(64), rewriter.getIntegerType(64)}, + {rewriter.getIntegerType(64)})) + .failed()) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, + mlir::ValueRange({adaptor.getA(), adaptor.getB()})); + + return mlir::success(); + } +}; + struct SubIntGLWEOpPattern : public mlir::OpRewritePattern { SubIntGLWEOpPattern(mlir::MLIRContext *context) @@ -579,14 +640,6 @@ void SimulateTFHEPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); - // Replace ops and convert operand and result types - patterns.insert, - mlir::concretelang::GenericOneToOneOpConversionPattern< - TFHE::AddGLWEOp, mlir::arith::AddIOp>, - mlir::concretelang::GenericOneToOneOpConversionPattern< - TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(), - converter); // Convert operand and result types patterns.insert, @@ -643,8 +696,10 @@ void SimulateTFHEPass::runOnOperation() { BootstrapGLWEOpPattern, WopPBSGLWEOpPattern, EncodeExpandLutForBootstrapOpPattern, EncodeLutForCrtWopPBSOpPattern, - EncodePlaintextWithCrtOpPattern, NegOpPattern>(&getContext(), - converter); + EncodePlaintextWithCrtOpPattern, NegOpPattern, + AddOpPattern, + AddOpPattern, + MulOpPattern>(&getContext(), converter); patterns.insert(&getContext()); patterns.add