Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overflow detection in simulation #777

Merged
merged 5 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
namespace mlir {
namespace concretelang {
/// Create a pass that simulates TFHE operations
std::unique_ptr<OperationPass<ModuleOp>> createSimulateTFHEPass();
std::unique_ptr<OperationPass<ModuleOp>>
createSimulateTFHEPass(bool enableOverflowDetection);
} // namespace concretelang
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@ uint64_t sim_encrypt_lwe_u64(uint64_t message, uint32_t lwe_dim, void *csprng);
/// \return uint64_t
uint64_t sim_neg_lwe_u64(uint64_t plaintext);

/// \brief simulate the addition of a noisy plaintext with another
/// plaintext (noisy or not)
///
/// The function also checks for overflow and print a warning when it happens
///
/// \param lhs left operand
/// \param rhs right operand
/// \param loc location of the operation
/// \param is_signed tell if operands are known to be signed
/// \return uint64_t
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed);

/// \brief simulate the multiplication of a noisy plaintext with an integer
///
/// The function also checks for overflow and print a warning when it happens
///
/// \param lhs left operand
/// \param rhs right operand
/// \param loc location of the operation
/// \param is_signed tell if operands are known to be signed
/// \return uint64_t
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed);

/// \brief simulate a keyswitch on a noisy plaintext
///
/// \param plaintext noisy plaintext
Expand All @@ -49,13 +72,14 @@ uint64_t sim_keyswitch_lwe_u64(uint64_t plaintext, uint32_t level,
/// \param level
/// \param base_log
/// \param glwe_dim
/// \param loc
/// \return uint64_t
uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset,
uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log,
uint32_t glwe_dim);
uint32_t glwe_dim, char *loc);

/// simulate a WoP PBS
void sim_wop_pbs_crt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ mlir::LogicalResult optimizeTFHE(mlir::MLIRContext &context,

mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ inline void forwardOptimizerID(mlir::Operation *source,
destination->setAttr("TFHE.OId", optimizerIdAttr);
}

// Set the `signed` attribute to true if the type is signed
inline void markOpIfSigned(mlir::Operation *op,
FHE::FheIntegerInterface resultType) {
auto isSigned = resultType.isSigned();
if (isSigned) {
op->setAttr("signed", mlir::BoolAttr::get(op->getContext(), true));
}
}

inline void
forwardLinearlyOptimizerIDS(mlir::Operation &source,
std::vector<mlir::Value> &destinations) {
Expand Down Expand Up @@ -185,6 +194,29 @@ struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
encodedInt);
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
}
};

/// Rewriter for the `FHE::add_eint` operation.
struct AddEintOpPattern : public mlir::OpConversionPattern<FHE::AddEintOp> {
AddEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpConversionPattern<FHE::AddEintOp>(converter, context, benefit) {
}

mlir::LogicalResult
matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

// Write the new op
auto newOp = rewriter.replaceOpWithNewOp<TFHE::AddGLWEOp>(
op, getTypeConverter()->convertType(op.getType()),
adaptor.getOperands());
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
}
Expand Down Expand Up @@ -225,6 +257,7 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
encodedInt);
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
};
Expand Down Expand Up @@ -252,6 +285,7 @@ struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
op, getTypeConverter()->convertType(op.getType()), encodedInt,
adaptor.getB());
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
};
Expand Down Expand Up @@ -281,6 +315,7 @@ struct SubEintOpPattern : public ScalarOpPattern<FHE::SubEintOp> {
op, getTypeConverter()->convertType(op.getType()), lhsOperand,
negative.getResult());
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
};
Expand Down Expand Up @@ -310,6 +345,7 @@ struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
op, getTypeConverter()->convertType(op.getType()), eintOperand,
castedCleartext);
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
}
Expand Down Expand Up @@ -804,12 +840,11 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
FHE::NegEintOp, TFHE::NegGLWEOp, true>,
// |_ `FHE::not`
mlir::concretelang::GenericOneToOneOpConversionPattern<
FHE::BoolNotOp, TFHE::NegGLWEOp, true>,
// |_ `FHE::add_eint`
mlir::concretelang::GenericOneToOneOpConversionPattern<
FHE::AddEintOp, TFHE::AddGLWEOp, true>>(&getContext(), converter);
FHE::BoolNotOp, TFHE::NegGLWEOp, true>>(&getContext(), converter);
// |_ `FHE::add_eint_int`
patterns.add<lowering::AddEintIntOpPattern,
// |_ `FHE::add_eint`
lowering::AddEintOpPattern,
// |_ `FHE::sub_int_eint`
lowering::SubIntEintOpPattern,
// |_ `FHE::sub_eint_int`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,111 @@ struct NegOpPattern : public mlir::OpConversionPattern<TFHE::NegGLWEOp> {
}
};

int locationStringCtr = 0;
mlir::Value globalStringValueFromLoc(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc) {
std::string locString;
auto ros = llvm::raw_string_ostream(locString);
loc.print(ros);
locString.append("\0");
auto locStrWithNullByte =
llvm::StringRef(locString.c_str(), locString.size() + 1);

std::stringstream msgName;
msgName << "str_loc_" << locationStringCtr++;
return mlir::LLVM::createGlobalString(
loc, rewriter, msgName.str(), locStrWithNullByte,
mlir::LLVM::linkage::Linkage::Linkonce, false);
}

template <typename AddOp, typename AddOpAdaptor>
struct AddOpPattern : public mlir::OpConversionPattern<AddOp> {

AddOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<AddOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}

::mlir::LogicalResult
matchAndRewrite(AddOp addOp, AddOpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

const std::string funcName = "sim_add_lwe_u64";

auto locString = globalStringValueFromLoc(rewriter, addOp.getLoc());
// check if operation has been tagged as signed
auto isSigned = false;
mlir::Attribute signedAttr = adaptor.getAttributes().get("signed");
if (signedAttr && signedAttr.cast<mlir::BoolAttr>().getValue()) {
isSigned = true;
}
mlir::Value isSignedCst = rewriter.create<mlir::arith::ConstantIntOp>(
addOp.getLoc(), isSigned, 1);

if (insertForwardDeclaration(
addOp, rewriter, funcName,
youben11 marked this conversation as resolved.
Show resolved Hide resolved
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(64),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getIntegerType(1)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange(
{adaptor.getA(), adaptor.getB(), locString, isSignedCst}));

return mlir::success();
}
};

struct MulOpPattern : public mlir::OpConversionPattern<TFHE::MulGLWEIntOp> {

MulOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::MulGLWEIntOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}

::mlir::LogicalResult
matchAndRewrite(TFHE::MulGLWEIntOp mulOp, TFHE::MulGLWEIntOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

const std::string funcName = "sim_mul_lwe_u64";

auto locString = globalStringValueFromLoc(rewriter, mulOp.getLoc());

// check if operation has been tagged as signed
auto isSigned = false;
mlir::Attribute signedAttr = adaptor.getAttributes().get("signed");
if (signedAttr && signedAttr.cast<mlir::BoolAttr>().getValue()) {
isSigned = true;
}
mlir::Value isSignedCst = rewriter.create<mlir::arith::ConstantIntOp>(
mulOp.getLoc(), isSigned, 1);

if (insertForwardDeclaration(
mulOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(64),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getIntegerType(1)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange(
{adaptor.getA(), adaptor.getB(), locString, isSignedCst}));

return mlir::success();
}
};

struct SubIntGLWEOpPattern : public mlir::OpRewritePattern<TFHE::SubGLWEIntOp> {

SubIntGLWEOpPattern(mlir::MLIRContext *context)
Expand All @@ -104,8 +209,10 @@ struct SubIntGLWEOpPattern : public mlir::OpRewritePattern<TFHE::SubGLWEIntOp> {
mlir::Value negated = rewriter.create<TFHE::NegGLWEOp>(
subOp.getLoc(), subOp.getB().getType(), subOp.getB());

rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(subOp, subOp.getType(),
negated, subOp.getA());
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
subOp, subOp.getType(), mlir::ValueRange({negated, subOp.getA()}),
// to forward the signed attr if set
subOp.getOperation()->getAttrs());

return mlir::success();
}
Expand Down Expand Up @@ -439,6 +546,8 @@ struct BootstrapGLWEOpPattern
mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
bsOp.getLoc(), dynamicLutType, adaptor.getLookupTable());

auto locString = globalStringValueFromLoc(rewriter, bsOp.getLoc());

// uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t
// *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t
// tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t
Expand All @@ -449,7 +558,8 @@ struct BootstrapGLWEOpPattern
{rewriter.getIntegerType(64), dynamicLutType,
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32)},
rewriter.getIntegerType(32),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
Expand All @@ -459,7 +569,7 @@ struct BootstrapGLWEOpPattern
bsOp, funcName, this->getTypeConverter()->convertType(resultType),
mlir::ValueRange({adaptor.getCiphertext(), castedLUT,
inputLweDimensionCst, polySizeCst, levelsCst,
baseLogCst, glweDimensionCst}));
baseLogCst, glweDimensionCst, locString}));

return mlir::success();
}
Expand Down Expand Up @@ -561,6 +671,10 @@ struct ZeroTensorOpPattern
};

struct SimulateTFHEPass : public SimulateTFHEBase<SimulateTFHEPass> {
bool enableOverflowDetection;
SimulateTFHEPass(bool enableOverflowDetection)
: enableOverflowDetection(enableOverflowDetection) {}

void runOnOperation() final;
};

Expand All @@ -573,20 +687,13 @@ void SimulateTFHEPass::runOnOperation() {
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::func::CallOp, mlir::memref::GetGlobalOp,
mlir::memref::CastOp, mlir::bufferization::AllocTensorOp,
mlir::tensor::CastOp>();
mlir::tensor::CastOp, mlir::LLVM::GlobalOp,
mlir::LLVM::AddressOfOp, mlir::LLVM::GEPOp>();
// Make sure that no ops from `TFHE` remain after the lowering
target.addIllegalDialect<TFHE::TFHEDialect>();

mlir::RewritePatternSet patterns(&getContext());

// Replace ops and convert operand and result types
patterns.insert<mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::AddGLWEIntOp, mlir::arith::AddIOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::AddGLWEOp, mlir::arith::AddIOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(),
converter);
// Convert operand and result types
patterns.insert<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::bufferization::AllocTensorOp, true>,
Expand Down Expand Up @@ -647,6 +754,23 @@ void SimulateTFHEPass::runOnOperation() {
converter);
patterns.insert<SubIntGLWEOpPattern>(&getContext());

// if overflow detection is enable, then rewrite to CAPI functions that
// performs the detection, otherwise, rewrite as simple arithmetic ops
if (enableOverflowDetection) {
patterns
.insert<AddOpPattern<TFHE::AddGLWEOp, TFHE::AddGLWEOp::Adaptor>,
AddOpPattern<TFHE::AddGLWEIntOp, TFHE::AddGLWEIntOp::Adaptor>,
MulOpPattern>(&getContext(), converter);
} else {
patterns.insert<mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::AddGLWEIntOp, mlir::arith::AddIOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::AddGLWEOp, mlir::arith::AddIOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(),
converter);
}

patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::func::ReturnOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
Expand Down Expand Up @@ -688,8 +812,9 @@ void SimulateTFHEPass::runOnOperation() {

namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>> createSimulateTFHEPass() {
return std::make_unique<SimulateTFHEPass>();
std::unique_ptr<OperationPass<ModuleOp>>
createSimulateTFHEPass(bool enableOverflowDetection) {
return std::make_unique<SimulateTFHEPass>(enableOverflowDetection);
}
} // namespace concretelang
} // namespace mlir
Loading
Loading