From bd08168b960a4398962e9944dca4f22c691a29fd Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 8 Apr 2024 18:51:31 +0100 Subject: [PATCH] feat(compiler/simu): add loc in overflow warnings --- .../include/concretelang/Runtime/simulation.h | 9 +++-- .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 39 +++++++++++++++---- .../compiler/lib/Runtime/simulation.cpp | 15 ++++--- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h index c642ddf240..53afd32e13 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h @@ -31,8 +31,9 @@ uint64_t sim_neg_lwe_u64(uint64_t plaintext); /// /// \param lhs left operand /// \param rhs right operand +/// \param loc /// \return uint64_t -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs); +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); /// \brief simulate the multiplication of a noisy plaintext with an integer /// @@ -40,8 +41,9 @@ uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs); /// /// \param lhs left operand /// \param rhs right operand +/// \param loc /// \return uint64_t -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs); +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); /// \brief simulate a keyswitch on a noisy plaintext /// @@ -68,13 +70,14 @@ uint64_t sim_keyswitch_lwe_u64(uint64_t plaintext, uint32_t level, /// \param level /// \param base_log /// \param glwe_dim +/// \param loc /// \return uint64_t uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, - uint32_t glwe_dim); + uint32_t glwe_dim, char *loc); /// simulate a WoP PBS void sim_wop_pbs_crt( diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp index f74da3cd22..035f9a9da8 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp @@ -92,6 +92,21 @@ struct NegOpPattern : public mlir::OpConversionPattern { } }; +mlir::Value globalStringValueFromLoc(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc) { + std::string locString; + auto ros = llvm::raw_string_ostream(locString); + loc.print(ros); + + std::string msgName; + std::stringstream stream; + stream << "loc_" << rand(); + stream >> msgName; + return mlir::LLVM::createGlobalString(loc, rewriter, msgName, locString, + mlir::LLVM::linkage::Linkage::Linkonce, + false); +} + template struct AddOpPattern : public mlir::OpConversionPattern { @@ -106,10 +121,13 @@ struct AddOpPattern : public mlir::OpConversionPattern { const std::string funcName = "sim_add_lwe_u64"; + auto locString = globalStringValueFromLoc(rewriter, addOp.getLoc()); + if (insertForwardDeclaration( addOp, rewriter, funcName, rewriter.getFunctionType( - {rewriter.getIntegerType(64), rewriter.getIntegerType(64)}, + {rewriter.getIntegerType(64), rewriter.getIntegerType(64), + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -117,7 +135,7 @@ struct AddOpPattern : public mlir::OpConversionPattern { rewriter.replaceOpWithNewOp( addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, - mlir::ValueRange({adaptor.getA(), adaptor.getB()})); + mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString})); return mlir::success(); } @@ -136,10 +154,13 @@ struct MulOpPattern : public mlir::OpConversionPattern { const std::string funcName = "sim_mul_lwe_u64"; + auto locString = globalStringValueFromLoc(rewriter, mulOp.getLoc()); + if (insertForwardDeclaration( mulOp, rewriter, funcName, rewriter.getFunctionType( - {rewriter.getIntegerType(64), rewriter.getIntegerType(64)}, + {rewriter.getIntegerType(64), rewriter.getIntegerType(64), + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -147,7 +168,7 @@ struct MulOpPattern : public mlir::OpConversionPattern { rewriter.replaceOpWithNewOp( mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, - mlir::ValueRange({adaptor.getA(), adaptor.getB()})); + mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString})); return mlir::success(); } @@ -500,6 +521,8 @@ struct BootstrapGLWEOpPattern mlir::Value castedLUT = rewriter.create( bsOp.getLoc(), dynamicLutType, adaptor.getLookupTable()); + auto locString = globalStringValueFromLoc(rewriter, bsOp.getLoc()); + // uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t // *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t // tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t @@ -510,7 +533,8 @@ struct BootstrapGLWEOpPattern {rewriter.getIntegerType(64), dynamicLutType, rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), - rewriter.getIntegerType(32)}, + rewriter.getIntegerType(32), + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -520,7 +544,7 @@ struct BootstrapGLWEOpPattern bsOp, funcName, this->getTypeConverter()->convertType(resultType), mlir::ValueRange({adaptor.getCiphertext(), castedLUT, inputLweDimensionCst, polySizeCst, levelsCst, - baseLogCst, glweDimensionCst})); + baseLogCst, glweDimensionCst, locString})); return mlir::success(); } @@ -638,7 +662,8 @@ void SimulateTFHEPass::runOnOperation() { target.addLegalDialect(); target.addLegalOp(); + mlir::tensor::CastOp, mlir::LLVM::GlobalOp, + mlir::LLVM::AddressOfOp, mlir::LLVM::GEPOp>(); // Make sure that no ops from `TFHE` remain after the lowering target.addIllegalDialect(); diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index c74f3ce948..c5922d5490 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -63,7 +63,7 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, uint64_t tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, - uint32_t glwe_dim) { + uint32_t glwe_dim, char *loc) { auto tlu = tlu_aligned + tlu_offset; // modulus switching @@ -98,7 +98,7 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, mlir::concretelang::optimizer::DEFAULT_FFT_PRECISION, variance_bsk); out = out + gaussian_noise(0, variance); if (out > UINT63_MAX) { - printf("WARNING: overflow happened during LUT\n"); + printf("WARNING at %s: overflow happened during LUT\n", loc); } return out; } @@ -189,16 +189,19 @@ void sim_wop_pbs_crt( uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; } -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs) { +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) { if (lhs > UINT63_MAX - rhs) { - printf("WARNING: overflow happened during addition in simulation\n"); + printf("WARNING at %s: overflow happened during addition in simulation\n", + loc); } return lhs + rhs; } -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs) { +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) { if (rhs != 0 && lhs > UINT63_MAX / rhs) { - printf("WARNING: overflow happened during multiplication in simulation\n"); + printf("WARNING at %s: overflow happened during multiplication in " + "simulation\n", + loc); } return lhs * rhs; }