Skip to content

Commit

Permalink
feat(compiler/simu): support signed integers
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Apr 30, 2024
1 parent d5dbf20 commit c2ec1a4
Show file tree
Hide file tree
Showing 6 changed files with 499 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,21 @@ uint64_t sim_neg_lwe_u64(uint64_t plaintext);
///
/// \param lhs left operand
/// \param rhs right operand
/// \param loc
/// \param loc location of the operation
/// \param is_signed tell if operands are known to be signed
/// \return uint64_t
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc);
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed);

/// \brief simulate the multiplication of a noisy plaintext with an integer
///
/// The function also checks for overflow and print a warning when it happens
///
/// \param lhs left operand
/// \param rhs right operand
/// \param loc
/// \param loc location of the operation
/// \param is_signed tell if operands are known to be signed
/// \return uint64_t
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc);
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed);

/// \brief simulate a keyswitch on a noisy plaintext
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ inline void forwardOptimizerID(mlir::Operation *source,
destination->setAttr("TFHE.OId", optimizerIdAttr);
}

// Set the `signed` attribute to true if the type is signed
inline void markOpIfSigned(mlir::Operation *op,
FHE::FheIntegerInterface resultType) {
auto isSigned = resultType.isSigned();
if (isSigned) {
op->setAttr("signed", mlir::BoolAttr::get(op->getContext(), true));
}
}

inline void
forwardLinearlyOptimizerIDS(mlir::Operation &source,
std::vector<mlir::Value> &destinations) {
Expand Down Expand Up @@ -185,6 +194,29 @@ struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
encodedInt);
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
}
};

/// Rewriter for the `FHE::add_eint` operation.
struct AddEintOpPattern : public mlir::OpConversionPattern<FHE::AddEintOp> {
AddEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpConversionPattern<FHE::AddEintOp>(converter, context, benefit) {
}

mlir::LogicalResult
matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

// Write the new op
auto newOp = rewriter.replaceOpWithNewOp<TFHE::AddGLWEOp>(
op, getTypeConverter()->convertType(op.getType()),
adaptor.getOperands());
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
}
Expand Down Expand Up @@ -225,6 +257,7 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
encodedInt);
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
};
Expand Down Expand Up @@ -252,6 +285,7 @@ struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
op, getTypeConverter()->convertType(op.getType()), encodedInt,
adaptor.getB());
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
};
Expand Down Expand Up @@ -281,6 +315,7 @@ struct SubEintOpPattern : public ScalarOpPattern<FHE::SubEintOp> {
op, getTypeConverter()->convertType(op.getType()), lhsOperand,
negative.getResult());
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
};
Expand Down Expand Up @@ -310,6 +345,7 @@ struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
op, getTypeConverter()->convertType(op.getType()), eintOperand,
castedCleartext);
forwardOptimizerID(op, newOp);
markOpIfSigned(newOp, op.getType().cast<FHE::FheIntegerInterface>());

return mlir::success();
}
Expand Down Expand Up @@ -804,12 +840,11 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
FHE::NegEintOp, TFHE::NegGLWEOp, true>,
// |_ `FHE::not`
mlir::concretelang::GenericOneToOneOpConversionPattern<
FHE::BoolNotOp, TFHE::NegGLWEOp, true>,
// |_ `FHE::add_eint`
mlir::concretelang::GenericOneToOneOpConversionPattern<
FHE::AddEintOp, TFHE::AddGLWEOp, true>>(&getContext(), converter);
FHE::BoolNotOp, TFHE::NegGLWEOp, true>>(&getContext(), converter);
// |_ `FHE::add_eint_int`
patterns.add<lowering::AddEintIntOpPattern,
// |_ `FHE::add_eint`
lowering::AddEintOpPattern,
// |_ `FHE::sub_int_eint`
lowering::SubIntEintOpPattern,
// |_ `FHE::sub_eint_int`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,21 @@ struct NegOpPattern : public mlir::OpConversionPattern<TFHE::NegGLWEOp> {
}
};

int locationStringCtr = 0;
mlir::Value globalStringValueFromLoc(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc) {
std::string locString;
auto ros = llvm::raw_string_ostream(locString);
loc.print(ros);

std::string msgName;
std::stringstream stream;
stream << "loc_" << rand();
stream >> msgName;
return mlir::LLVM::createGlobalString(loc, rewriter, msgName, locString,
mlir::LLVM::linkage::Linkage::Linkonce,
false);
locString.append("\0");
auto locStrWithNullByte =
llvm::StringRef(locString.c_str(), locString.size() + 1);

std::stringstream msgName;
msgName << "str_loc_" << locationStringCtr++;
return mlir::LLVM::createGlobalString(
loc, rewriter, msgName.str(), locStrWithNullByte,
mlir::LLVM::linkage::Linkage::Linkonce, false);
}

template <typename AddOp, typename AddOpAdaptor>
Expand All @@ -122,20 +124,30 @@ struct AddOpPattern : public mlir::OpConversionPattern<AddOp> {
const std::string funcName = "sim_add_lwe_u64";

auto locString = globalStringValueFromLoc(rewriter, addOp.getLoc());
// check if operation has been tagged as signed
auto isSigned = false;
mlir::Attribute signedAttr = adaptor.getAttributes().get("signed");
if (signedAttr && signedAttr.cast<mlir::BoolAttr>().getValue()) {
isSigned = true;
}
mlir::Value isSignedCst = rewriter.create<mlir::arith::ConstantIntOp>(
addOp.getLoc(), isSigned, 1);

if (insertForwardDeclaration(
addOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(64),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())},
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getIntegerType(1)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString}));
mlir::ValueRange(
{adaptor.getA(), adaptor.getB(), locString, isSignedCst}));

return mlir::success();
}
Expand All @@ -156,19 +168,30 @@ struct MulOpPattern : public mlir::OpConversionPattern<TFHE::MulGLWEIntOp> {

auto locString = globalStringValueFromLoc(rewriter, mulOp.getLoc());

// check if operation has been tagged as signed
auto isSigned = false;
mlir::Attribute signedAttr = adaptor.getAttributes().get("signed");
if (signedAttr && signedAttr.cast<mlir::BoolAttr>().getValue()) {
isSigned = true;
}
mlir::Value isSignedCst = rewriter.create<mlir::arith::ConstantIntOp>(
mulOp.getLoc(), isSigned, 1);

if (insertForwardDeclaration(
mulOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(64),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())},
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getIntegerType(1)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString}));
mlir::ValueRange(
{adaptor.getA(), adaptor.getB(), locString, isSignedCst}));

return mlir::success();
}
Expand All @@ -186,8 +209,10 @@ struct SubIntGLWEOpPattern : public mlir::OpRewritePattern<TFHE::SubGLWEIntOp> {
mlir::Value negated = rewriter.create<TFHE::NegGLWEOp>(
subOp.getLoc(), subOp.getB().getType(), subOp.getB());

rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(subOp, subOp.getType(),
negated, subOp.getA());
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
subOp, subOp.getType(), mlir::ValueRange({negated, subOp.getA()}),
// to forward the signed attr if set
subOp.getOperation()->getAttrs());

return mlir::success();
}
Expand Down
Loading

0 comments on commit c2ec1a4

Please sign in to comment.