From b8822c8699a63cafcc80da37983fc179492f667d Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 4 Apr 2024 10:38:28 +0100 Subject: [PATCH 1/5] refactor(compiler, simu): rewrite add/mul to CAPI calls --- .../include/concretelang/Runtime/simulation.h | 19 +++++ .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 75 ++++++++++++++++--- .../compiler/lib/Runtime/simulation.cpp | 4 + 3 files changed, 88 insertions(+), 10 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h index 3854258baf..c642ddf240 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h @@ -24,6 +24,25 @@ uint64_t sim_encrypt_lwe_u64(uint64_t message, uint32_t lwe_dim, void *csprng); /// \return uint64_t uint64_t sim_neg_lwe_u64(uint64_t plaintext); +/// \brief simulate the addition of a noisy plaintext with another +/// plaintext (noisy or not) +/// +/// The function also checks for overflow and print a warning when it happens +/// +/// \param lhs left operand +/// \param rhs right operand +/// \return uint64_t +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs); + +/// \brief simulate the multiplication of a noisy plaintext with an integer +/// +/// The function also checks for overflow and print a warning when it happens +/// +/// \param lhs left operand +/// \param rhs right operand +/// \return uint64_t +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs); + /// \brief simulate a keyswitch on a noisy plaintext /// /// \param plaintext noisy plaintext diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp index 211f95c8b2..6ddada67c0 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp @@ -92,6 +92,67 @@ struct NegOpPattern : public mlir::OpConversionPattern { } }; +template +struct AddOpPattern : public mlir::OpConversionPattern { + + AddOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(AddOp addOp, AddOpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + const std::string funcName = "sim_add_lwe_u64"; + + if (insertForwardDeclaration( + addOp, rewriter, funcName, + rewriter.getFunctionType( + {rewriter.getIntegerType(64), rewriter.getIntegerType(64)}, + {rewriter.getIntegerType(64)})) + .failed()) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, + mlir::ValueRange({adaptor.getA(), adaptor.getB()})); + + return mlir::success(); + } +}; + +struct MulOpPattern : public mlir::OpConversionPattern { + + MulOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(TFHE::MulGLWEIntOp mulOp, TFHE::MulGLWEIntOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + const std::string funcName = "sim_mul_lwe_u64"; + + if (insertForwardDeclaration( + mulOp, rewriter, funcName, + rewriter.getFunctionType( + {rewriter.getIntegerType(64), rewriter.getIntegerType(64)}, + {rewriter.getIntegerType(64)})) + .failed()) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, + mlir::ValueRange({adaptor.getA(), adaptor.getB()})); + + return mlir::success(); + } +}; + struct SubIntGLWEOpPattern : public mlir::OpRewritePattern { SubIntGLWEOpPattern(mlir::MLIRContext *context) @@ -579,14 +640,6 @@ void SimulateTFHEPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); - // Replace ops and convert operand and result types - patterns.insert, - mlir::concretelang::GenericOneToOneOpConversionPattern< - TFHE::AddGLWEOp, mlir::arith::AddIOp>, - mlir::concretelang::GenericOneToOneOpConversionPattern< - TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(), - converter); // Convert operand and result types patterns.insert, @@ -643,8 +696,10 @@ void SimulateTFHEPass::runOnOperation() { BootstrapGLWEOpPattern, WopPBSGLWEOpPattern, EncodeExpandLutForBootstrapOpPattern, EncodeLutForCrtWopPBSOpPattern, - EncodePlaintextWithCrtOpPattern, NegOpPattern>(&getContext(), - converter); + EncodePlaintextWithCrtOpPattern, NegOpPattern, + AddOpPattern, + AddOpPattern, + MulOpPattern>(&getContext(), converter); patterns.insert(&getContext()); patterns.add Date: Thu, 4 Apr 2024 12:11:37 +0100 Subject: [PATCH 2/5] feat(compiler): warn when there is overflow in sim (native encoding) --- .../Conversion/SimulateTFHE/Pass.h | 3 +- .../include/concretelang/Support/Pipeline.h | 1 + .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 32 +++++++++++++++---- .../compiler/lib/Runtime/simulation.cpp | 22 +++++++++++-- .../compiler/lib/Support/CompilerEngine.cpp | 4 +-- .../compiler/lib/Support/Pipeline.cpp | 17 ++++++++-- 6 files changed, 65 insertions(+), 14 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/SimulateTFHE/Pass.h b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/SimulateTFHE/Pass.h index 9c91085897..6f8e4d51a7 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/SimulateTFHE/Pass.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/SimulateTFHE/Pass.h @@ -11,7 +11,8 @@ namespace mlir { namespace concretelang { /// Create a pass that simulates TFHE operations -std::unique_ptr> createSimulateTFHEPass(); +std::unique_ptr> +createSimulateTFHEPass(bool enableOverflowDetection); } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h index 49735900d8..ac48e9f188 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h @@ -104,6 +104,7 @@ mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context, mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::optional &fheContext, std::function enablePass); mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context, diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp index 6ddada67c0..f74da3cd22 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp @@ -622,6 +622,10 @@ struct ZeroTensorOpPattern }; struct SimulateTFHEPass : public SimulateTFHEBase { + bool enableOverflowDetection; + SimulateTFHEPass(bool enableOverflowDetection) + : enableOverflowDetection(enableOverflowDetection) {} + void runOnOperation() final; }; @@ -696,12 +700,27 @@ void SimulateTFHEPass::runOnOperation() { BootstrapGLWEOpPattern, WopPBSGLWEOpPattern, EncodeExpandLutForBootstrapOpPattern, EncodeLutForCrtWopPBSOpPattern, - EncodePlaintextWithCrtOpPattern, NegOpPattern, - AddOpPattern, - AddOpPattern, - MulOpPattern>(&getContext(), converter); + EncodePlaintextWithCrtOpPattern, NegOpPattern>(&getContext(), + converter); patterns.insert(&getContext()); + // if overflow detection is enable, then rewrite to CAPI functions that + // performs the detection, otherwise, rewrite as simple arithmetic ops + if (enableOverflowDetection) { + patterns + .insert, + AddOpPattern, + MulOpPattern>(&getContext(), converter); + } else { + patterns.insert, + mlir::concretelang::GenericOneToOneOpConversionPattern< + TFHE::AddGLWEOp, mlir::arith::AddIOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(), + converter); + } + patterns.add, mlir::concretelang::TypeConvertingReinstantiationPattern< @@ -743,8 +762,9 @@ void SimulateTFHEPass::runOnOperation() { namespace mlir { namespace concretelang { -std::unique_ptr> createSimulateTFHEPass() { - return std::make_unique(); +std::unique_ptr> +createSimulateTFHEPass(bool enableOverflowDetection) { + return std::make_unique(enableOverflowDetection); } } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index f89318d535..c74f3ce948 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -18,6 +18,8 @@ using concretelang::csprng::SoftCSPRNG; thread_local auto csprng = SoftCSPRNG(0); +const uint64_t UINT63_MAX = UINT64_MAX >> 1; + inline concrete::SecurityCurve *security_curve() { return concrete::getSecurityCurve(128, concrete::BINARY); } @@ -94,7 +96,11 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, double variance = concrete_cpu_variance_blind_rotate( input_lwe_dim, glwe_dim, poly_size, base_log, level, 64, mlir::concretelang::optimizer::DEFAULT_FFT_PRECISION, variance_bsk); - return out + gaussian_noise(0, variance); + out = out + gaussian_noise(0, variance); + if (out > UINT63_MAX) { + printf("WARNING: overflow happened during LUT\n"); + } + return out; } void sim_wop_pbs_crt( @@ -183,9 +189,19 @@ void sim_wop_pbs_crt( uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; } -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs) { return lhs + rhs; } +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs) { + if (lhs > UINT63_MAX - rhs) { + printf("WARNING: overflow happened during addition in simulation\n"); + } + return lhs + rhs; +} -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs) { return lhs * rhs; } +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs) { + if (rhs != 0 && lhs > UINT63_MAX / rhs) { + printf("WARNING: overflow happened during multiplication in simulation\n"); + } + return lhs * rhs; +} void sim_encode_expand_lut_for_boostrap( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index af29041ff1..669f27ba2d 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -522,8 +522,8 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, } if (options.simulate) { - if (mlir::concretelang::pipeline::simulateTFHE(mlirContext, module, - this->enablePass) + if (mlir::concretelang::pipeline::simulateTFHE( + mlirContext, module, res.fheContext, this->enablePass) .failed()) { return StreamStringError("Simulating TFHE failed"); } diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 1016e1901c..79306b76f4 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -426,11 +426,24 @@ mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context, mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::optional &fheContext, std::function enablePass) { mlir::PassManager pm(&context); + + // we want to disable overflow detection if CRT is used (overflow would be + // expected) + bool enableOverflowDetection = true; + if (fheContext) { + auto solution = fheContext.value().solution; + auto optCrt = getCrtDecompositionFromSolution(solution); + if (optCrt) + enableOverflowDetection = false; + } + pipelinePrinting("TFHESimulation", pm, context); - addPotentiallyNestedPass(pm, mlir::concretelang::createSimulateTFHEPass(), - enablePass); + addPotentiallyNestedPass( + pm, mlir::concretelang::createSimulateTFHEPass(enableOverflowDetection), + enablePass); return pm.run(module.getOperation()); } From 9aed30b8693fa1bfdc1e7e3eb42142d9f8152dc5 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 8 Apr 2024 18:51:31 +0100 Subject: [PATCH 3/5] feat(compiler/simu): add loc in overflow warnings --- .../include/concretelang/Runtime/simulation.h | 9 +++-- .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 39 +++++++++++++++---- .../compiler/lib/Runtime/simulation.cpp | 15 ++++--- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h index c642ddf240..53afd32e13 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h @@ -31,8 +31,9 @@ uint64_t sim_neg_lwe_u64(uint64_t plaintext); /// /// \param lhs left operand /// \param rhs right operand +/// \param loc /// \return uint64_t -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs); +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); /// \brief simulate the multiplication of a noisy plaintext with an integer /// @@ -40,8 +41,9 @@ uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs); /// /// \param lhs left operand /// \param rhs right operand +/// \param loc /// \return uint64_t -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs); +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); /// \brief simulate a keyswitch on a noisy plaintext /// @@ -68,13 +70,14 @@ uint64_t sim_keyswitch_lwe_u64(uint64_t plaintext, uint32_t level, /// \param level /// \param base_log /// \param glwe_dim +/// \param loc /// \return uint64_t uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, - uint32_t glwe_dim); + uint32_t glwe_dim, char *loc); /// simulate a WoP PBS void sim_wop_pbs_crt( diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp index f74da3cd22..035f9a9da8 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp @@ -92,6 +92,21 @@ struct NegOpPattern : public mlir::OpConversionPattern { } }; +mlir::Value globalStringValueFromLoc(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc) { + std::string locString; + auto ros = llvm::raw_string_ostream(locString); + loc.print(ros); + + std::string msgName; + std::stringstream stream; + stream << "loc_" << rand(); + stream >> msgName; + return mlir::LLVM::createGlobalString(loc, rewriter, msgName, locString, + mlir::LLVM::linkage::Linkage::Linkonce, + false); +} + template struct AddOpPattern : public mlir::OpConversionPattern { @@ -106,10 +121,13 @@ struct AddOpPattern : public mlir::OpConversionPattern { const std::string funcName = "sim_add_lwe_u64"; + auto locString = globalStringValueFromLoc(rewriter, addOp.getLoc()); + if (insertForwardDeclaration( addOp, rewriter, funcName, rewriter.getFunctionType( - {rewriter.getIntegerType(64), rewriter.getIntegerType(64)}, + {rewriter.getIntegerType(64), rewriter.getIntegerType(64), + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -117,7 +135,7 @@ struct AddOpPattern : public mlir::OpConversionPattern { rewriter.replaceOpWithNewOp( addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, - mlir::ValueRange({adaptor.getA(), adaptor.getB()})); + mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString})); return mlir::success(); } @@ -136,10 +154,13 @@ struct MulOpPattern : public mlir::OpConversionPattern { const std::string funcName = "sim_mul_lwe_u64"; + auto locString = globalStringValueFromLoc(rewriter, mulOp.getLoc()); + if (insertForwardDeclaration( mulOp, rewriter, funcName, rewriter.getFunctionType( - {rewriter.getIntegerType(64), rewriter.getIntegerType(64)}, + {rewriter.getIntegerType(64), rewriter.getIntegerType(64), + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -147,7 +168,7 @@ struct MulOpPattern : public mlir::OpConversionPattern { rewriter.replaceOpWithNewOp( mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, - mlir::ValueRange({adaptor.getA(), adaptor.getB()})); + mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString})); return mlir::success(); } @@ -500,6 +521,8 @@ struct BootstrapGLWEOpPattern mlir::Value castedLUT = rewriter.create( bsOp.getLoc(), dynamicLutType, adaptor.getLookupTable()); + auto locString = globalStringValueFromLoc(rewriter, bsOp.getLoc()); + // uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t // *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t // tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t @@ -510,7 +533,8 @@ struct BootstrapGLWEOpPattern {rewriter.getIntegerType(64), dynamicLutType, rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), - rewriter.getIntegerType(32)}, + rewriter.getIntegerType(32), + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -520,7 +544,7 @@ struct BootstrapGLWEOpPattern bsOp, funcName, this->getTypeConverter()->convertType(resultType), mlir::ValueRange({adaptor.getCiphertext(), castedLUT, inputLweDimensionCst, polySizeCst, levelsCst, - baseLogCst, glweDimensionCst})); + baseLogCst, glweDimensionCst, locString})); return mlir::success(); } @@ -638,7 +662,8 @@ void SimulateTFHEPass::runOnOperation() { target.addLegalDialect(); target.addLegalOp(); + mlir::tensor::CastOp, mlir::LLVM::GlobalOp, + mlir::LLVM::AddressOfOp, mlir::LLVM::GEPOp>(); // Make sure that no ops from `TFHE` remain after the lowering target.addIllegalDialect(); diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index c74f3ce948..c5922d5490 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -63,7 +63,7 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, uint64_t tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, - uint32_t glwe_dim) { + uint32_t glwe_dim, char *loc) { auto tlu = tlu_aligned + tlu_offset; // modulus switching @@ -98,7 +98,7 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, mlir::concretelang::optimizer::DEFAULT_FFT_PRECISION, variance_bsk); out = out + gaussian_noise(0, variance); if (out > UINT63_MAX) { - printf("WARNING: overflow happened during LUT\n"); + printf("WARNING at %s: overflow happened during LUT\n", loc); } return out; } @@ -189,16 +189,19 @@ void sim_wop_pbs_crt( uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; } -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs) { +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) { if (lhs > UINT63_MAX - rhs) { - printf("WARNING: overflow happened during addition in simulation\n"); + printf("WARNING at %s: overflow happened during addition in simulation\n", + loc); } return lhs + rhs; } -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs) { +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) { if (rhs != 0 && lhs > UINT63_MAX / rhs) { - printf("WARNING: overflow happened during multiplication in simulation\n"); + printf("WARNING at %s: overflow happened during multiplication in " + "simulation\n", + loc); } return lhs * rhs; } From 83cf0853b9f9d9e296f91eaaba9ff665fa147e68 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 9 Apr 2024 16:16:21 +0100 Subject: [PATCH 4/5] test(compiler): overflow in simulation --- .../compiler/lib/Runtime/simulation.cpp | 2 +- .../compiler/tests/python/overflow.py | 27 +++++++ .../compiler/tests/python/test_simulation.py | 70 +++++++++++++++++++ 3 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 compilers/concrete-compiler/compiler/tests/python/overflow.py diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index c5922d5490..f537483141 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -98,7 +98,7 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, mlir::concretelang::optimizer::DEFAULT_FFT_PRECISION, variance_bsk); out = out + gaussian_noise(0, variance); if (out > UINT63_MAX) { - printf("WARNING at %s: overflow happened during LUT\n", loc); + printf("WARNING at %s: overflow happened during LUT in simulation\n", loc); } return out; } diff --git a/compilers/concrete-compiler/compiler/tests/python/overflow.py b/compilers/concrete-compiler/compiler/tests/python/overflow.py new file mode 100644 index 0000000000..b8ba3eaa64 --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/python/overflow.py @@ -0,0 +1,27 @@ +import sys +import shutil +import numpy as np +from concrete.compiler import LibrarySupport +from test_simulation import compile_run_assert + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: overflow.py mlir_file.mlir input_int... expected_out_int") + exit() + + with open(sys.argv[1], "r") as f: + mlir_input = f.read() + + artifact_dir = "./py_test_lib_compile_and_run" + engine = LibrarySupport.new(artifact_dir) + args = list(map(int, sys.argv[2:-1])) + expected_result = int(sys.argv[-1]) + args_and_shape = [] + for arg in args: + if isinstance(arg, int): + args_and_shape.append((arg, None)) + else: # np.array + args_and_shape.append((arg.flatten().tolist(), list(arg.shape))) + compile_run_assert(engine, mlir_input, args_and_shape, expected_result) + shutil.rmtree(artifact_dir) diff --git a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py index 6f6cf73abb..c083d94d2d 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py @@ -1,3 +1,7 @@ +import subprocess +import sys +import os +import tempfile import pytest import shutil import numpy as np @@ -239,3 +243,69 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): args_and_shape.append((arg.flatten().tolist(), list(arg.shape))) compile_run_assert(engine, mlir_input, args_and_shape, expected_result) shutil.rmtree(artifact_dir) + + +end_to_end_overflow_simu_fixture = [ + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (120, 30), + 150, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint_int", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (20, 10), + 200, + b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', + id="mul_eint_int", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { + %tlu = arith.constant dense<[0, 140, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> + %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (1,), + 140, + b'WARNING at loc("-":4:22): overflow happened during LUT in simulation\n', + id="apply_lookup_table", + ), +] + + +@pytest.mark.parametrize( + "mlir_input, args, expected_result, overflow_message", + end_to_end_overflow_simu_fixture, +) +def test_lib_compile_and_run_simulation_with_overflow( + mlir_input, args, expected_result, overflow_message +): + # write mlir to tmp file + mlir_file = tempfile.NamedTemporaryFile("w") + mlir_file.write(mlir_input) + mlir_file.flush() + + # prepare cmd and run + script_path = os.path.join(os.path.dirname(__file__), "overflow.py") + cmd = [sys.executable, script_path, mlir_file.name] + cmd.extend(map(str, args)) + cmd.append(str(expected_result)) + out = subprocess.check_output(cmd, env=os.environ) + + # close/remove tmp file + mlir_file.close() + + assert overflow_message == out From 463a1b3a8326f91442ef94f47e344caa06b52e84 Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 24 Apr 2024 16:17:38 +0100 Subject: [PATCH 5/5] feat(compiler/simu): support signed integers --- .../include/concretelang/Runtime/simulation.h | 10 +- .../FHEToTFHEScalar/FHEToTFHEScalar.cpp | 43 ++- .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 53 +++- .../compiler/lib/Runtime/simulation.cpp | 174 ++++++++++-- .../compiler/lib/Support/Pipeline.cpp | 4 +- .../compiler/tests/python/test_simulation.py | 263 +++++++++++++++++- 6 files changed, 497 insertions(+), 50 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h index 53afd32e13..ca661ef0b2 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h @@ -31,9 +31,10 @@ uint64_t sim_neg_lwe_u64(uint64_t plaintext); /// /// \param lhs left operand /// \param rhs right operand -/// \param loc +/// \param loc location of the operation +/// \param is_signed tell if operands are known to be signed /// \return uint64_t -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed); /// \brief simulate the multiplication of a noisy plaintext with an integer /// @@ -41,9 +42,10 @@ uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); /// /// \param lhs left operand /// \param rhs right operand -/// \param loc +/// \param loc location of the operation +/// \param is_signed tell if operands are known to be signed /// \return uint64_t -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed); /// \brief simulate a keyswitch on a noisy plaintext /// diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index ae88c00c32..34a21ae383 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -57,6 +57,15 @@ inline void forwardOptimizerID(mlir::Operation *source, destination->setAttr("TFHE.OId", optimizerIdAttr); } +// Set the `signed` attribute to true if the type is signed +inline void markOpIfSigned(mlir::Operation *op, + FHE::FheIntegerInterface resultType) { + auto isSigned = resultType.isSigned(); + if (isSigned) { + op->setAttr("signed", mlir::BoolAttr::get(op->getContext(), true)); + } +} + inline void forwardLinearlyOptimizerIDS(mlir::Operation &source, std::vector &destinations) { @@ -185,6 +194,29 @@ struct AddEintIntOpPattern : public ScalarOpPattern { op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), encodedInt); forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); + + return mlir::success(); + } +}; + +/// Rewriter for the `FHE::add_eint` operation. +struct AddEintOpPattern : public mlir::OpConversionPattern { + AddEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpConversionPattern(converter, context, benefit) { + } + + mlir::LogicalResult + matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + // Write the new op + auto newOp = rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + adaptor.getOperands()); + forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); return mlir::success(); } @@ -225,6 +257,7 @@ struct SubEintIntOpPattern : public ScalarOpPattern { op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), encodedInt); forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); return mlir::success(); }; @@ -252,6 +285,7 @@ struct SubIntEintOpPattern : public ScalarOpPattern { op, getTypeConverter()->convertType(op.getType()), encodedInt, adaptor.getB()); forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); return mlir::success(); }; @@ -281,6 +315,7 @@ struct SubEintOpPattern : public ScalarOpPattern { op, getTypeConverter()->convertType(op.getType()), lhsOperand, negative.getResult()); forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); return mlir::success(); }; @@ -310,6 +345,7 @@ struct MulEintIntOpPattern : public ScalarOpPattern { op, getTypeConverter()->convertType(op.getType()), eintOperand, castedCleartext); forwardOptimizerID(op, newOp); + markOpIfSigned(newOp, op.getType().cast()); return mlir::success(); } @@ -804,12 +840,11 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { FHE::NegEintOp, TFHE::NegGLWEOp, true>, // |_ `FHE::not` mlir::concretelang::GenericOneToOneOpConversionPattern< - FHE::BoolNotOp, TFHE::NegGLWEOp, true>, - // |_ `FHE::add_eint` - mlir::concretelang::GenericOneToOneOpConversionPattern< - FHE::AddEintOp, TFHE::AddGLWEOp, true>>(&getContext(), converter); + FHE::BoolNotOp, TFHE::NegGLWEOp, true>>(&getContext(), converter); // |_ `FHE::add_eint_int` patterns.add { } }; +int locationStringCtr = 0; mlir::Value globalStringValueFromLoc(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc) { std::string locString; auto ros = llvm::raw_string_ostream(locString); loc.print(ros); - - std::string msgName; - std::stringstream stream; - stream << "loc_" << rand(); - stream >> msgName; - return mlir::LLVM::createGlobalString(loc, rewriter, msgName, locString, - mlir::LLVM::linkage::Linkage::Linkonce, - false); + locString.append("\0"); + auto locStrWithNullByte = + llvm::StringRef(locString.c_str(), locString.size() + 1); + + std::stringstream msgName; + msgName << "str_loc_" << locationStringCtr++; + return mlir::LLVM::createGlobalString( + loc, rewriter, msgName.str(), locStrWithNullByte, + mlir::LLVM::linkage::Linkage::Linkonce, false); } template @@ -122,12 +124,21 @@ struct AddOpPattern : public mlir::OpConversionPattern { const std::string funcName = "sim_add_lwe_u64"; auto locString = globalStringValueFromLoc(rewriter, addOp.getLoc()); + // check if operation has been tagged as signed + auto isSigned = false; + mlir::Attribute signedAttr = adaptor.getAttributes().get("signed"); + if (signedAttr && signedAttr.cast().getValue()) { + isSigned = true; + } + mlir::Value isSignedCst = rewriter.create( + addOp.getLoc(), isSigned, 1); if (insertForwardDeclaration( addOp, rewriter, funcName, rewriter.getFunctionType( {rewriter.getIntegerType(64), rewriter.getIntegerType(64), - mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getIntegerType(1)}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -135,7 +146,8 @@ struct AddOpPattern : public mlir::OpConversionPattern { rewriter.replaceOpWithNewOp( addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, - mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString})); + mlir::ValueRange( + {adaptor.getA(), adaptor.getB(), locString, isSignedCst})); return mlir::success(); } @@ -156,11 +168,21 @@ struct MulOpPattern : public mlir::OpConversionPattern { auto locString = globalStringValueFromLoc(rewriter, mulOp.getLoc()); + // check if operation has been tagged as signed + auto isSigned = false; + mlir::Attribute signedAttr = adaptor.getAttributes().get("signed"); + if (signedAttr && signedAttr.cast().getValue()) { + isSigned = true; + } + mlir::Value isSignedCst = rewriter.create( + mulOp.getLoc(), isSigned, 1); + if (insertForwardDeclaration( mulOp, rewriter, funcName, rewriter.getFunctionType( {rewriter.getIntegerType(64), rewriter.getIntegerType(64), - mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getIntegerType(1)}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -168,7 +190,8 @@ struct MulOpPattern : public mlir::OpConversionPattern { rewriter.replaceOpWithNewOp( mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, - mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString})); + mlir::ValueRange( + {adaptor.getA(), adaptor.getB(), locString, isSignedCst})); return mlir::success(); } @@ -186,8 +209,10 @@ struct SubIntGLWEOpPattern : public mlir::OpRewritePattern { mlir::Value negated = rewriter.create( subOp.getLoc(), subOp.getB().getType(), subOp.getB()); - rewriter.replaceOpWithNewOp(subOp, subOp.getType(), - negated, subOp.getA()); + rewriter.replaceOpWithNewOp( + subOp, subOp.getType(), mlir::ValueRange({negated, subOp.getA()}), + // to forward the signed attr if set + subOp.getOperation()->getAttrs()); return mlir::success(); } diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index f537483141..6dc43c7334 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -92,14 +92,30 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, else out = -tlu[mod_switched % poly_size]; + // get encoded info from lsb + bool is_signed = (out >> 1) & 1; + bool is_overflow = out & 1; + // discard info bits (2 lsb) + out = out & 18446744073709551612U; + + if (!is_signed && out > UINT63_MAX) { + printf("WARNING at %s: overflow (padding bit) happened during LUT in " + "simulation\n", + loc); + } + if (is_overflow) { + printf("WARNING at %s: overflow (original value didn't fit, so a modulus " + "was applied) happened " + "during LUT in " + "simulation\n", + loc); + } + double variance_bsk = security_curve()->getVariance(glwe_dim, poly_size, 64); double variance = concrete_cpu_variance_blind_rotate( input_lwe_dim, glwe_dim, poly_size, base_log, level, 64, mlir::concretelang::optimizer::DEFAULT_FFT_PRECISION, variance_bsk); out = out + gaussian_noise(0, variance); - if (out > UINT63_MAX) { - printf("WARNING at %s: overflow happened during LUT in simulation\n", loc); - } return out; } @@ -189,33 +205,145 @@ void sim_wop_pbs_crt( uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; } -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) { - if (lhs > UINT63_MAX - rhs) { - printf("WARNING at %s: overflow happened during addition in simulation\n", - loc); +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, + bool is_signed) { + const char msg_f[] = + "WARNING at %s: overflow happened during addition in simulation\n"; + + uint64_t result = lhs + rhs; + + if (is_signed) { + // We shift left to discard the padding bit and only consider the message + // for easier overflow checking + int64_t lhs_signed = (int64_t)lhs << 1; + int64_t rhs_signed = (int64_t)rhs << 1; + if (lhs_signed > 0 && rhs_signed > INT64_MAX - lhs_signed) + printf(msg_f, loc); + else if (lhs_signed < 0 && rhs_signed < INT64_MIN - lhs_signed) + printf(msg_f, loc); + } else if (lhs > UINT63_MAX - rhs || result > UINT63_MAX) { + printf(msg_f, loc); } - return lhs + rhs; + return result; } -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) { - if (rhs != 0 && lhs > UINT63_MAX / rhs) { - printf("WARNING at %s: overflow happened during multiplication in " - "simulation\n", - loc); +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, + bool is_signed) { + const char msg_f[] = + "WARNING at %s: overflow happened during multiplication in simulation\n"; + + uint64_t result = lhs * rhs; + + if (is_signed) { + // We shift left to discard the padding bit and only consider the message + // for easier overflow checking + int64_t lhs_signed = (int64_t)lhs << 1; + int64_t rhs_signed = (int64_t)rhs << 1; + if (lhs_signed != 0 && rhs_signed > INT64_MAX / lhs_signed) + printf(msg_f, loc); + else if (lhs_signed != 0 && rhs_signed < INT64_MIN / lhs_signed) + printf(msg_f, loc); + } else if (rhs != 0 && lhs > UINT63_MAX / rhs) { + printf(msg_f, loc); } - return lhs * rhs; + return result; } +// a copy of memref_encode_expand_lut_for_bootstrap but which encodes overflow +// and sign info into the LUT. Those information should later be discarder by +// the LUT function void sim_encode_expand_lut_for_boostrap( - uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, - uint64_t out_size, uint64_t out_stride, uint64_t *in_allocated, - uint64_t *in_aligned, uint64_t in_offset, uint64_t in_size, - uint64_t in_stride, uint32_t poly_size, uint32_t output_bits, - bool is_signed) { - return memref_encode_expand_lut_for_bootstrap( - out_allocated, out_aligned, out_offset, out_size, out_stride, - in_allocated, in_aligned, in_offset, in_size, in_stride, poly_size, - output_bits, is_signed); + uint64_t *output_lut_allocated, uint64_t *output_lut_aligned, + uint64_t output_lut_offset, uint64_t output_lut_size, + uint64_t output_lut_stride, uint64_t *input_lut_allocated, + uint64_t *input_lut_aligned, uint64_t input_lut_offset, + uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size, + uint32_t out_MESSAGE_BITS, bool is_signed) { + + assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check " + "memref_encode_expand_lut_bootstrap"); + + assert(output_lut_stride == 1 && "Runtime: stride not equal to 1, check " + "memref_encode_expand_lut_bootstrap"); + + size_t mega_case_size = output_lut_size / input_lut_size; + + assert((mega_case_size % 2) == 0); + + // compute overflow bit + std::vector overflow_info(output_lut_size, false); + uint64_t upper_bound = uint64_t(1) + << (out_MESSAGE_BITS + (is_signed ? 1 : 0)); + for (size_t i = 0; i < input_lut_size; i++) { + if (input_lut_aligned[input_lut_offset + i] >= upper_bound) { + overflow_info[i] = true; + } else { + overflow_info[i] = false; + } + } + // used to set the sign bit or not + uint64_t sign_bit_setter = 0; + if (is_signed) { + sign_bit_setter = 2; + } + + // When the bootstrap is executed on encrypted signed integers, the lut must + // be half-rotated. This map takes care about properly indexing into the input + // lut depending on what bootstrap gets executed. + std::function indexMap; + if (is_signed) { + size_t halfInputSize = input_lut_size / 2; + indexMap = [=](size_t idx) { + if (idx < halfInputSize) { + return idx + halfInputSize; + } else { + return idx - halfInputSize; + } + }; + } else { + indexMap = [=](size_t idx) { return idx; }; + } + + // The first lut value should be centered over zero. This means that half of + // it should appear at the beginning of the output lut, and half of it at the + // end (but negated). + for (size_t idx = 0; idx < mega_case_size / 2; ++idx) { + output_lut_aligned[output_lut_offset + idx] = + input_lut_aligned[input_lut_offset + indexMap(0)] + << (64 - out_MESSAGE_BITS - 1); + // set the sign bit + output_lut_aligned[output_lut_offset + idx] |= sign_bit_setter; + // set the overflow bit + output_lut_aligned[output_lut_offset + idx] |= (uint64_t)overflow_info[0]; + } + for (size_t idx = (input_lut_size - 1) * mega_case_size + mega_case_size / 2; + idx < output_lut_size; ++idx) { + output_lut_aligned[output_lut_offset + idx] = + -(input_lut_aligned[input_lut_offset + indexMap(0)] + << (64 - out_MESSAGE_BITS - 1)); + // set the sign bit + output_lut_aligned[output_lut_offset + idx] |= sign_bit_setter; + // set the overflow bit + output_lut_aligned[output_lut_offset + idx] |= + (uint64_t)overflow_info[indexMap(0)]; + } + + // Treats the other ut values. + for (size_t lut_idx = 1; lut_idx < input_lut_size; ++lut_idx) { + uint64_t lut_value = input_lut_aligned[input_lut_offset + indexMap(lut_idx)] + << (64 - out_MESSAGE_BITS - 1); + // set the sign bit + lut_value |= sign_bit_setter; + // set the overflow bit + lut_value |= (uint64_t)overflow_info[indexMap(lut_idx)]; + size_t start = mega_case_size * (lut_idx - 1) + mega_case_size / 2; + for (size_t output_idx = start; output_idx < start + mega_case_size; + ++output_idx) { + output_lut_aligned[output_lut_offset + output_idx] = lut_value; + } + } + + return; } void sim_encode_plaintext_with_crt(uint64_t *output_allocated, diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 79306b76f4..980d4db62e 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -436,8 +436,10 @@ mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context, if (fheContext) { auto solution = fheContext.value().solution; auto optCrt = getCrtDecompositionFromSolution(solution); - if (optCrt) + if (optCrt) { enableOverflowDetection = false; + log_verbose() << "WARNING: overflow detection disabled since using CRT"; + } } pipelinePrinting("TFHESimulation", pm, context); diff --git a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py index c083d94d2d..4755b7a4f9 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py @@ -21,7 +21,7 @@ def assert_result(result, expected_result): """ assert type(expected_result) == type(result) if isinstance(expected_result, int): - assert result == expected_result + assert result == expected_result, f"{result} != {expected_result}" else: assert np.all(result == expected_result) @@ -258,6 +258,186 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="add_eint_int", ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-1, -2), + -3, + b"", + id="add_eint_int_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-60, -20), + -80, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint_int_signed_underflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (60, 20), + -48, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint_int_signed_overflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (81, 73), + 154, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-81, 73), + -8, + b"", + id="add_eint_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-60, -20), + -80, # undefined behavior + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint_signed_underflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (81, 73), + -102, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint_signed_overflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (4, 7), + 256 - 3, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_int", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (4, 7), + -3, + b"", + id="sub_eint_int_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-37, 40), + -77, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_int_signed_underflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (33, -40), + -55, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_int_signed_overflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (11, 18), + 256 - 7, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (11, 18), + -7, + b"", + id="sub_eint_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-44, 32), + -76, # undefined behavior + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_signed_underflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (61, -25), + -42, # undefined behavior + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_signed_overflow", + ), pytest.param( """ func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { @@ -270,18 +450,93 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', id="mul_eint_int", ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + %2 = "FHE.mul_eint_int"(%1, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %2: !FHE.eint<7> + } + """, + (5, 10), + 256 - 50, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\nWARNING at loc("-":4:22): overflow happened during multiplication in simulation\n', + id="sub_mul_eint_int", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (5, -2), + -10, + b"", + id="mul_eint_int_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-33, 5), + -37, # undefined behavior + b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', + id="mul_eint_int_signed_underflow", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-33, -5), + -91, + b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', + id="mul_eint_int_signed_overflow", + ), pytest.param( """ func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { - %tlu = arith.constant dense<[0, 140, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> + %tlu = arith.constant dense<[0, 1420, -2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } """, (1,), 140, - b'WARNING at loc("-":4:22): overflow happened during LUT in simulation\n', - id="apply_lookup_table", + b'WARNING at loc("-":4:22): overflow (padding bit) happened during LUT in simulation\nWARNING at loc("-":4:22): overflow (original value didn\'t fit, so a modulus was applied) happened during LUT in simulation\n', + id="apply_lookup_table_big_value", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> { + %tlu = arith.constant dense<[0, 1400, 254, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> + %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.esint<7>, tensor<128xi64>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (2,), + -2, + b"", + id="apply_lookup_table_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> { + %tlu = arith.constant dense<[0, 1400, -2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> + %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.esint<7>, tensor<128xi64>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (1,), + -8, + b'WARNING at loc("-":4:22): overflow (original value didn\'t fit, so a modulus was applied) happened during LUT in simulation\n', + id="apply_lookup_table_signed_big_value", ), ]