Skip to content

Commit

Permalink
feat(compiler/simu): add loc in overflow warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Apr 22, 2024
1 parent f4bd732 commit bd08168
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ uint64_t sim_neg_lwe_u64(uint64_t plaintext);
///
/// \param lhs left operand
/// \param rhs right operand
/// \param loc
/// \return uint64_t
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs);
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc);

/// \brief simulate the multiplication of a noisy plaintext with an integer
///
/// The function also checks for overflow and print a warning when it happens
///
/// \param lhs left operand
/// \param rhs right operand
/// \param loc
/// \return uint64_t
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs);
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc);

/// \brief simulate a keyswitch on a noisy plaintext
///
Expand All @@ -68,13 +70,14 @@ uint64_t sim_keyswitch_lwe_u64(uint64_t plaintext, uint32_t level,
/// \param level
/// \param base_log
/// \param glwe_dim
/// \param loc
/// \return uint64_t
uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset,
uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log,
uint32_t glwe_dim);
uint32_t glwe_dim, char *loc);

/// simulate a WoP PBS
void sim_wop_pbs_crt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ struct NegOpPattern : public mlir::OpConversionPattern<TFHE::NegGLWEOp> {
}
};

mlir::Value globalStringValueFromLoc(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc) {
std::string locString;
auto ros = llvm::raw_string_ostream(locString);
loc.print(ros);

std::string msgName;
std::stringstream stream;
stream << "loc_" << rand();
stream >> msgName;
return mlir::LLVM::createGlobalString(loc, rewriter, msgName, locString,
mlir::LLVM::linkage::Linkage::Linkonce,
false);
}

template <typename AddOp, typename AddOpAdaptor>
struct AddOpPattern : public mlir::OpConversionPattern<AddOp> {

Expand All @@ -106,18 +121,21 @@ struct AddOpPattern : public mlir::OpConversionPattern<AddOp> {

const std::string funcName = "sim_add_lwe_u64";

auto locString = globalStringValueFromLoc(rewriter, addOp.getLoc());

if (insertForwardDeclaration(
addOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(64)},
{rewriter.getIntegerType(64), rewriter.getIntegerType(64),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange({adaptor.getA(), adaptor.getB()}));
mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString}));

return mlir::success();
}
Expand All @@ -136,18 +154,21 @@ struct MulOpPattern : public mlir::OpConversionPattern<TFHE::MulGLWEIntOp> {

const std::string funcName = "sim_mul_lwe_u64";

auto locString = globalStringValueFromLoc(rewriter, mulOp.getLoc());

if (insertForwardDeclaration(
mulOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(64)},
{rewriter.getIntegerType(64), rewriter.getIntegerType(64),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange({adaptor.getA(), adaptor.getB()}));
mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString}));

return mlir::success();
}
Expand Down Expand Up @@ -500,6 +521,8 @@ struct BootstrapGLWEOpPattern
mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
bsOp.getLoc(), dynamicLutType, adaptor.getLookupTable());

auto locString = globalStringValueFromLoc(rewriter, bsOp.getLoc());

// uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t
// *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t
// tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t
Expand All @@ -510,7 +533,8 @@ struct BootstrapGLWEOpPattern
{rewriter.getIntegerType(64), dynamicLutType,
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32)},
rewriter.getIntegerType(32),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
Expand All @@ -520,7 +544,7 @@ struct BootstrapGLWEOpPattern
bsOp, funcName, this->getTypeConverter()->convertType(resultType),
mlir::ValueRange({adaptor.getCiphertext(), castedLUT,
inputLweDimensionCst, polySizeCst, levelsCst,
baseLogCst, glweDimensionCst}));
baseLogCst, glweDimensionCst, locString}));

return mlir::success();
}
Expand Down Expand Up @@ -638,7 +662,8 @@ void SimulateTFHEPass::runOnOperation() {
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::func::CallOp, mlir::memref::GetGlobalOp,
mlir::memref::CastOp, mlir::bufferization::AllocTensorOp,
mlir::tensor::CastOp>();
mlir::tensor::CastOp, mlir::LLVM::GlobalOp,
mlir::LLVM::AddressOfOp, mlir::LLVM::GEPOp>();
// Make sure that no ops from `TFHE` remain after the lowering
target.addIllegalDialect<TFHE::TFHEDialect>();

Expand Down
15 changes: 9 additions & 6 deletions compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated,
uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log,
uint32_t glwe_dim) {
uint32_t glwe_dim, char *loc) {
auto tlu = tlu_aligned + tlu_offset;

// modulus switching
Expand Down Expand Up @@ -98,7 +98,7 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated,
mlir::concretelang::optimizer::DEFAULT_FFT_PRECISION, variance_bsk);
out = out + gaussian_noise(0, variance);
if (out > UINT63_MAX) {
printf("WARNING: overflow happened during LUT\n");
printf("WARNING at %s: overflow happened during LUT\n", loc);
}
return out;
}
Expand Down Expand Up @@ -189,16 +189,19 @@ void sim_wop_pbs_crt(

uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; }

uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs) {
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) {
if (lhs > UINT63_MAX - rhs) {
printf("WARNING: overflow happened during addition in simulation\n");
printf("WARNING at %s: overflow happened during addition in simulation\n",
loc);
}
return lhs + rhs;
}

uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs) {
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) {
if (rhs != 0 && lhs > UINT63_MAX / rhs) {
printf("WARNING: overflow happened during multiplication in simulation\n");
printf("WARNING at %s: overflow happened during multiplication in "
"simulation\n",
loc);
}
return lhs * rhs;
}
Expand Down

0 comments on commit bd08168

Please sign in to comment.