Skip to content

Commit

Permalink
feat(compiler): warn when there is overflow in sim (native encoding)
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Apr 22, 2024
1 parent 80c04d2 commit f4bd732
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
namespace mlir {
namespace concretelang {
/// Create a pass that simulates TFHE operations
std::unique_ptr<OperationPass<ModuleOp>> createSimulateTFHEPass();
std::unique_ptr<OperationPass<ModuleOp>>
createSimulateTFHEPass(bool enableOverflowDetection);
} // namespace concretelang
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context,

mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,10 @@ struct ZeroTensorOpPattern
};

struct SimulateTFHEPass : public SimulateTFHEBase<SimulateTFHEPass> {
bool enableOverflowDetection;
SimulateTFHEPass(bool enableOverflowDetection)
: enableOverflowDetection(enableOverflowDetection) {}

void runOnOperation() final;
};

Expand Down Expand Up @@ -696,12 +700,27 @@ void SimulateTFHEPass::runOnOperation() {
BootstrapGLWEOpPattern, WopPBSGLWEOpPattern,
EncodeExpandLutForBootstrapOpPattern,
EncodeLutForCrtWopPBSOpPattern,
EncodePlaintextWithCrtOpPattern, NegOpPattern,
AddOpPattern<TFHE::AddGLWEOp, TFHE::AddGLWEOp::Adaptor>,
AddOpPattern<TFHE::AddGLWEIntOp, TFHE::AddGLWEIntOp::Adaptor>,
MulOpPattern>(&getContext(), converter);
EncodePlaintextWithCrtOpPattern, NegOpPattern>(&getContext(),
converter);
patterns.insert<SubIntGLWEOpPattern>(&getContext());

// if overflow detection is enable, then rewrite to CAPI functions that
// performs the detection, otherwise, rewrite as simple arithmetic ops
if (enableOverflowDetection) {
patterns
.insert<AddOpPattern<TFHE::AddGLWEOp, TFHE::AddGLWEOp::Adaptor>,
AddOpPattern<TFHE::AddGLWEIntOp, TFHE::AddGLWEIntOp::Adaptor>,
MulOpPattern>(&getContext(), converter);
} else {
patterns.insert<mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::AddGLWEIntOp, mlir::arith::AddIOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::AddGLWEOp, mlir::arith::AddIOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(),
converter);
}

patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::func::ReturnOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
Expand Down Expand Up @@ -743,8 +762,9 @@ void SimulateTFHEPass::runOnOperation() {

namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>> createSimulateTFHEPass() {
return std::make_unique<SimulateTFHEPass>();
std::unique_ptr<OperationPass<ModuleOp>>
createSimulateTFHEPass(bool enableOverflowDetection) {
return std::make_unique<SimulateTFHEPass>(enableOverflowDetection);
}
} // namespace concretelang
} // namespace mlir
22 changes: 19 additions & 3 deletions compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using concretelang::csprng::SoftCSPRNG;

thread_local auto csprng = SoftCSPRNG(0);

const uint64_t UINT63_MAX = UINT64_MAX >> 1;

inline concrete::SecurityCurve *security_curve() {
return concrete::getSecurityCurve(128, concrete::BINARY);
}
Expand Down Expand Up @@ -94,7 +96,11 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated,
double variance = concrete_cpu_variance_blind_rotate(
input_lwe_dim, glwe_dim, poly_size, base_log, level, 64,
mlir::concretelang::optimizer::DEFAULT_FFT_PRECISION, variance_bsk);
return out + gaussian_noise(0, variance);
out = out + gaussian_noise(0, variance);
if (out > UINT63_MAX) {
printf("WARNING: overflow happened during LUT\n");
}
return out;
}

void sim_wop_pbs_crt(
Expand Down Expand Up @@ -183,9 +189,19 @@ void sim_wop_pbs_crt(

uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; }

uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs) { return lhs + rhs; }
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs) {
if (lhs > UINT63_MAX - rhs) {
printf("WARNING: overflow happened during addition in simulation\n");
}
return lhs + rhs;
}

uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs) { return lhs * rhs; }
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs) {
if (rhs != 0 && lhs > UINT63_MAX / rhs) {
printf("WARNING: overflow happened during multiplication in simulation\n");
}
return lhs * rhs;
}

void sim_encode_expand_lut_for_boostrap(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
}

if (options.simulate) {
if (mlir::concretelang::pipeline::simulateTFHE(mlirContext, module,
this->enablePass)
if (mlir::concretelang::pipeline::simulateTFHE(
mlirContext, module, res.fheContext, this->enablePass)
.failed()) {
return StreamStringError("Simulating TFHE failed");
}
Expand Down
17 changes: 15 additions & 2 deletions compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,24 @@ mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context,

mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);

// we want to disable overflow detection if CRT is used (overflow would be
// expected)
bool enableOverflowDetection = true;
if (fheContext) {
auto solution = fheContext.value().solution;
auto optCrt = getCrtDecompositionFromSolution(solution);
if (optCrt)
enableOverflowDetection = false;
}

pipelinePrinting("TFHESimulation", pm, context);
addPotentiallyNestedPass(pm, mlir::concretelang::createSimulateTFHEPass(),
enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createSimulateTFHEPass(enableOverflowDetection),
enablePass);

return pm.run(module.getOperation());
}
Expand Down

0 comments on commit f4bd732

Please sign in to comment.