Skip to content

Commit

Permalink
refactor(compiler, simu): rewrite add/mul to CAPI calls
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Apr 22, 2024
1 parent e238067 commit 80c04d2
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ uint64_t sim_encrypt_lwe_u64(uint64_t message, uint32_t lwe_dim, void *csprng);
/// \return uint64_t
uint64_t sim_neg_lwe_u64(uint64_t plaintext);

/// \brief simulate the addition of a noisy plaintext with another
/// plaintext (noisy or not)
///
/// The function also checks for overflow and print a warning when it happens
///
/// \param lhs left operand
/// \param rhs right operand
/// \return uint64_t
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs);

/// \brief simulate the multiplication of a noisy plaintext with an integer
///
/// The function also checks for overflow and print a warning when it happens
///
/// \param lhs left operand
/// \param rhs right operand
/// \return uint64_t
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs);

/// \brief simulate a keyswitch on a noisy plaintext
///
/// \param plaintext noisy plaintext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,67 @@ struct NegOpPattern : public mlir::OpConversionPattern<TFHE::NegGLWEOp> {
}
};

template <typename AddOp, typename AddOpAdaptor>
struct AddOpPattern : public mlir::OpConversionPattern<AddOp> {

AddOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<AddOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}

::mlir::LogicalResult
matchAndRewrite(AddOp addOp, AddOpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

const std::string funcName = "sim_add_lwe_u64";

if (insertForwardDeclaration(
addOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(64)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange({adaptor.getA(), adaptor.getB()}));

return mlir::success();
}
};

struct MulOpPattern : public mlir::OpConversionPattern<TFHE::MulGLWEIntOp> {

MulOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::MulGLWEIntOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}

::mlir::LogicalResult
matchAndRewrite(TFHE::MulGLWEIntOp mulOp, TFHE::MulGLWEIntOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

const std::string funcName = "sim_mul_lwe_u64";

if (insertForwardDeclaration(
mulOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(64)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange({adaptor.getA(), adaptor.getB()}));

return mlir::success();
}
};

struct SubIntGLWEOpPattern : public mlir::OpRewritePattern<TFHE::SubGLWEIntOp> {

SubIntGLWEOpPattern(mlir::MLIRContext *context)
Expand Down Expand Up @@ -579,14 +640,6 @@ void SimulateTFHEPass::runOnOperation() {

mlir::RewritePatternSet patterns(&getContext());

// Replace ops and convert operand and result types
patterns.insert<mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::AddGLWEIntOp, mlir::arith::AddIOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::AddGLWEOp, mlir::arith::AddIOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(),
converter);
// Convert operand and result types
patterns.insert<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::bufferization::AllocTensorOp, true>,
Expand Down Expand Up @@ -643,8 +696,10 @@ void SimulateTFHEPass::runOnOperation() {
BootstrapGLWEOpPattern, WopPBSGLWEOpPattern,
EncodeExpandLutForBootstrapOpPattern,
EncodeLutForCrtWopPBSOpPattern,
EncodePlaintextWithCrtOpPattern, NegOpPattern>(&getContext(),
converter);
EncodePlaintextWithCrtOpPattern, NegOpPattern,
AddOpPattern<TFHE::AddGLWEOp, TFHE::AddGLWEOp::Adaptor>,
AddOpPattern<TFHE::AddGLWEIntOp, TFHE::AddGLWEIntOp::Adaptor>,
MulOpPattern>(&getContext(), converter);
patterns.insert<SubIntGLWEOpPattern>(&getContext());

patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ void sim_wop_pbs_crt(

uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; }

uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs) { return lhs + rhs; }

uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs) { return lhs * rhs; }

void sim_encode_expand_lut_for_boostrap(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *in_allocated,
Expand Down

0 comments on commit 80c04d2

Please sign in to comment.