From f4bd7322a26cb7df2fafe8986e695bed557b8da2 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 4 Apr 2024 12:11:37 +0100 Subject: [PATCH] feat(compiler): warn when there is overflow in sim (native encoding) --- .../Conversion/SimulateTFHE/Pass.h | 3 +- .../include/concretelang/Support/Pipeline.h | 1 + .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 32 +++++++++++++++---- .../compiler/lib/Runtime/simulation.cpp | 22 +++++++++++-- .../compiler/lib/Support/CompilerEngine.cpp | 4 +-- .../compiler/lib/Support/Pipeline.cpp | 17 ++++++++-- 6 files changed, 65 insertions(+), 14 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/SimulateTFHE/Pass.h b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/SimulateTFHE/Pass.h index 9c91085897..6f8e4d51a7 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/SimulateTFHE/Pass.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/SimulateTFHE/Pass.h @@ -11,7 +11,8 @@ namespace mlir { namespace concretelang { /// Create a pass that simulates TFHE operations -std::unique_ptr> createSimulateTFHEPass(); +std::unique_ptr> +createSimulateTFHEPass(bool enableOverflowDetection); } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h index 0c787c9a01..e7ea47bc51 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h @@ -104,6 +104,7 @@ mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context, mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::optional &fheContext, std::function enablePass); mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context, diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp index 6ddada67c0..f74da3cd22 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp @@ -622,6 +622,10 @@ struct ZeroTensorOpPattern }; struct SimulateTFHEPass : public SimulateTFHEBase { + bool enableOverflowDetection; + SimulateTFHEPass(bool enableOverflowDetection) + : enableOverflowDetection(enableOverflowDetection) {} + void runOnOperation() final; }; @@ -696,12 +700,27 @@ void SimulateTFHEPass::runOnOperation() { BootstrapGLWEOpPattern, WopPBSGLWEOpPattern, EncodeExpandLutForBootstrapOpPattern, EncodeLutForCrtWopPBSOpPattern, - EncodePlaintextWithCrtOpPattern, NegOpPattern, - AddOpPattern, - AddOpPattern, - MulOpPattern>(&getContext(), converter); + EncodePlaintextWithCrtOpPattern, NegOpPattern>(&getContext(), + converter); patterns.insert(&getContext()); + // if overflow detection is enable, then rewrite to CAPI functions that + // performs the detection, otherwise, rewrite as simple arithmetic ops + if (enableOverflowDetection) { + patterns + .insert, + AddOpPattern, + MulOpPattern>(&getContext(), converter); + } else { + patterns.insert, + mlir::concretelang::GenericOneToOneOpConversionPattern< + TFHE::AddGLWEOp, mlir::arith::AddIOp>, + mlir::concretelang::GenericOneToOneOpConversionPattern< + TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(), + converter); + } + patterns.add, mlir::concretelang::TypeConvertingReinstantiationPattern< @@ -743,8 +762,9 @@ void SimulateTFHEPass::runOnOperation() { namespace mlir { namespace concretelang { -std::unique_ptr> createSimulateTFHEPass() { - return std::make_unique(); +std::unique_ptr> +createSimulateTFHEPass(bool enableOverflowDetection) { + return std::make_unique(enableOverflowDetection); } } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index f89318d535..c74f3ce948 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -18,6 +18,8 @@ using concretelang::csprng::SoftCSPRNG; thread_local auto csprng = SoftCSPRNG(0); +const uint64_t UINT63_MAX = UINT64_MAX >> 1; + inline concrete::SecurityCurve *security_curve() { return concrete::getSecurityCurve(128, concrete::BINARY); } @@ -94,7 +96,11 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, double variance = concrete_cpu_variance_blind_rotate( input_lwe_dim, glwe_dim, poly_size, base_log, level, 64, mlir::concretelang::optimizer::DEFAULT_FFT_PRECISION, variance_bsk); - return out + gaussian_noise(0, variance); + out = out + gaussian_noise(0, variance); + if (out > UINT63_MAX) { + printf("WARNING: overflow happened during LUT\n"); + } + return out; } void sim_wop_pbs_crt( @@ -183,9 +189,19 @@ void sim_wop_pbs_crt( uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; } -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs) { return lhs + rhs; } +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs) { + if (lhs > UINT63_MAX - rhs) { + printf("WARNING: overflow happened during addition in simulation\n"); + } + return lhs + rhs; +} -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs) { return lhs * rhs; } +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs) { + if (rhs != 0 && lhs > UINT63_MAX / rhs) { + printf("WARNING: overflow happened during multiplication in simulation\n"); + } + return lhs * rhs; +} void sim_encode_expand_lut_for_boostrap( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index b5a258ee40..3cc6bb46ae 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -544,8 +544,8 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, } if (options.simulate) { - if (mlir::concretelang::pipeline::simulateTFHE(mlirContext, module, - this->enablePass) + if (mlir::concretelang::pipeline::simulateTFHE( + mlirContext, module, res.fheContext, this->enablePass) .failed()) { return StreamStringError("Simulating TFHE failed"); } diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 210d2ea03d..22a1bb08b3 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -430,11 +430,24 @@ mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context, mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::optional &fheContext, std::function enablePass) { mlir::PassManager pm(&context); + + // we want to disable overflow detection if CRT is used (overflow would be + // expected) + bool enableOverflowDetection = true; + if (fheContext) { + auto solution = fheContext.value().solution; + auto optCrt = getCrtDecompositionFromSolution(solution); + if (optCrt) + enableOverflowDetection = false; + } + pipelinePrinting("TFHESimulation", pm, context); - addPotentiallyNestedPass(pm, mlir::concretelang::createSimulateTFHEPass(), - enablePass); + addPotentiallyNestedPass( + pm, mlir::concretelang::createSimulateTFHEPass(enableOverflowDetection), + enablePass); return pm.run(module.getOperation()); }