diff --git a/examples/psu/krtw19_psu.cc b/examples/psu/krtw19_psu.cc index 946ceb2d..6ae2873c 100644 --- a/examples/psu/krtw19_psu.cc +++ b/examples/psu/krtw19_psu.cc @@ -80,7 +80,7 @@ auto Interpolate(const std::vector& xs, } } for (size_t k{}; k != size; ++k) { - L_coeffs[k] ^= yacl::GfMul64(Li_coeffs[k], yacl::Inv64(prod)); + L_coeffs[k] ^= yacl::GfMul64(Li_coeffs[k], yacl::GfInv64(prod)); } } return L_coeffs; diff --git a/yacl/base/aligned_vector.h b/yacl/base/aligned_vector.h index 57bcd03f..a2f8dd29 100644 --- a/yacl/base/aligned_vector.h +++ b/yacl/base/aligned_vector.h @@ -28,7 +28,7 @@ namespace yacl { * @tparam ALIGNMENT_IN_BYTES Must be a positive power of 2. */ template -class AlignedAllocator { +class UninitAlignedAllocator { private: static_assert( ALIGNMENT_IN_BYTES >= alignof(ElementType), @@ -40,23 +40,24 @@ class AlignedAllocator { static std::align_val_t constexpr ALIGNMENT{ALIGNMENT_IN_BYTES}; /** - * This is only necessary because AlignedAllocator has a second template + * This is only necessary because UninitAlignedAllocator has a second template * argument for the alignment that will make the default * std::allocator_traits implementation fail during compilation. * @see https://stackoverflow.com/a/48062758/2191065 */ template struct rebind { - using other = AlignedAllocator; + using other = UninitAlignedAllocator; }; - constexpr AlignedAllocator() noexcept = default; + constexpr UninitAlignedAllocator() noexcept = default; - constexpr AlignedAllocator(const AlignedAllocator&) noexcept = default; + constexpr UninitAlignedAllocator(const UninitAlignedAllocator&) noexcept = + default; template - constexpr AlignedAllocator( - AlignedAllocator const&) noexcept {} + constexpr UninitAlignedAllocator( + UninitAlignedAllocator const&) noexcept {} [[nodiscard]] ElementType* allocate(std::size_t nElementsToAllocate) { if (nElementsToAllocate > @@ -77,20 +78,38 @@ class AlignedAllocator { * the one used in new. */ ::operator delete[](allocatedPointer, ALIGNMENT); } + + /* + * unintialised_allocator implementation (avoid meaningless initialization) + * ref: https://stackoverflow.com/a/15966795 + */ + // elide trivial default construction of objects of type ElementType only + template + typename std::enable_if< + std::is_same::value && + std::is_trivially_default_constructible::value>::type + construct(U*) {} + + // elide trivial default destruction of objects of type ElementType only + template + typename std::enable_if::value && + std::is_trivially_destructible::value>::type + destroy(U*) {} }; template -using AlignedVector = std::vector >; +using UninitAlignedVector = + std::vector >; template -bool operator==(AlignedAllocator const& a0, - AlignedAllocator const& a1) { +bool operator==(UninitAlignedAllocator const& a0, + UninitAlignedAllocator const& a1) { return a0.ALIGNMENT == a1.ALIGNMENT; } template -bool operator!=(AlignedAllocator const& a0, - AlignedAllocator const& a1) { +bool operator!=(UninitAlignedAllocator const& a0, + UninitAlignedAllocator const& a1) { return !(a0 == a1); } } // namespace yacl diff --git a/yacl/crypto/tools/BUILD.bazel b/yacl/crypto/tools/BUILD.bazel index 81ad9a00..d855dd0e 100644 --- a/yacl/crypto/tools/BUILD.bazel +++ b/yacl/crypto/tools/BUILD.bazel @@ -20,6 +20,7 @@ yacl_cc_library( name = "common", hdrs = ["common.h"], deps = [ + "//yacl/crypto/rand", "//yacl/link", "//yacl/utils:serialize", ], diff --git a/yacl/crypto/tools/common.h b/yacl/crypto/tools/common.h index 1d178be5..7e379678 100644 --- a/yacl/crypto/tools/common.h +++ b/yacl/crypto/tools/common.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "yacl/crypto/rand/rand.h" #include "yacl/link/link.h" #include "yacl/utils/serialize.h" diff --git a/yacl/kernels/algorithms/BUILD.bazel b/yacl/kernels/algorithms/BUILD.bazel index c02e87b0..aba0e0cb 100644 --- a/yacl/kernels/algorithms/BUILD.bazel +++ b/yacl/kernels/algorithms/BUILD.bazel @@ -234,11 +234,14 @@ yacl_cc_library( hdrs = [ "ferret_ote.h", ], + copts = AES_COPT_FLAGS, deps = [ ":ot_store", "//yacl:secparam", "//yacl/base:exception", + "//yacl/crypto/hash:hash_utils", "//yacl/crypto/rand", + "//yacl/crypto/tools:common", "//yacl/crypto/tools:prg", "//yacl/crypto/tools:rp", "//yacl/kernels/algorithms:gywz_ote", @@ -272,6 +275,7 @@ yacl_cc_library( "//yacl/base:exception", "//yacl/base:int128", "//yacl/crypto/rand", + "//yacl/crypto/tools:common", "//yacl/crypto/tools:crhash", "//yacl/crypto/tools:prg", "//yacl/crypto/tools:rp", @@ -302,6 +306,7 @@ yacl_cc_library( ":sgrr_ote", "//yacl/base:exception", "//yacl/base:int128", + "//yacl/crypto/tools:common", "//yacl/crypto/tools:crhash", "//yacl/crypto/tools:prg", "//yacl/crypto/tools:rp", diff --git a/yacl/kernels/algorithms/base_vole.h b/yacl/kernels/algorithms/base_vole.h index 6198ec8f..84377aec 100644 --- a/yacl/kernels/algorithms/base_vole.h +++ b/yacl/kernels/algorithms/base_vole.h @@ -30,6 +30,9 @@ namespace yacl::crypto { // Convert OT to f2k-VOLE (non-interactive) // the type of ot_store must be COT // w = u * delta + v, where delta = send_ot.delta +// notice: +// - non-interactive method, which means that Ot2Vole could achieve malicious +// secure if send_ot/recv_ot are malicious secure. // usage: // - Vole: u in GF(2^64); w, v, delta in GF(2^64) // Ot2VoleSend / Ot2VoleRecv @@ -44,9 +47,8 @@ void inline Ot2VoleSend(OtSendStore& send_ot, absl::Span w) { YACL_ENFORCE(send_ot.Size() >= size * T_bits); - std::array basis; std::array w_buff; - + std::array basis; if (std::is_same::value) { memcpy(basis.data(), gf128_basis.data(), T_bits * sizeof(K)); } else if (std::is_same::value) { diff --git a/yacl/kernels/algorithms/ferret_ote.cc b/yacl/kernels/algorithms/ferret_ote.cc index 9178eeb4..52768ebf 100644 --- a/yacl/kernels/algorithms/ferret_ote.cc +++ b/yacl/kernels/algorithms/ferret_ote.cc @@ -25,24 +25,8 @@ namespace yacl::crypto { -namespace { - -uint128_t GenSyncedSeed(const std::shared_ptr& ctx) { - YACL_ENFORCE(ctx->WorldSize() == 2); - uint128_t seed; - - if (ctx->Rank() == 0) { - seed = SecureRandSeed(); - ctx->SendAsync(ctx->NextRank(), SerializeUint128(seed), "SEND:Seed"); - } else { - seed = DeserializeUint128(ctx->Recv(ctx->NextRank(), "RECV:Seed")); - } - return seed; -} - -} // namespace - -uint64_t FerretCotHelper(const LpnParam& lpn_param, uint64_t /*ot_num*/) { +uint64_t FerretCotHelper(const LpnParam& lpn_param, uint64_t /*ot_num*/, + bool mal) { uint64_t mpcot_cot = 0; if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) { // for each mpcot invocation, @@ -60,12 +44,15 @@ uint64_t FerretCotHelper(const LpnParam& lpn_param, uint64_t /*ot_num*/) { // The required cots are used as: // (1) expansion seed: kFerret_lpnK // (2) mpcot cot: mp_option.cot_num (just for the first batch) - return lpn_param.k + mpcot_cot; + // (3) mpcot malicious check: 128 (fixed) + uint64_t check_cot = (mal == true) ? 128 : 0; + return lpn_param.k + mpcot_cot + check_cot; } OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, const OtSendStore& base_cot, - const LpnParam& lpn_param, uint64_t ot_num) { + const LpnParam& lpn_param, uint64_t ot_num, + bool mal) { YACL_ENFORCE(ctx->WorldSize() == 2); // Make sure that OT has two parties YACL_ENFORCE(base_cot.Type() == OtStoreType::Compact); YACL_ENFORCE(base_cot.Size() >= FerretCotHelper(lpn_param, ot_num)); @@ -75,7 +62,7 @@ OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, ot_num, 2 * lpn_param.t); // get constants: the number of cot needed for mpcot phase - const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n); + const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n, mal); // get constants: batch information const uint64_t cache_size = lpn_param.k + mpcot_cot_num; @@ -92,11 +79,11 @@ OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, auto working_v = cot_seed.CopyCotBlocks(); // get lpn public matrix A - uint128_t seed = GenSyncedSeed(ctx); + uint128_t seed = SyncSeedSend(ctx); LocalLinearCode<10> llc(seed, lpn_param.n, lpn_param.k); // placeholder for the outputs - AlignedVector out(ot_num); + UninitAlignedVector out(ot_num); auto out_span = absl::MakeSpan(out.data(), out.size()); // For uniform noise assumption only @@ -119,7 +106,8 @@ OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, auto idx_num = lpn_param.t; auto idx_range = batch_ot_num; if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) { - MpCotRNSend(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_s); + MpCotRNSend(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_s, + mal); } else { YACL_THROW("Not Implemented!"); // MpCotUNSend(ctx, cot_mpcot, simple_map, option, working_s); @@ -133,15 +121,15 @@ OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, // update v (first lpn_k of va^s) if ((ot_num - i * batch_size) > batch_ot_num) { // update v for the next batch - for (uint64_t j = 0; j < lpn_param.k; ++j) { - working_v[j] = working_s[batch_size + j]; - } + memcpy(working_v.data(), working_s.data() + batch_size, + lpn_param.k * sizeof(uint128_t)); // manually set the cot for next batch mpcot - cot_mpcot.ResetSlice(); - for (uint64_t j = 0; j < mpcot_cot_num; ++j) { - cot_mpcot.SetCompactBlock(j, working_s[batch_size + lpn_param.k + j]); - } + UninitAlignedVector mpcot_data(mpcot_cot_num); + memcpy(mpcot_data.data(), working_s.data() + batch_size + lpn_param.k, + mpcot_cot_num * sizeof(uint128_t)); + cot_mpcot = + MakeCompactOtSendStore(std::move(mpcot_data), base_cot.GetDelta()); } else { break; } @@ -152,7 +140,8 @@ OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, const OtRecvStore& base_cot, - const LpnParam& lpn_param, uint64_t ot_num) { + const LpnParam& lpn_param, uint64_t ot_num, + bool mal) { YACL_ENFORCE(ctx->WorldSize() == 2); // Make sure that OT has two parties YACL_ENFORCE(base_cot.Type() == OtStoreType::Compact); YACL_ENFORCE(base_cot.Size() >= FerretCotHelper(lpn_param, ot_num)); @@ -162,7 +151,7 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, ot_num, 2 * lpn_param.t); // get constants: the number of cot needed for mpcot phase - const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n); + const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n, mal); // get constants: batch information const uint64_t cache_size = lpn_param.k + mpcot_cot_num; @@ -170,7 +159,7 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, const uint64_t batch_num = (ot_num + batch_size - 1) / batch_size; // F2, but we store it in uint128_t - AlignedVector u(lpn_param.k); + UninitAlignedVector u(lpn_param.k); // prepare u, w, where w = v ^ u * delta // FIX ME: "Slice" would would force to slice original OtStore from "begin" to @@ -181,11 +170,11 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, auto working_w = cot_seed.CopyBlocks(); // get lpn public matrix A - uint128_t seed = GenSyncedSeed(ctx); + uint128_t seed = SyncSeedRecv(ctx); LocalLinearCode<10> llc(seed, lpn_param.n, lpn_param.k); // placeholder for the outputs - AlignedVector out(ot_num); + UninitAlignedVector out(ot_num); auto out_span = absl::MakeSpan(out); // For uniform noise assumption only @@ -210,7 +199,8 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, auto idx_range = batch_ot_num; if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) { - MpCotRNRecv(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_r); + MpCotRNRecv(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_r, + mal); } else { YACL_THROW("Not Implemented!"); // MpCotUNRecv(ctx, cot_mpcot, simple_map, option, e, working_r); @@ -223,15 +213,14 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, // bool is_last_batch = (i == batch_num - 1); if ((ot_num - i * batch_size) > batch_ot_num) { // update u, w (first lpn_k of va^s) - for (uint64_t j = 0; j < lpn_param.k; ++j) { - working_w[j] = working_r[batch_size + j]; - } + memcpy(working_w.data(), working_r.data() + batch_size, + lpn_param.k * sizeof(uint128_t)); // manually set the cot for next batch mpcot - cot_mpcot.ResetSlice(); - for (uint64_t j = 0; j < mpcot_cot_num; ++j) { - cot_mpcot.SetBlock(j, working_r[batch_size + j + lpn_param.k]); - } + UninitAlignedVector mpcot_data(mpcot_cot_num); + memcpy(mpcot_data.data(), working_r.data() + batch_size + lpn_param.k, + mpcot_cot_num * sizeof(uint128_t)); + cot_mpcot = MakeCompactOtRecvStore(std::move(mpcot_data)); } else { break; } @@ -243,13 +232,13 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, void FerretOtExtSend_cheetah(const std::shared_ptr& ctx, const OtSendStore& base_cot, const LpnParam& lpn_param, uint64_t ot_num, - absl::Span out) { + absl::Span out, bool mal) { YACL_ENFORCE(ctx->WorldSize() == 2); // Make sure that OT has two parties YACL_ENFORCE(base_cot.Type() == OtStoreType::Compact); YACL_ENFORCE(base_cot.Size() >= FerretCotHelper(lpn_param, ot_num)); // get constants: the number of cot needed for mpcot phase - const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n); + const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n, mal); // get constants: batch information const uint64_t cache_size = lpn_param.k + mpcot_cot_num; @@ -263,12 +252,12 @@ void FerretOtExtSend_cheetah(const std::shared_ptr& ctx, auto working_v = cot_seed.CopyCotBlocks(); // get lpn public matrix A - uint128_t seed = GenSyncedSeed(ctx); + uint128_t seed = SyncSeedSend(ctx); LocalLinearCode<10> llc(seed, lpn_param.n, lpn_param.k); // placeholder for the outputs YACL_ENFORCE(out.size() == ot_num); - // AlignedVector out(ot_num); + // UninitAlignedVector out(ot_num); auto out_span = out; // For uniform noise assumption only @@ -291,7 +280,8 @@ void FerretOtExtSend_cheetah(const std::shared_ptr& ctx, auto idx_num = lpn_param.t; auto idx_range = batch_ot_num; if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) { - MpCotRNSend(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_s); + MpCotRNSend(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_s, + mal); } else { YACL_THROW("Not Implemented!"); // MpCotUNSend(ctx, cot_mpcot, simple_map, option, working_s); @@ -305,17 +295,15 @@ void FerretOtExtSend_cheetah(const std::shared_ptr& ctx, // update v (first lpn_k of va^s) if ((ot_num - i * batch_size) > batch_ot_num) { // update v for the next batch - // for (uint64_t j = 0; j < lpn_param.k; ++j) { - // working_v[j] = working_s[batch_size + j]; - // } memcpy(working_v.data(), working_s.data() + batch_size, lpn_param.k * sizeof(uint128_t)); // manually set the cot for next batch mpcot - cot_mpcot.ResetSlice(); - for (uint64_t j = 0; j < mpcot_cot_num; ++j) { - cot_mpcot.SetCompactBlock(j, working_s[batch_size + lpn_param.k + j]); - } + UninitAlignedVector mpcot_data(mpcot_cot_num); + memcpy(mpcot_data.data(), working_s.data() + batch_size + lpn_param.k, + mpcot_cot_num * sizeof(uint128_t)); + cot_mpcot = + MakeCompactOtSendStore(std::move(mpcot_data), base_cot.GetDelta()); } else { break; } @@ -327,13 +315,13 @@ void FerretOtExtSend_cheetah(const std::shared_ptr& ctx, void FerretOtExtRecv_cheetah(const std::shared_ptr& ctx, const OtRecvStore& base_cot, const LpnParam& lpn_param, uint64_t ot_num, - absl::Span out) { + absl::Span out, bool mal) { YACL_ENFORCE(ctx->WorldSize() == 2); // Make sure that OT has two parties YACL_ENFORCE(base_cot.Type() == OtStoreType::Compact); YACL_ENFORCE(base_cot.Size() >= FerretCotHelper(lpn_param, ot_num)); // get constants: the number of cot needed for mpcot phase - const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n); + const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n, mal); // get constants: batch information const uint64_t cache_size = lpn_param.k + mpcot_cot_num; @@ -341,7 +329,7 @@ void FerretOtExtRecv_cheetah(const std::shared_ptr& ctx, const uint64_t batch_num = (ot_num + batch_size - 1) / batch_size; // F2, but we store it in uint128_t - AlignedVector u(lpn_param.k); + UninitAlignedVector u(lpn_param.k); // prepare u, w, where w = v ^ u * delta auto cot_seed = base_cot.Slice(0, lpn_param.k); @@ -349,11 +337,11 @@ void FerretOtExtRecv_cheetah(const std::shared_ptr& ctx, auto working_w = cot_seed.CopyBlocks(); // get lpn public matrix A - uint128_t seed = GenSyncedSeed(ctx); + uint128_t seed = SyncSeedRecv(ctx); LocalLinearCode<10> llc(seed, lpn_param.n, lpn_param.k); // placeholder for the outputs - // AlignedVector out(ot_num); + // UninitAlignedVector out(ot_num); YACL_ENFORCE(out.size() == ot_num); auto out_span = out; @@ -379,7 +367,8 @@ void FerretOtExtRecv_cheetah(const std::shared_ptr& ctx, auto idx_range = batch_ot_num; if (lpn_param.noise_asm == LpnNoiseAsm::RegularNoise) { - MpCotRNRecv(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_r); + MpCotRNRecv(ctx, cot_mpcot, idx_range, idx_num, spcot_size, working_r, + mal); } else { YACL_THROW("Not Implemented!"); // MpCotUNRecv(ctx, cot_mpcot, simple_map, option, e, working_r); @@ -392,17 +381,14 @@ void FerretOtExtRecv_cheetah(const std::shared_ptr& ctx, // bool is_last_batch = (i == batch_num - 1); if ((ot_num - i * batch_size) > batch_ot_num) { // update u, w (first lpn_k of va^s) - // for (uint64_t j = 0; j < lpn_param.k; ++j) { - // working_w[j] = working_r[batch_size + j]; - // } memcpy(working_w.data(), working_r.data() + batch_size, lpn_param.k * sizeof(uint128_t)); // manually set the cot for next batch mpcot - cot_mpcot.ResetSlice(); - for (uint64_t j = 0; j < mpcot_cot_num; ++j) { - cot_mpcot.SetBlock(j, working_r[batch_size + j + lpn_param.k]); - } + UninitAlignedVector mpcot_data(mpcot_cot_num); + memcpy(mpcot_data.data(), working_r.data() + batch_size + lpn_param.k, + mpcot_cot_num * sizeof(uint128_t)); + cot_mpcot = MakeCompactOtRecvStore(std::move(mpcot_data)); } else { break; } diff --git a/yacl/kernels/algorithms/ferret_ote.h b/yacl/kernels/algorithms/ferret_ote.h index 3588a0cc..1109f2fe 100644 --- a/yacl/kernels/algorithms/ferret_ote.h +++ b/yacl/kernels/algorithms/ferret_ote.h @@ -57,15 +57,18 @@ namespace yacl::crypto { // implementation, see `yacl/crypto-tools/rp.h` // > Primal LPN, for more details, please see the original paper -uint64_t FerretCotHelper(const LpnParam& lpn_param, uint64_t ot_num); +uint64_t FerretCotHelper(const LpnParam& lpn_param, uint64_t ot_num, + bool mal = false); OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, const OtSendStore& base_cot, - const LpnParam& lpn_param, uint64_t ot_num); + const LpnParam& lpn_param, uint64_t ot_num, + bool mal = false); OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, const OtRecvStore& base_cot, - const LpnParam& lpn_param, uint64_t ot_num); + const LpnParam& lpn_param, uint64_t ot_num, + bool mal = false); // // -------------------------- @@ -76,11 +79,11 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, void FerretOtExtSend_cheetah(const std::shared_ptr& ctx, const OtSendStore& base_cot, const LpnParam& lpn_param, uint64_t ot_num, - absl::Span out); + absl::Span out, bool mal = false); void FerretOtExtRecv_cheetah(const std::shared_ptr& ctx, const OtRecvStore& base_cot, const LpnParam& lpn_param, uint64_t ot_num, - absl::Span out); + absl::Span out, bool mal = false); } // namespace yacl::crypto diff --git a/yacl/kernels/algorithms/ferret_ote_rn.h b/yacl/kernels/algorithms/ferret_ote_rn.h index fe396dda..3866ed48 100644 --- a/yacl/kernels/algorithms/ferret_ote_rn.h +++ b/yacl/kernels/algorithms/ferret_ote_rn.h @@ -18,6 +18,8 @@ #include #include +#include "yacl/crypto/hash/hash_utils.h" +#include "yacl/crypto/tools/common.h" #include "yacl/math/gadget.h" #include "yacl/secparam.h" @@ -29,76 +31,141 @@ YACL_MODULE_DECLARE("ferret_ote_rn", SecParam::C::k128, SecParam::S::INF); namespace yacl::crypto { -inline uint64_t MpCotRNHelper(uint64_t idx_num, uint64_t idx_range) { - const auto batch_size = idx_range / idx_num; +inline uint64_t MpCotRNHelper(uint64_t idx_num, uint64_t idx_range, + bool mal = false) { + const auto batch_size = (idx_range + idx_num - 1) / idx_num; const auto last_size = idx_range - batch_size * (idx_num - 1); - return math::Log2Ceil(batch_size) * (idx_num - 1) + math::Log2Ceil(last_size); + const auto check_size = (mal == true) ? 128 : 0; + return math::Log2Ceil(batch_size) * (idx_num - 1) + + math::Log2Ceil(last_size) + check_size; } inline void MpCotRNSend(const std::shared_ptr& ctx, const OtSendStore& cot, uint64_t idx_range, uint64_t idx_num, uint64_t spcot_size, - absl::Span out) { + absl::Span out, bool mal = false) { const uint64_t full_size = idx_range; const uint64_t batch_size = spcot_size; const uint64_t batch_num = math::DivCeil(full_size, batch_size); YACL_ENFORCE(batch_num <= idx_num); const uint64_t last_size = full_size - (batch_num - 1) * batch_size; + const uint64_t batch_length = math::Log2Ceil(batch_size); + const uint64_t last_length = math::Log2Ceil(last_size); // for each bin, call single-point cot for (uint64_t i = 0; i < batch_num - 1; ++i) { const auto& cot_slice = - cot.Slice(i * math::Log2Ceil(batch_size), - i * math::Log2Ceil(batch_size) + math::Log2Ceil(batch_size)); + cot.Slice(i * batch_length, i * batch_length + batch_length); GywzOtExtSend_ferret(ctx, cot_slice, batch_size, out.subspan(i * batch_size, batch_size)); } // deal with last batch if (last_size == 1) { out[(batch_num - 1) * batch_size] = - cot.GetBlock((batch_num - 1) * math::Log2Ceil(batch_size), 0); + cot.GetBlock((batch_num - 1) * batch_length, 0); } else { const auto& cot_slice = - cot.Slice((batch_num - 1) * math::Log2Ceil(batch_size), - (batch_num - 1) * math::Log2Ceil(batch_size) + - math::Log2Ceil(last_size)); + cot.Slice((batch_num - 1) * batch_length, + (batch_num - 1) * batch_length + last_length); GywzOtExtSend_ferret(ctx, cot_slice, last_size, out.subspan((batch_num - 1) * batch_size, last_size)); } + + if (mal) { + // COT for Consistency check + auto check_cot = + cot.Slice(batch_length * (batch_num - 1) + last_length, + batch_length * (batch_num - 1) + last_length + 128); + + auto seed = SyncSeedSend(ctx); + auto uhash = math::UniversalHash(seed, out); + + auto recv_buf = ctx->Recv(ctx->NextRank(), "FerretCheck: masked choices"); + auto choices = dynamic_bitset(0); + choices.append(DeserializeUint128(recv_buf)); + + std::array check_cot_data; + for (size_t i = 0; i < 128; ++i) { + check_cot_data[i] = check_cot.GetBlock(i, choices[i]); + } + auto diff = PackGf128(absl::MakeSpan(check_cot_data)); + uhash = uhash ^ diff; + + auto hash = Blake3(SerializeUint128(uhash)); + ctx->SendAsync(ctx->NextRank(), ByteContainerView(hash), + "FerretCheck: hash value"); + } } inline void MpCotRNRecv(const std::shared_ptr& ctx, const OtRecvStore& cot, uint64_t idx_range, uint64_t idx_num, uint64_t spcot_size, - absl::Span out) { + absl::Span out, bool mal = false) { const uint64_t full_size = idx_range; const uint64_t batch_size = spcot_size; const uint64_t batch_num = math::DivCeil(full_size, batch_size); YACL_ENFORCE(batch_num <= idx_num); const uint64_t last_size = full_size - (batch_num - 1) * batch_size; + const uint64_t batch_length = math::Log2Ceil(batch_size); + const uint64_t last_length = math::Log2Ceil(last_size); // for each bin, call single-point cot for (uint64_t i = 0; i < batch_num - 1; ++i) { const auto cot_slice = - cot.Slice(i * math::Log2Ceil(batch_size), - i * math::Log2Ceil(batch_size) + math::Log2Ceil(batch_size)); + cot.Slice(i * batch_length, i * batch_length + batch_length); GywzOtExtRecv_ferret(ctx, cot_slice, batch_size, out.subspan(i * batch_size, batch_size)); } // deal with last batch if (last_size == 1) { out[(batch_num - 1) * batch_size] = - cot.GetBlock((batch_num - 1) * math::Log2Ceil(batch_size)); + cot.GetBlock((batch_num - 1) * batch_length); } else { const auto& cot_slice = - cot.Slice((batch_num - 1) * math::Log2Ceil(batch_size), - (batch_num - 1) * math::Log2Ceil(batch_size) + - math::Log2Ceil(last_size)); + cot.Slice((batch_num - 1) * batch_length, + (batch_num - 1) * batch_length + last_length); GywzOtExtRecv_ferret(ctx, cot_slice, last_size, out.subspan((batch_num - 1) * batch_size, last_size)); } + // malicious: consistency check + if (mal) { + // COT for consistency check + auto check_cot = + cot.Slice(batch_length * (batch_num - 1) + last_length, + batch_length * (batch_num - 1) + last_length + 128); + auto seed = SyncSeedRecv(ctx); + auto uhash = math::UniversalHash(seed, out); + + // [Warning] low efficency + uint128_t choices = check_cot.CopyChoice().data()[0]; + + auto check_cot_data = check_cot.CopyBlocks(); + auto diff = PackGf128(absl::MakeSpan(check_cot_data)); + uhash = uhash ^ diff; + + // find punctured indexes + std::vector indexes; + for (size_t i = 0; i < out.size(); ++i) { + if (out[i] & 0x1) { + indexes.push_back(i); + } + } + // extract the coefficent for universal hash + auto ceof = math::ExtractHashCoef(seed, absl::MakeConstSpan(indexes)); + choices = std::accumulate(ceof.cbegin(), ceof.cend(), choices, + std::bit_xor()); + + ctx->SendAsync(ctx->NextRank(), SerializeUint128(choices), + "FerretCheck: masked choices"); + + auto hash = Blake3(SerializeUint128(uhash)); + auto buf = ctx->Recv(ctx->NextRank(), "FerretCheck: hash value"); + + YACL_ENFORCE(ByteContainerView(hash) == ByteContainerView(buf), + "FerretCheck: fail"); + } } } // namespace yacl::crypto diff --git a/yacl/kernels/algorithms/ferret_ote_test.cc b/yacl/kernels/algorithms/ferret_ote_test.cc index c98fdcc9..2660ccdf 100644 --- a/yacl/kernels/algorithms/ferret_ote_test.cc +++ b/yacl/kernels/algorithms/ferret_ote_test.cc @@ -87,8 +87,8 @@ TEST_P(FerretOtExtTest, CheetahWorks) { auto cots_compact = MockCompactOts(cot_num); // mock cots auto delta = cots_compact.send.GetDelta(); - auto send_out = AlignedVector(ot_num); - auto recv_out = AlignedVector(ot_num); + auto send_out = UninitAlignedVector(ot_num); + auto recv_out = UninitAlignedVector(ot_num); // WHEN auto sender = std::async([&] { FerretOtExtSend_cheetah(lctxs[0], cots_compact.send, lpn_param, ot_num, diff --git a/yacl/kernels/algorithms/ferret_ote_un.h b/yacl/kernels/algorithms/ferret_ote_un.h index 80421d04..25121938 100644 --- a/yacl/kernels/algorithms/ferret_ote_un.h +++ b/yacl/kernels/algorithms/ferret_ote_un.h @@ -48,7 +48,7 @@ inline std::unique_ptr MakeSimpleMap( auto out = std::make_unique(bin_num); // get index set {0, 1, ..., n}, and then RP - AlignedVector idx_blocks(n); + UninitAlignedVector idx_blocks(n); std::iota(idx_blocks.begin(), idx_blocks.end(), 0); // random permutation @@ -112,7 +112,7 @@ inline void MpCotUNSend(const std::shared_ptr& ctx, const uint64_t bin_num = cuckoo_option.NumBins(); // for each bin, call single-point cot - AlignedVector> s(bin_num); + UninitAlignedVector> s(bin_num); uint64_t slice_begin = 0; for (uint64_t i = 0; i < bin_num && !simple_map->operator[](i).empty(); ++i) { @@ -148,7 +148,7 @@ inline void MpCotUNRecv(const std::shared_ptr& ctx, const uint64_t bin_num = cuckoo_option.NumBins(); // random permutation - AlignedVector idx_blocks(idxes.begin(), idxes.end()); + UninitAlignedVector idx_blocks(idxes.begin(), idxes.end()); auto idxes_h = kRP.Gen(idx_blocks); CuckooIndex cuckoo_index(cuckoo_option); @@ -156,7 +156,7 @@ inline void MpCotUNRecv(const std::shared_ptr& ctx, // for each (non-empty) cuckoo bin, call single-point c-ot std::fill(out.begin(), out.end(), 0); - AlignedVector> r(bin_num); + UninitAlignedVector> r(bin_num); uint64_t slice_begin = 0; for (uint64_t i = 0; i < bin_num && !simple_map->operator[](i).empty(); ++i) { diff --git a/yacl/kernels/algorithms/gywz_ote.cc b/yacl/kernels/algorithms/gywz_ote.cc index 4ad72116..6d1f4260 100644 --- a/yacl/kernels/algorithms/gywz_ote.cc +++ b/yacl/kernels/algorithms/gywz_ote.cc @@ -35,7 +35,7 @@ void CggmFullEval(uint128_t delta, uint128_t seed, uint32_t n, // if n is power of two, // all_msgs would have enough space to store the leaves bool is_two_power = (n == (static_cast(1) << height)); - AlignedVector extra_buff; + UninitAlignedVector extra_buff; auto& working_seeds = all_msgs; // first level @@ -82,7 +82,7 @@ void CggmPuncFullEval(uint32_t index, absl::Span sibling_sums, uint128_t one = Uint128Max()) { YACL_ENFORCE(punctured_msgs.size() >= n); uint32_t height = sibling_sums.size(); - AlignedVector extra_buff; + UninitAlignedVector extra_buff; // if n is power of two, // punctured_msgs would have enough space to store all leaves @@ -160,7 +160,7 @@ void GywzOtExtRecv(const std::shared_ptr& ctx, // receive punctured seed thought cot auto recv_buf = ctx->Recv(ctx->NextRank(), "GYWZ_OTE: message"); - AlignedVector sibling_sums(height); + UninitAlignedVector sibling_sums(height); memcpy(sibling_sums.data(), recv_buf.data(), recv_buf.size()); for (uint32_t i = 0; i < height; ++i) { sibling_sums[i] ^= cot.GetBlock(i); @@ -178,7 +178,7 @@ void GywzOtExtSend(const std::shared_ptr& ctx, // get delta from cot uint128_t delta = cot.GetDelta(); - AlignedVector left_sums(height); + UninitAlignedVector left_sums(height); uint128_t seed = SecureRandSeed(); CggmFullEval(delta, seed, n, output, absl::MakeSpan(left_sums)); @@ -211,7 +211,7 @@ void GywzOtExtRecv_ferret(const std::shared_ptr& ctx, uint128_t one = MakeUint128(0xffffffffffffffff, 0xfffffffffffffffe); auto recv_buf = ctx->Recv(ctx->NextRank(), "GYWZ_OTE: messages"); - AlignedVector sibling_sums(height); + UninitAlignedVector sibling_sums(height); memcpy(sibling_sums.data(), recv_buf.data(), recv_buf.size()); for (uint32_t i = 0; i < height; ++i) { sibling_sums[i] ^= (cot.GetBlock(i) & one); @@ -239,7 +239,7 @@ void GywzOtExtSend_ferret(const std::shared_ptr& ctx, uint128_t delta = cot.GetDelta() & one; uint128_t seed = SecureRandSeed() & one; - AlignedVector left_sums(height); + UninitAlignedVector left_sums(height); CggmFullEval(delta, seed, n, output, absl::MakeSpan(left_sums), one); for (uint32_t i = 0; i < height; ++i) { @@ -274,7 +274,7 @@ void GywzOtExtSend_fixed_index(const std::shared_ptr& ctx, YACL_ENFORCE(cot.Size() == height); YACL_ENFORCE_GT(n, (uint32_t)1); - AlignedVector left_sums(height); + UninitAlignedVector left_sums(height); GywzOtExtSend_fixed_index(cot, n, output, absl::MakeSpan(left_sums)); ctx->SendAsync( @@ -296,8 +296,8 @@ void GywzOtExtRecv_fixed_index(const OtRecvStore& cot, uint32_t n, index |= (cot.GetChoice(i)) << i; } - AlignedVector sibling_sums(recv_msgs.data(), - recv_msgs.data() + height); + UninitAlignedVector sibling_sums(recv_msgs.data(), + recv_msgs.data() + height); for (uint32_t i = 0; i < height; ++i) { sibling_sums[i] ^= cot.GetBlock(i); } diff --git a/yacl/kernels/algorithms/kos_ote.cc b/yacl/kernels/algorithms/kos_ote.cc index 4fe5499a..136ceeeb 100644 --- a/yacl/kernels/algorithms/kos_ote.cc +++ b/yacl/kernels/algorithms/kos_ote.cc @@ -21,6 +21,7 @@ #include "yacl/base/byte_container_view.h" #include "yacl/base/int128.h" +#include "yacl/crypto/tools/common.h" #include "yacl/math/f2k/f2k.h" #include "yacl/utils/matrix_utils.h" #include "yacl/utils/serialize.h" @@ -156,9 +157,7 @@ void KosOtExtSend(const std::shared_ptr& ctx, std::array q_check{0}; // Sender generates a random seed and sends it to receiver. - uint128_t seed = SecureRandSeed(); - ctx->SendAsync(ctx->NextRank(), SerializeUint128(seed), - fmt::format("KOS-Seed")); + uint128_t seed = SyncSeedSend(ctx); // Generate the coefficent for consistency check std::vector rand_samples(batch_num * 2); PrgAesCtr(seed, absl::MakeSpan(rand_samples)); @@ -263,8 +262,7 @@ void KosOtExtRecv(const std::shared_ptr& ctx, CheckMsg check_msgs; // Recevies the random seed from sender - uint128_t seed = - DeserializeUint128(ctx->Recv(ctx->NextRank(), fmt::format("KOS-Seed"))); + uint128_t seed = SyncSeedRecv(ctx); // Generate coefficent for consistency check std::vector rand_samples(batch_num * 2); PrgAesCtr(seed, absl::Span(rand_samples)); diff --git a/yacl/kernels/algorithms/mp_vole.h b/yacl/kernels/algorithms/mp_vole.h index 17256376..0a31ac78 100644 --- a/yacl/kernels/algorithms/mp_vole.h +++ b/yacl/kernels/algorithms/mp_vole.h @@ -85,6 +85,8 @@ class MpVoleSender { is_finish_ = false; } + // Multi-Point VOLE + // MpVoleSender.Send would set 'c' as a * delta + b. void Send(const std::shared_ptr& ctx, const OtSendStore& /*cot*/ send_ot, absl::Span c, bool fixed_index = false) { @@ -155,6 +157,9 @@ class MpVoleReceiver { is_setup_ = true; } + // Multi-Point VOLE + // MpVoleReceiver.Recv would set 'a' and 'b' + // s.t. c = a * delta + b, where a is t-weight vector. void Recv(const std::shared_ptr& ctx, const OtRecvStore& /*cot*/ recv_ot, absl::Span a, absl::Span b, bool fixed_index = false) { @@ -169,10 +174,13 @@ class MpVoleReceiver { MpfssRecv(ctx, recv_ot, param_, b); } + // reset a + std::memset(a.data(), 0, a.size() * sizeof(T)); std::vector indexes(param_.noise_num_); for (size_t i = 0; i < indexes.size(); ++i) { auto index = i * param_.sp_vole_size_ + param_.indexes_[i]; indexes[i] = index; + // insert base-VOLE value a[index] = pre_a_[i]; b[index] = b[index] ^ pre_b_[i]; } diff --git a/yacl/kernels/algorithms/mpfss.cc b/yacl/kernels/algorithms/mpfss.cc index f6b45bb9..b195581a 100644 --- a/yacl/kernels/algorithms/mpfss.cc +++ b/yacl/kernels/algorithms/mpfss.cc @@ -45,7 +45,7 @@ void MpfssSend(const std::shared_ptr& ctx, const auto& batch_size = param.sp_vole_size_; const auto& last_batch_size = param.last_sp_vole_size_; - AlignedVector send_msgs(batch_num, 0); + UninitAlignedVector send_msgs(batch_num, 0); std::transform(send_msgs.cbegin(), send_msgs.cend(), w.cbegin(), send_msgs.begin(), op.sub); @@ -86,7 +86,7 @@ void MpfssRecv(const std::shared_ptr& ctx, const auto& last_batch_size = param.last_sp_vole_size_; const auto& indexes = param.indexes_; - AlignedVector dpf_sum(batch_num, 0); + UninitAlignedVector dpf_sum(batch_num, 0); for (uint32_t i = 0; i < batch_num; ++i) { auto this_size = (i == batch_num - 1) ? last_batch_size : batch_size; @@ -132,7 +132,7 @@ void MpfssSend(const std::shared_ptr& ctx, const auto& batch_size = param.sp_vole_size_; const auto& last_batch_size = param.last_sp_vole_size_; - AlignedVector send_msgs(batch_num); + UninitAlignedVector send_msgs(batch_num, 0); std::transform(send_msgs.cbegin(), send_msgs.cend(), w.cbegin(), send_msgs.begin(), op.sub); @@ -140,7 +140,8 @@ void MpfssSend(const std::shared_ptr& ctx, Buffer(std::max(batch_size, last_batch_size) * sizeof(uint128_t)); auto dpf_span = absl::MakeSpan(dpf_buff.data(), dpf_buff.size() / sizeof(uint128_t)); - // AlignedVector dpf_buff(std::max(batch_size, last_batch_size)); + // UninitAlignedVector dpf_buff(std::max(batch_size, + // last_batch_size)); for (uint32_t i = 0; i < batch_num; ++i) { auto this_size = (i == batch_num - 1) ? last_batch_size : batch_size; @@ -189,8 +190,8 @@ void MpfssRecv(const std::shared_ptr& ctx, Buffer(std::max(batch_size, last_batch_size) * sizeof(uint128_t)); auto dpf_span = absl::MakeSpan(dpf_buf.data(), dpf_buf.size() / sizeof(uint128_t)); - // AlignedVector dpf_buff(std::max(batch_size, last_batch_size)); - AlignedVector dpf_sum(batch_num, 0); + + UninitAlignedVector dpf_sum(batch_num, 0); for (uint32_t i = 0; i < batch_num; ++i) { auto this_size = (i == batch_num - 1) ? last_batch_size : batch_size; @@ -245,11 +246,11 @@ void MpfssSend_fixed_index(const std::shared_ptr& ctx, const auto last_batch_length = math::Log2Ceil(last_batch_size); // Copy vector w - AlignedVector dpf_sum(batch_num, 0); + UninitAlignedVector dpf_sum(batch_num, 0); std::transform(dpf_sum.cbegin(), dpf_sum.cend(), w.cbegin(), dpf_sum.begin(), op.sub); // send message buff for GYWZ OTe - auto gywz_send_msgs = AlignedVector( + auto gywz_send_msgs = UninitAlignedVector( batch_length * (kSuperBatch - 1) + last_batch_length); const auto super_batch_num = math::DivCeil(batch_num, kSuperBatch); @@ -320,7 +321,7 @@ void MpfssRecv_fixed_index(const std::shared_ptr& ctx, const auto super_batch_num = math::DivCeil(batch_num, kSuperBatch); // Copy vector v - auto dpf_sum = AlignedVector(batch_num, 0); + auto dpf_sum = UninitAlignedVector(batch_num, 0); for (uint32_t s = 0; s < super_batch_num; ++s) { const uint32_t bound = @@ -406,19 +407,19 @@ void MpfssSend_fixed_index(const std::shared_ptr& ctx, const auto last_batch_length = math::Log2Ceil(last_batch_size); // copy w - AlignedVector dpf_sum(batch_num, 0); + UninitAlignedVector dpf_sum(batch_num, 0); std::transform(dpf_sum.cbegin(), dpf_sum.cend(), w.cbegin(), dpf_sum.begin(), op.sub); // GywzOtExt need uint128_t buffer auto dpf_buf = Buffer((1 << std::max(batch_length, last_batch_length)) * sizeof(uint128_t)); // auto dpf_buf = - // AlignedVector(1 << std::max(batch_length, + // UninitAlignedVector(1 << std::max(batch_length, // last_batch_length)); auto dpf_span = absl::MakeSpan(dpf_buf.data(), dpf_buf.size() / sizeof(uint128_t)); // send message buffer for GYWZ OTe - auto gywz_send_msgs = AlignedVector( + auto gywz_send_msgs = UninitAlignedVector( batch_length * (kSuperBatch - 1) + last_batch_length); const auto super_batch_num = math::DivCeil(batch_num, kSuperBatch); @@ -497,12 +498,12 @@ void MpfssRecv_fixed_index(const std::shared_ptr& ctx, const auto super_batch_num = math::DivCeil(batch_num, kSuperBatch); - auto dpf_sum = AlignedVector(batch_num, 0); + auto dpf_sum = UninitAlignedVector(batch_num, 0); // GywzOtExt need uint128_t buffer auto dpf_buf = Buffer((1 << std::max(batch_length, last_batch_length)) * sizeof(uint128_t)); // auto dpf_buf = - // AlignedVector(1 << std::max(batch_length, + // UninitAlignedVector(1 << std::max(batch_length, // last_batch_length)); auto dpf_span = absl::MakeSpan(dpf_buf.data(), dpf_buf.size() / sizeof(uint128_t)); diff --git a/yacl/kernels/algorithms/ot_store.cc b/yacl/kernels/algorithms/ot_store.cc index 4a33fe0a..1086243f 100644 --- a/yacl/kernels/algorithms/ot_store.cc +++ b/yacl/kernels/algorithms/ot_store.cc @@ -93,21 +93,19 @@ OtRecvStore::OtRecvStore(uint64_t num, OtStoreType type) : type_(type) { if (type_ == OtStoreType::Normal) { bit_buf_ = std::make_shared>(num); } - blk_buf_ = std::make_shared>(num); + blk_buf_ = std::make_shared>(num, 0); InitCtrs(0, num, 0, num); ConsistencyCheck(); } -std::unique_ptr OtRecvStore::GetChoiceBuf() { +Buffer OtRecvStore::GetChoiceBuf() { // Constructs Buffer object by copy - return std::make_unique(bit_buf_->data(), - bit_buf_->num_blocks() * sizeof(uint128_t)); + return Buffer(bit_buf_->data(), bit_buf_->num_blocks() * sizeof(uint128_t)); } -std::unique_ptr OtRecvStore::GetBlockBuf() { +Buffer OtRecvStore::GetBlockBuf() { // Constructs Buffer object by copy - return std::make_unique(blk_buf_->data(), - blk_buf_->size() * sizeof(uint128_t)); + return Buffer(blk_buf_->data(), blk_buf_->size() * sizeof(uint128_t)); } void OtRecvStore::Reset() { @@ -204,15 +202,23 @@ void OtRecvStore::FlipChoice(uint64_t idx) { } dynamic_bitset OtRecvStore::CopyChoice() const { - YACL_ENFORCE(type_ == OtStoreType::Normal, - "Copying choice is currently not allowed in compact mode"); + // [Warning] low efficency + if (type_ == OtStoreType::Compact) { + dynamic_bitset out(Size()); + for (size_t i = 0; i < Size(); ++i) { + out[i] = GetBlock(i) & 0x1; + } + return out; + } + // YACL_ENFORCE(type_ == OtStoreType::Normal, + // "Copying choice is currently not allowed in compact mode"); dynamic_bitset out = *bit_buf_; // copy out >>= GetUseCtr(); out.resize(GetUseSize()); return out; } -AlignedVector OtRecvStore::CopyBlocks() const { +UninitAlignedVector OtRecvStore::CopyBlocks() const { return {blk_buf_->begin() + internal_use_ctr_, blk_buf_->begin() + internal_use_ctr_ + internal_use_size_}; } @@ -220,7 +226,8 @@ AlignedVector OtRecvStore::CopyBlocks() const { OtRecvStore MakeOtRecvStore(const dynamic_bitset& choices, const std::vector& blocks) { auto tmp1_ptr = std::make_shared>(choices); // copy - auto tmp2_ptr = std::make_shared>(blocks.size()); + auto tmp2_ptr = + std::make_shared>(blocks.size()); std::memcpy(tmp2_ptr->data(), blocks.data(), blocks.size() * sizeof(uint128_t)); // copy @@ -229,16 +236,18 @@ OtRecvStore MakeOtRecvStore(const dynamic_bitset& choices, } OtRecvStore MakeOtRecvStore(const dynamic_bitset& choices, - const AlignedVector& blocks) { + const UninitAlignedVector& blocks) { auto tmp1_ptr = std::make_shared>(choices); // copy - auto tmp2_ptr = std::make_shared>(blocks); // copy + auto tmp2_ptr = + std::make_shared>(blocks); // copy return {tmp1_ptr, tmp2_ptr, 0, tmp1_ptr->size(), 0, tmp1_ptr->size(), OtStoreType::Normal}; } OtRecvStore MakeCompactOtRecvStore(const std::vector& blocks) { - auto tmp_ptr = std::make_shared>(blocks.size()); + auto tmp_ptr = + std::make_shared>(blocks.size()); std::memcpy(tmp_ptr->data(), blocks.data(), blocks.size() * sizeof(uint128_t)); // copy @@ -251,8 +260,10 @@ OtRecvStore MakeCompactOtRecvStore(const std::vector& blocks) { OtStoreType::Compact}; } -OtRecvStore MakeCompactOtRecvStore(const AlignedVector& blocks) { - auto tmp_ptr = std::make_shared>(blocks); // copy +OtRecvStore MakeCompactOtRecvStore( + const UninitAlignedVector& blocks) { + auto tmp_ptr = + std::make_shared>(blocks); // copy return {nullptr, tmp_ptr, @@ -264,7 +275,8 @@ OtRecvStore MakeCompactOtRecvStore(const AlignedVector& blocks) { } OtRecvStore MakeCompactOtRecvStore(std::vector&& blocks) { - auto tmp_ptr = std::make_shared>(blocks.size()); + auto tmp_ptr = + std::make_shared>(blocks.size()); std::memcpy(tmp_ptr->data(), blocks.data(), blocks.size() * sizeof(uint128_t)); // copy @@ -278,9 +290,9 @@ OtRecvStore MakeCompactOtRecvStore(std::vector&& blocks) { OtStoreType::Compact}; } -OtRecvStore MakeCompactOtRecvStore(AlignedVector&& blocks) { - auto tmp_ptr = - std::make_shared>(std::move(blocks)); // move +OtRecvStore MakeCompactOtRecvStore(UninitAlignedVector&& blocks) { + auto tmp_ptr = std::make_shared>( + std::move(blocks)); // move return {nullptr, tmp_ptr, @@ -309,15 +321,14 @@ OtSendStore::OtSendStore(uint64_t num, OtStoreType type) : type_(type) { buf_size = num * 2; } - blk_buf_ = std::make_shared>(buf_size); + blk_buf_ = std::make_shared>(buf_size, 0); InitCtrs(0, buf_size, 0, buf_size); ConsistencyCheck(); } -std::unique_ptr OtSendStore::GetBlockBuf() { +Buffer OtSendStore::GetBlockBuf() { // Constructs Buffer object by copy - return std::make_unique(blk_buf_->data(), - blk_buf_->size() * sizeof(uint128_t)); + return Buffer(blk_buf_->data(), blk_buf_->size() * sizeof(uint128_t)); } void OtSendStore::Reset() { @@ -432,7 +443,7 @@ void OtSendStore::SetCompactBlock(uint64_t ot_idx, uint128_t val) { blk_buf_->operator[](GetBufIdx(ot_idx)) = val; } -AlignedVector OtSendStore::CopyCotBlocks() const { +UninitAlignedVector OtSendStore::CopyCotBlocks() const { YACL_ENFORCE(type_ == OtStoreType::Compact, "CopyCotBlocks() is only allowed in compact mode"); return {blk_buf_->begin() + internal_buf_ctr_, @@ -442,7 +453,8 @@ AlignedVector OtSendStore::CopyCotBlocks() const { OtSendStore MakeOtSendStore( const std::vector>& blocks) { // warning: copy - auto buf_ptr = std::make_shared>(blocks.size() * 2); + auto buf_ptr = + std::make_shared>(blocks.size() * 2); memcpy(buf_ptr->data(), blocks.data(), buf_ptr->size() * sizeof(uint128_t)); return {buf_ptr, @@ -455,9 +467,10 @@ OtSendStore MakeOtSendStore( } OtSendStore MakeOtSendStore( - const AlignedVector>& blocks) { + const UninitAlignedVector>& blocks) { // warning: copy - auto buf_ptr = std::make_shared>(blocks.size() * 2); + auto buf_ptr = + std::make_shared>(blocks.size() * 2); memcpy(buf_ptr->data(), blocks.data(), buf_ptr->size() * sizeof(uint128_t)); return {buf_ptr, @@ -472,7 +485,8 @@ OtSendStore MakeOtSendStore( OtSendStore MakeCompactOtSendStore(const std::vector& blocks, uint128_t delta) { // warning: copy - auto buf_ptr = std::make_shared>(blocks.size()); + auto buf_ptr = + std::make_shared>(blocks.size()); std::memcpy(buf_ptr->data(), blocks.data(), blocks.size() * sizeof(uint128_t)); // copy @@ -485,10 +499,10 @@ OtSendStore MakeCompactOtSendStore(const std::vector& blocks, OtStoreType::Compact}; } -OtSendStore MakeCompactOtSendStore(const AlignedVector& blocks, +OtSendStore MakeCompactOtSendStore(const UninitAlignedVector& blocks, uint128_t delta) { // warning: copy - auto buf_ptr = std::make_shared>(blocks); + auto buf_ptr = std::make_shared>(blocks); return {buf_ptr, delta, @@ -501,7 +515,8 @@ OtSendStore MakeCompactOtSendStore(const AlignedVector& blocks, OtSendStore MakeCompactOtSendStore(std::vector&& blocks, uint128_t delta) { - auto buf_ptr = std::make_shared>(blocks.size()); + auto buf_ptr = + std::make_shared>(blocks.size()); std::memcpy(buf_ptr->data(), blocks.data(), blocks.size() * sizeof(uint128_t)); // copy @@ -514,10 +529,10 @@ OtSendStore MakeCompactOtSendStore(std::vector&& blocks, OtStoreType::Compact}; } -OtSendStore MakeCompactOtSendStore(AlignedVector&& blocks, +OtSendStore MakeCompactOtSendStore(UninitAlignedVector&& blocks, uint128_t delta) { - auto buf_ptr = - std::make_shared>(std::move(blocks)); // move + auto buf_ptr = std::make_shared>( + std::move(blocks)); // move return {buf_ptr, delta, @@ -535,8 +550,8 @@ MockOtStore MockRots(uint64_t num) { MockOtStore MockRots(uint64_t num, dynamic_bitset choices) { YACL_ENFORCE(choices.size() == num); - AlignedVector recv_blocks; - AlignedVector> send_blocks; + UninitAlignedVector recv_blocks; + UninitAlignedVector> send_blocks; Prg gen(FastRandSeed()); for (uint64_t i = 0; i < num; ++i) { @@ -556,8 +571,8 @@ MockOtStore MockCots(uint64_t num, uint128_t delta) { MockOtStore MockCots(uint64_t num, uint128_t delta, dynamic_bitset choices) { YACL_ENFORCE(choices.size() == num); - AlignedVector recv_blocks; - AlignedVector send_blocks; + UninitAlignedVector recv_blocks; + UninitAlignedVector send_blocks; Prg gen(FastRandSeed()); for (uint64_t i = 0; i < num; ++i) { @@ -577,8 +592,8 @@ MockOtStore MockCots(uint64_t num, uint128_t delta, MockOtStore MockCompactOts(uint64_t num) { uint128_t delta = FastRandU128(); delta |= 0x1; // make sure its last bits = 1; - AlignedVector recv_blocks; - AlignedVector send_blocks; + UninitAlignedVector recv_blocks; + UninitAlignedVector send_blocks; Prg gen(FastRandSeed()); for (uint64_t i = 0; i < num; ++i) { diff --git a/yacl/kernels/algorithms/ot_store.h b/yacl/kernels/algorithms/ot_store.h index e46beb5e..08d2274d 100644 --- a/yacl/kernels/algorithms/ot_store.h +++ b/yacl/kernels/algorithms/ot_store.h @@ -94,7 +94,7 @@ class SliceBase { class OtRecvStore : public SliceBase { public: using BitBufPtr = std::shared_ptr>; - using BlkBufPtr = std::shared_ptr>; + using BlkBufPtr = std::shared_ptr>; // full constructor for ot receiver store OtRecvStore(BitBufPtr bit_ptr, BlkBufPtr blk_ptr, uint64_t use_ctr, @@ -114,10 +114,10 @@ class OtRecvStore : public SliceBase { OtStoreType Type() const { return type_; } // get a buffer copy of choice buf - std::unique_ptr GetChoiceBuf(); + Buffer GetChoiceBuf(); // get a buffer copy of block buf - std::unique_ptr GetBlockBuf(); + Buffer GetBlockBuf(); // reset ot store void Reset(); @@ -144,7 +144,7 @@ class OtRecvStore : public SliceBase { dynamic_bitset CopyChoice() const; // copy out the sliced choice buffer [wanring: low efficiency] - AlignedVector CopyBlocks() const; + UninitAlignedVector CopyBlocks() const; private: // check the consistency of ot receiver store @@ -169,18 +169,19 @@ class OtRecvStore : public SliceBase { // Easier way of generate a ot_store pointer from a given choice buffer and // a block buffer OtRecvStore MakeOtRecvStore(const dynamic_bitset& choices, - const AlignedVector& blocks); + const UninitAlignedVector& blocks); OtRecvStore MakeOtRecvStore(const dynamic_bitset& choices, const std::vector& blocks); // Easier way of generate a compact cot_store pointer from a given block buffer // Note: Compact ot is correlated-ot (or called delta-ot) -OtRecvStore MakeCompactOtRecvStore(const AlignedVector& blocks); +OtRecvStore MakeCompactOtRecvStore( + const UninitAlignedVector& blocks); OtRecvStore MakeCompactOtRecvStore(const std::vector& blocks); -OtRecvStore MakeCompactOtRecvStore(AlignedVector&& blocks); +OtRecvStore MakeCompactOtRecvStore(UninitAlignedVector&& blocks); OtRecvStore MakeCompactOtRecvStore(std::vector&& blocks); @@ -189,7 +190,7 @@ OtRecvStore MakeCompactOtRecvStore(std::vector&& blocks); // Data structure that stores multiple ot sender's data (a.k.a. the ot messages) class OtSendStore : public SliceBase { public: - using BlkBufPtr = std::shared_ptr>; + using BlkBufPtr = std::shared_ptr>; // full constructor for ot receiver store OtSendStore(BlkBufPtr blk_ptr, uint128_t delta, uint64_t use_ctr, @@ -209,7 +210,7 @@ class OtSendStore : public SliceBase { OtStoreType Type() const { return type_; } // get a buffer copy of block buf - std::unique_ptr GetBlockBuf(); + Buffer GetBlockBuf(); // reset ot store void Reset(); @@ -233,7 +234,7 @@ class OtSendStore : public SliceBase { void SetCompactBlock(uint64_t ot_idx, uint128_t val); // copy out cot blocks - AlignedVector CopyCotBlocks() const; + UninitAlignedVector CopyCotBlocks() const; private: // check the consistency of ot receiver store @@ -249,7 +250,7 @@ class OtSendStore : public SliceBase { // Easier way of generate a ot_store pointer from a given blocks buffer OtSendStore MakeOtSendStore( - const AlignedVector>& blocks); + const UninitAlignedVector>& blocks); OtSendStore MakeOtSendStore( const std::vector>& blocks); @@ -260,13 +261,13 @@ OtSendStore MakeOtSendStore( OtSendStore MakeCompactOtSendStore(const std::vector& blocks, uint128_t delta); -OtSendStore MakeCompactOtSendStore(const AlignedVector& blocks, +OtSendStore MakeCompactOtSendStore(const UninitAlignedVector& blocks, uint128_t delta); OtSendStore MakeCompactOtSendStore(std::vector&& blocks, uint128_t delta); -OtSendStore MakeCompactOtSendStore(AlignedVector&& blocks, +OtSendStore MakeCompactOtSendStore(UninitAlignedVector&& blocks, uint128_t delta); // OT Store (for mocking only) diff --git a/yacl/kernels/algorithms/silent_vole.cc b/yacl/kernels/algorithms/silent_vole.cc index 55e8dbcb..7e10ed60 100644 --- a/yacl/kernels/algorithms/silent_vole.cc +++ b/yacl/kernels/algorithms/silent_vole.cc @@ -226,7 +226,7 @@ void SilentVoleSender::SendImpl(const std::shared_ptr& ctx, // w would be moved into mpvole mpvole.OneTimeSetup(static_cast(delta_), std::move(w)); // mp_vole output - // AlignedVector mp_vole_output(mp_param.mp_vole_size_); + // UninitAlignedVector mp_vole_output(mp_param.mp_vole_size_); auto buf = Buffer(mp_param.mp_vole_size_ * sizeof(K)); auto mp_vole_output = absl::MakeSpan(buf.data(), mp_param.mp_vole_size_); // mpvole with fixed index @@ -277,8 +277,8 @@ void SilentVoleReceiver::RecvImpl(const std::shared_ptr& ctx, // u && v would be moved into mpvole mpvole.OneTimeSetup(std::move(u), std::move(v)); // sparse_noise && mp_vole output - AlignedVector sparse_noise(mp_param.mp_vole_size_); - // AlignedVector mp_vole_output(mp_param.mp_vole_size_); + UninitAlignedVector sparse_noise(mp_param.mp_vole_size_); + // UninitAlignedVector mp_vole_output(mp_param.mp_vole_size_); auto buf = Buffer(mp_param.mp_vole_size_ * sizeof(K)); auto mp_vole_output = absl::MakeSpan(buf.data(), mp_param.mp_vole_size_); // mpvole with fixed index diff --git a/yacl/kernels/algorithms/softspoken_ote.cc b/yacl/kernels/algorithms/softspoken_ote.cc index 7515b591..eaea1db6 100644 --- a/yacl/kernels/algorithms/softspoken_ote.cc +++ b/yacl/kernels/algorithms/softspoken_ote.cc @@ -22,6 +22,7 @@ #include "yacl/base/aligned_vector.h" #include "yacl/base/byte_container_view.h" +#include "yacl/crypto/tools/common.h" #include "yacl/math/f2k/f2k.h" #include "yacl/utils/matrix_utils.h" #include "yacl/utils/serialize.h" @@ -158,8 +159,8 @@ inline void XorReduceImpl(uint64_t k, absl::Span inout) { } // namespace SoftspokenOtExtSender::SoftspokenOtExtSender(uint64_t k, uint64_t step, - bool mal) - : k_(k), step_(step), mal_(mal) { + bool mal, bool compact) + : k_(k), step_(step), mal_(mal), compact_(compact) { counter_ = 0; pprf_num_ = (kKappa + k_ - 1) / k_; pprf_range_ = static_cast(1) << k_; @@ -168,11 +169,11 @@ SoftspokenOtExtSender::SoftspokenOtExtSender(uint64_t k, uint64_t step, const uint128_t total_size = pprf_num_ * pprf_range_ - empty_num; // punctured_leaves_ would save leaves in all pprf - punctured_leaves_ = AlignedVector(total_size); + punctured_leaves_ = UninitAlignedVector(total_size); // punctured_idx_ would record all punctured indexs - punctured_idx_ = AlignedVector(pprf_num_); + punctured_idx_ = UninitAlignedVector(pprf_num_); // remove the empty entries in punctured_leaves_ - compress_leaves_ = AlignedVector(total_size - pprf_num_); + compress_leaves_ = UninitAlignedVector(total_size - pprf_num_); // init delta delta_ = MakeUint128(0, 0); @@ -189,15 +190,15 @@ SoftspokenOtExtSender::SoftspokenOtExtSender(uint64_t k, uint64_t step, } SoftspokenOtExtReceiver::SoftspokenOtExtReceiver(uint64_t k, uint64_t step, - bool mal) - : k_(k), step_(step), mal_(mal) { + bool mal, bool compact) + : k_(k), step_(step), mal_(mal), compact_(compact) { counter_ = 0; pprf_num_ = (kKappa + k_ - 1) / k_; pprf_range_ = static_cast(1) << k_; const uint64_t empty_num = pprf_range_ - (1 << (kKappa - (pprf_num_ - 1) * k_)); const uint64_t total_size = pprf_num_ * pprf_range_ - empty_num; - all_leaves_ = AlignedVector(total_size); + all_leaves_ = UninitAlignedVector(total_size); // set default step or super batch if (step_ == 0) { @@ -233,7 +234,13 @@ void SoftspokenOtExtSender::OneTimeSetup( // FIXME: Copy base_ot, since NextSlice is not const auto dup_base_ot = base_ot; // set delta - delta_ = base_ot.CopyChoice().data()[0]; + + if (compact_) { + dup_base_ot.SetBlock(0, MakeUint128(0, 0)); + dup_base_ot.SetChoice(0, 1); + } + + delta_ = dup_base_ot.CopyChoice().data()[0]; auto recv_size = 128 * 2 * sizeof(uint128_t) + pprf_num_ * (mal_ ? 64 : 0); auto recv_buf = ctx->Recv(ctx->NextRank(), "SGRR_OTE:RECV-CORR"); @@ -297,8 +304,13 @@ void SoftspokenOtExtReceiver::OneTimeSetup( if (inited_) { return; } + YACL_ENFORCE(base_ot.Size() == kKappa); // FIXME: Copy base_ot, since NextSlice is not const auto dup_base_ot = base_ot; + if (compact_) { + dup_base_ot.SetNormalBlock(0, 0, MakeUint128(0, 0)); + dup_base_ot.SetNormalBlock(0, 1, MakeUint128(0, 0)); + } // Send Message Buffer auto send_size = 128 * 2 * sizeof(uint128_t) + pprf_num_ * (mal_ ? 64 : 0); auto send_buf = Buffer(send_size); @@ -398,18 +410,8 @@ void SoftspokenOtExtReceiver::GenRot(const std::shared_ptr& ctx, void SoftspokenOtExtReceiver::GenCot(const std::shared_ptr& ctx, uint64_t num_ot, OtRecvStore* out) { - YACL_ENFORCE(out->Size() == num_ot); - YACL_ENFORCE(out->Type() == OtStoreType::Normal); auto choices = SecureRandBits>(num_ot); - auto recv_blocks = std::vector(num_ot); - Recv(ctx, choices, absl::MakeSpan(recv_blocks), true); - - // out->SetChoices did not implement - // [Warning] low efficiency - for (uint64_t i = 0; i < num_ot; ++i) { - out->SetBlock(i, recv_blocks[i]); - out->SetChoice(i, choices[i]); - } + GenCot(ctx, choices, out); } void SoftspokenOtExtReceiver::GenCot(const std::shared_ptr& ctx, @@ -417,15 +419,21 @@ void SoftspokenOtExtReceiver::GenCot(const std::shared_ptr& ctx, OtRecvStore* out) { const uint64_t num_ot = choices.size(); YACL_ENFORCE(out->Size() == num_ot); - YACL_ENFORCE(out->Type() == OtStoreType::Normal); + YACL_ENFORCE(out->Type() == + (compact_ ? OtStoreType::Compact : OtStoreType::Normal)); auto recv_blocks = std::vector(num_ot); Recv(ctx, choices, absl::MakeSpan(recv_blocks), true); - // out->SetChoices did not implement // [Warning] low efficiency - for (uint64_t i = 0; i < num_ot; ++i) { - out->SetBlock(i, recv_blocks[i]); - out->SetChoice(i, choices[i]); + if (compact_) { + for (uint64_t i = 0; i < num_ot; ++i) { + out->SetBlock(i, recv_blocks[i]); + } + } else { + for (uint64_t i = 0; i < num_ot; ++i) { + out->SetBlock(i, recv_blocks[i]); + out->SetChoice(i, choices[i]); + } } } @@ -451,6 +459,9 @@ OtRecvStore SoftspokenOtExtReceiver::GenRot( OtRecvStore SoftspokenOtExtReceiver::GenCot( const std::shared_ptr& ctx, uint64_t num_ot) { OtRecvStore out(num_ot, OtStoreType::Normal); + if (compact_) { + out = OtRecvStore(num_ot, OtStoreType::Compact); + } // [Warning] low efficiency. GenCot(ctx, num_ot, &out); return out; @@ -463,6 +474,9 @@ OtRecvStore SoftspokenOtExtReceiver::GenCot( const std::shared_ptr& ctx, const dynamic_bitset& choices) { OtRecvStore out(choices.size(), OtStoreType::Normal); + if (compact_) { + out = OtRecvStore(choices.size(), OtStoreType::Compact); + } // [Warning] low efficiency. GenCot(ctx, choices, &out); return out; @@ -526,6 +540,10 @@ void SoftspokenOtExtSender::GenSfVole(absl::Span hash_buff, xor_offset += pprf_range_; // hash_offset = i * pprf_range } } + + if (compact_) { + V[0] = MakeUint128(0, 0); + } } // Generate Smallfield VOLE and Subspace VOLE @@ -565,6 +583,9 @@ void SoftspokenOtExtReceiver::GenSfVole(const uint128_t choice, xor_offset += pprf_range_; // xor_offset = i * pprf_range; } } + if (compact_) { + W[0] = choice; + } } // old style interface @@ -589,16 +610,16 @@ void SoftspokenOtExtSender::Send( const uint64_t all_batch_num = super_batch_num * step + batch_num; YACL_ENFORCE(all_batch_num * kBatchSize == expand_numOt); - AlignedVector, 32> allV(all_batch_num); + UninitAlignedVector, 32> allV(all_batch_num); // OT extension // AVX need to be aligned to 32 bytes. // Extra one array for consitency check in batch_num for-loop. - AlignedVector, 32> V(step + 1); - AlignedVector, 32> V_xor_delta(step + 1); + UninitAlignedVector, 32> V(step + 1); + UninitAlignedVector, 32> V_xor_delta(step + 1); // Hash Buffer to perform AES/PRG // Xor Buffer to perform XorReduce ( \sum x PRG(M_x) ) - auto hash_buff = AlignedVector(compress_leaves_.size()); - auto xor_buff = AlignedVector(pprf_num_ * pprf_range_); + auto hash_buff = UninitAlignedVector(compress_leaves_.size()); + auto xor_buff = UninitAlignedVector(pprf_num_ * pprf_range_, 0); // deal with super batch for (uint64_t t = 0; t < super_batch_num; ++t) { @@ -672,9 +693,7 @@ void SoftspokenOtExtSender::Send( if (mal_) { // Sender generates a random seed and sends it to receiver. - uint128_t seed = SecureRandSeed(); - ctx->SendAsync(ctx->NextRank(), SerializeUint128(seed), - fmt::format("SSMal-Seed")); + uint128_t seed = SyncSeedSend(ctx); // Consistency check std::vector rand_samples(all_batch_num * 2); @@ -727,14 +746,14 @@ void SoftspokenOtExtReceiver::Recv(const std::shared_ptr& ctx, const uint64_t all_batch_num = super_batch_num * step + batch_num; YACL_ENFORCE(all_batch_num * kBatchSize == expand_numOt); - AlignedVector, 32> allW(all_batch_num); + UninitAlignedVector, 32> allW(all_batch_num); auto choice_ext = ExtendChoice(choices, expand_numOt); // AVX need to be aligned to 32 bytes. // Extra one array for consitency check in batch_num for-loop. - AlignedVector, 32> W(step + 1); + UninitAlignedVector, 32> W(step + 1); // AES Buffer & Xor Buffer to perform AES/PRG and XorReduce - auto xor_buff = AlignedVector(pprf_num_ * pprf_range_); - AlignedVector U(pprf_num_ * step); + auto xor_buff = UninitAlignedVector(pprf_num_ * pprf_range_, 0); + UninitAlignedVector U(pprf_num_ * step); // deal with super batch for (uint64_t t = 0; t < super_batch_num; ++t) { @@ -797,8 +816,7 @@ void SoftspokenOtExtReceiver::Recv(const std::shared_ptr& ctx, if (mal_) { // Recevies the random seed from sender - uint128_t seed = DeserializeUint128( - ctx->Recv(ctx->NextRank(), fmt::format("SSMal-Seed"))); + uint128_t seed = SyncSeedRecv(ctx); // Consistency check std::vector rand_samples(all_batch_num * 2); diff --git a/yacl/kernels/algorithms/softspoken_ote.h b/yacl/kernels/algorithms/softspoken_ote.h index d2ac47f4..16784b58 100644 --- a/yacl/kernels/algorithms/softspoken_ote.h +++ b/yacl/kernels/algorithms/softspoken_ote.h @@ -77,7 +77,7 @@ namespace yacl::crypto { class SoftspokenOtExtSender { public: explicit SoftspokenOtExtSender(uint64_t k = 2, uint64_t step = 0, - bool mal = false); + bool mal = false, bool compact = false); void OneTimeSetup(const std::shared_ptr& ctx); @@ -125,22 +125,24 @@ class SoftspokenOtExtSender { // Softspoken one time setup bool inited_{false}; + uint64_t k_; // parameter k uint64_t pprf_num_; // kKappa / k uint64_t pprf_range_; // the number of leaves for single pprf - AlignedVector punctured_leaves_; // leaves for all pprf - AlignedVector punctured_idx_; // pprf punctured index + UninitAlignedVector punctured_leaves_; // leaves for all pprf + UninitAlignedVector punctured_idx_; // pprf punctured index uint128_t delta_; // cot delta a.k.a the choices of base OT - std::array p_idx_mask_; // mask for punctured index - AlignedVector compress_leaves_; // compressed pprf leaves - uint64_t step_{32}; // super batch size = step_ * 128 - bool mal_{false}; // malicous + std::array p_idx_mask_; // mask for punctured index + UninitAlignedVector compress_leaves_; // compressed pprf leaves + uint64_t step_{32}; // super batch size = step_ * 128 + bool mal_{false}; // malicous + bool compact_{false}; // compact mode }; class SoftspokenOtExtReceiver { public: explicit SoftspokenOtExtReceiver(uint64_t k = 2, uint64_t step = 0, - bool mal = false); + bool mal = false, bool compact = false); void OneTimeSetup(const std::shared_ptr& ctx); @@ -160,6 +162,7 @@ class SoftspokenOtExtReceiver { const dynamic_bitset& choices, OtRecvStore* out); // [Warning] low efficiency + // Compact Softspoken would the type of OtRecvStore is "OtStoreType:Compact". void GenCot(const std::shared_ptr& ctx, uint64_t num_ot, OtRecvStore* out); @@ -176,6 +179,7 @@ class SoftspokenOtExtReceiver { // OtStore-style interface // [Warning] low efficiency + // Compact Softspoken would return "Compact" OtRecvStore OtRecvStore GenCot(const std::shared_ptr& ctx, uint64_t num_ot); @@ -199,12 +203,13 @@ class SoftspokenOtExtReceiver { // Softspoken one time setup bool inited_{false}; - uint64_t k_; // parameter k - uint64_t pprf_num_; // kkappa / k - uint64_t pprf_range_; // the number of leaves for single pprf - AlignedVector all_leaves_; // leaves for all pprf - uint64_t step_{32}; // super batch size = step_ * 128 - bool mal_{false}; // malicous + uint64_t k_; // parameter k + uint64_t pprf_num_; // kkappa / k + uint64_t pprf_range_; // the number of leaves for single pprf + UninitAlignedVector all_leaves_; // leaves for all pprf + uint64_t step_{32}; // super batch size = step_ * 128 + bool mal_{false}; // malicous + bool compact_{false}; // compact mode }; // Softspoken Ot Extension interface @@ -212,8 +217,8 @@ inline void SoftspokenOtExtSend( const std::shared_ptr& ctx, const OtRecvStore& base_ot /* rot */, absl::Span> send_blocks, uint64_t k = 2, - bool cot = false, bool mal = false) { - auto ssSender = SoftspokenOtExtSender(k, 0, mal); + bool cot = false, bool mal = false, bool compact = false) { + auto ssSender = SoftspokenOtExtSender(k, 0, mal, compact); ssSender.OneTimeSetup(ctx, base_ot); ssSender.Send(ctx, send_blocks, cot); } @@ -223,8 +228,8 @@ inline void SoftspokenOtExtRecv(const std::shared_ptr& ctx, const dynamic_bitset& choices, absl::Span recv_blocks, uint64_t k = 2, bool cot = false, - bool mal = false) { - auto ssReceiver = SoftspokenOtExtReceiver(k, 0, mal); + bool mal = false, bool compact = false) { + auto ssReceiver = SoftspokenOtExtReceiver(k, 0, mal, compact); ssReceiver.OneTimeSetup(ctx, base_ot); ssReceiver.Recv(ctx, choices, recv_blocks, cot); } diff --git a/yacl/kernels/algorithms/softspoken_ote_test.cc b/yacl/kernels/algorithms/softspoken_ote_test.cc index b2ce606f..4bfb6690 100644 --- a/yacl/kernels/algorithms/softspoken_ote_test.cc +++ b/yacl/kernels/algorithms/softspoken_ote_test.cc @@ -31,6 +31,7 @@ namespace yacl::crypto { struct OtTestParams { unsigned num_ot; bool mal = false; + bool compact = false; }; struct KTestParams { @@ -183,6 +184,7 @@ TEST_P(SoftspokenOtExtTest, RotExtWorks) { const int kWorldSize = 2; const size_t num_ot = GetParam().num_ot; const bool mal = GetParam().mal; + const bool compact = GetParam().compact; auto lctxs = link::test::SetupWorld(kWorldSize); // setup network auto base_ot = MockRots(128); // mock option auto choices = RandBits>(num_ot); // get input @@ -192,11 +194,11 @@ TEST_P(SoftspokenOtExtTest, RotExtWorks) { std::vector recv_out(num_ot); std::future sender = std::async([&] { SoftspokenOtExtSend(lctxs[0], base_ot.recv, absl::MakeSpan(send_out), 2, - false, mal); + false, mal, compact); }); std::future receiver = std::async([&] { SoftspokenOtExtRecv(lctxs[1], base_ot.send, choices, - absl::MakeSpan(recv_out), 2, false, mal); + absl::MakeSpan(recv_out), 2, false, mal, compact); }); receiver.get(); sender.get(); @@ -215,6 +217,7 @@ TEST_P(SoftspokenOtExtTest, CotExtWorks) { const int kWorldSize = 2; const size_t num_ot = GetParam().num_ot; const bool mal = GetParam().mal; + const bool compact = GetParam().compact; auto lctxs = link::test::SetupWorld(kWorldSize); // setup network auto base_ot = MockRots(128); // mock option auto choices = RandBits>(num_ot); // get input @@ -225,24 +228,30 @@ TEST_P(SoftspokenOtExtTest, CotExtWorks) { std::future sender = std::async([&] { SoftspokenOtExtSend(lctxs[0], base_ot.recv, absl::MakeSpan(send_out), 3, - true, mal); + true, mal, compact); }); std::future receiver = std::async([&] { SoftspokenOtExtRecv(lctxs[1], base_ot.send, choices, - absl::MakeSpan(recv_out), 3, true, mal); + absl::MakeSpan(recv_out), 3, true, mal, compact); }); receiver.get(); sender.get(); // THEN // cot correlation = base ot choice - uint128_t check = base_ot.recv.CopyChoice().data()[0]; // get delta + uint128_t check = send_out[0][0] ^ send_out[0][1]; for (size_t i = 0; i < num_ot; ++i) { EXPECT_NE(recv_out[i], 0); EXPECT_NE(send_out[i][0], 0); EXPECT_NE(send_out[i][1], 0); EXPECT_EQ(send_out[i][choices[i]], recv_out[i]); EXPECT_EQ(check, send_out[i][0] ^ send_out[i][1]); + // Compact Mode + if (compact) { + EXPECT_EQ(send_out[i][0] & 0x1, 0); + EXPECT_EQ(send_out[i][1] & 0x1, 1); + EXPECT_EQ(recv_out[i] & 0x1, choices[i]); + } } } @@ -251,20 +260,21 @@ TEST_P(SoftspokenOtExtTest, RotStoreWorks) { const int kWorldSize = 2; const size_t num_ot = GetParam().num_ot; const bool mal = GetParam().mal; + const bool compact = GetParam().compact; auto lctxs = link::test::SetupWorld(kWorldSize); // setup network auto base_ot = MockRots(128); // mock option // WHEN // One time setup for Softspoken auto ssReceiverTask = - std::async([&] { return SoftspokenOtExtReceiver(2, 0, mal); }); + std::async([&] { return SoftspokenOtExtReceiver(2, 0, mal, compact); }); auto ssSenderTask = - std::async([&] { return SoftspokenOtExtSender(2, 0, mal); }); + std::async([&] { return SoftspokenOtExtSender(2, 0, mal, compact); }); auto ssReceiver = ssReceiverTask.get(); auto ssSender = ssSenderTask.get(); - // Generate COT + // Generate ROT std::vector> send_out1(num_ot); std::vector recv_out1(num_ot); auto sendTask1 = std::async([&] { @@ -297,15 +307,16 @@ TEST_P(SoftspokenOtExtTest, CotStoreWorks) { const int kWorldSize = 2; const size_t num_ot = GetParam().num_ot; const bool mal = GetParam().mal; + const bool compact = GetParam().compact; auto lctxs = link::test::SetupWorld(kWorldSize); // setup network auto base_ot = MockRots(128); // mock option // WHEN // One time setup for Softspoken auto ssReceiverTask = - std::async([&] { return SoftspokenOtExtReceiver(2, 0, mal); }); + std::async([&] { return SoftspokenOtExtReceiver(2, 0, mal, compact); }); auto ssSenderTask = - std::async([&] { return SoftspokenOtExtSender(2, 0, mal); }); + std::async([&] { return SoftspokenOtExtSender(2, 0, mal, compact); }); auto ssReceiver = ssReceiverTask.get(); auto ssSender = ssSenderTask.get(); @@ -325,7 +336,12 @@ TEST_P(SoftspokenOtExtTest, CotStoreWorks) { auto sendStore = sendTask1.get(); auto recvStore = recvTask1.get(); - EXPECT_EQ(recvStore.Type(), OtStoreType::Normal); + if (compact) { + EXPECT_EQ(recvStore.Type(), OtStoreType::Compact); + } else { + EXPECT_EQ(recvStore.Type(), OtStoreType::Normal); + } + EXPECT_EQ(sendStore.Type(), OtStoreType::Compact); // THEN auto delta = ssSender.GetDelta(); @@ -336,6 +352,12 @@ TEST_P(SoftspokenOtExtTest, CotStoreWorks) { EXPECT_EQ(sendStore.GetBlock(i, 0) ^ sendStore.GetBlock(i, 1), delta); EXPECT_EQ(sendStore.GetBlock(i, recvStore.GetChoice(i)), recvStore.GetBlock(i)); + // Compact Mode + if (compact) { + EXPECT_EQ(sendStore.GetBlock(i, 0) & 0x1, 0); + EXPECT_EQ(sendStore.GetBlock(i, 1) & 0x1, 1); + EXPECT_EQ(recvStore.GetBlock(i) & 0x1, recvStore.GetChoice(i)); + } } } @@ -380,19 +402,28 @@ INSTANTIATE_TEST_SUITE_P(Works_Instances, SoftspokenKTest, KTestParams{10, true})); INSTANTIATE_TEST_SUITE_P(Works_Instances, SoftspokenOtExtTest, - testing::Values(OtTestParams{8}, // - OtTestParams{128}, // - OtTestParams{129}, // - OtTestParams{4095}, // - OtTestParams{4096}, // - OtTestParams{65536}, // - OtTestParams{100000}, // - OtTestParams{8, true}, // - OtTestParams{128, true}, // - OtTestParams{129, true}, // - OtTestParams{4095, true}, // - OtTestParams{4096, true}, // - OtTestParams{65536, true}, // - OtTestParams{100000, true})); + testing::Values(OtTestParams{8}, // + OtTestParams{128}, // + OtTestParams{129}, // + OtTestParams{4095}, // + OtTestParams{4096}, // + OtTestParams{65536}, // + OtTestParams{100000}, // + // malicious OT + OtTestParams{8, true}, // + OtTestParams{128, true}, // + OtTestParams{129, true}, // + OtTestParams{4095, true}, // + OtTestParams{4096, true}, // + OtTestParams{65536, true}, // + OtTestParams{100000, true}, // + // malicious && compact OT + OtTestParams{8, true, true}, // + OtTestParams{128, true, true}, // + OtTestParams{129, true, true}, // + OtTestParams{4095, true, true}, // + OtTestParams{4096, true, true}, // + OtTestParams{65536, true, true}, // + OtTestParams{100000, true, true})); } // namespace yacl::crypto diff --git a/yacl/kernels/benchmark/ot_bench.cc b/yacl/kernels/benchmark/ot_bench.cc index 36414009..d2a2be01 100644 --- a/yacl/kernels/benchmark/ot_bench.cc +++ b/yacl/kernels/benchmark/ot_bench.cc @@ -49,6 +49,7 @@ BM_REGISTER_ALL_OT(BM_DefaultArguments); // BM_REGISTER_SGRR_OTE(BM_DefaultArguments); // BM_REGISTER_GYWZ_OTE(BM_PerfArguments); // BM_REGISTER_FERRET_OTE(BM_PerfArguments); +// BM_REGISTER_MAL_FERRET_OTE(BM_PerfArguments); // BM_REGISTER_SOFTSPOKEN_OTE(BM_PerfArguments); // BM_REGISTER_MAL_SOFTSPOKEN_OTE(BM_PerfArguments); diff --git a/yacl/kernels/benchmark/ot_bench.h b/yacl/kernels/benchmark/ot_bench.h index 48804eb8..1d20a5d9 100644 --- a/yacl/kernels/benchmark/ot_bench.h +++ b/yacl/kernels/benchmark/ot_bench.h @@ -62,8 +62,8 @@ BENCHMARK_DEFINE_F(OtBench, SimplestOT)(benchmark::State& state) { // preprare inputs auto choices = RandBits>(num_ot); - AlignedVector> send_blocks(num_ot); - AlignedVector recv_blocks(num_ot); + UninitAlignedVector> send_blocks(num_ot); + UninitAlignedVector recv_blocks(num_ot); state.ResumeTiming(); @@ -88,8 +88,8 @@ BENCHMARK_DEFINE_F(OtBench, IknpOTe)(benchmark::State& state) { const auto num_ot = state.range(0); // preprare inputs - AlignedVector> send_blocks(num_ot); - AlignedVector recv_blocks(num_ot); + UninitAlignedVector> send_blocks(num_ot); + UninitAlignedVector recv_blocks(num_ot); auto choices = RandBits>(num_ot); auto base_ot = MockRots(128); @@ -119,8 +119,8 @@ BENCHMARK_DEFINE_F(OtBench, KosOTe)(benchmark::State& state) { const auto num_ot = state.range(0); // preprare inputs - AlignedVector> send_blocks(num_ot); - AlignedVector recv_blocks(num_ot); + UninitAlignedVector> send_blocks(num_ot); + UninitAlignedVector recv_blocks(num_ot); auto choices = RandBits>(num_ot); auto base_ot = MockRots(128); @@ -150,8 +150,8 @@ BENCHMARK_DEFINE_F(OtBench, KkrtOTe)(benchmark::State& state) { const auto num_ot = state.range(0); // preprare inputs - AlignedVector inputs(num_ot); - AlignedVector recv_out(num_ot); + UninitAlignedVector inputs(num_ot); + UninitAlignedVector recv_out(num_ot); auto base_ot = MockRots(512); state.ResumeTiming(); @@ -181,8 +181,8 @@ BENCHMARK_DEFINE_F(OtBench, SgrrOTe)(benchmark::State& state) { // preprare inputs uint32_t choice_value = RandInRange(range_n); auto base_ot = MockRots(math::Log2Ceil(range_n)); - AlignedVector send_out(range_n); - AlignedVector recv_out(range_n); + UninitAlignedVector send_out(range_n); + UninitAlignedVector recv_out(range_n); state.ResumeTiming(); @@ -214,8 +214,8 @@ BENCHMARK_DEFINE_F(OtBench, GywzOTe)(benchmark::State& state) { uint32_t choice_value = RandInRange(range_n); uint128_t delta = SecureRandSeed(); auto base_ot = MockCots(math::Log2Ceil(range_n), delta); - AlignedVector send_out(range_n); - AlignedVector recv_out(range_n); + UninitAlignedVector send_out(range_n); + UninitAlignedVector recv_out(range_n); state.ResumeTiming(); @@ -313,6 +313,38 @@ BENCHMARK_DEFINE_F(OtBench, FerretOTe)(benchmark::State& state) { state.ResumeTiming(); } } + +BENCHMARK_DEFINE_F(OtBench, MalFerretOTe)(benchmark::State& state) { + YACL_ENFORCE(lctxs_.size() == 2); + for (auto _ : state) { + state.PauseTiming(); + { + const size_t num_ot = state.range(0); + + // preprare inputs + auto lpn_param = LpnParam::GetDefault(); + auto cot_num = FerretCotHelper(lpn_param, num_ot, true); // make option + auto cots_compact = MockCompactOts(cot_num); // mock cots + + state.ResumeTiming(); + + // run base OT + auto sender = std::async([&] { + return FerretOtExtSend(lctxs_[0], cots_compact.send, lpn_param, num_ot, + true); + }); + auto receiver = std::async([&] { + return FerretOtExtRecv(lctxs_[1], cots_compact.recv, lpn_param, num_ot, + true); + }); + sender.get(); + receiver.get(); + state.PauseTiming(); + } + state.ResumeTiming(); + } +} + #define DELCARE_MAL_SOFTSPOKEN_BENCH(K) \ BENCHMARK_DEFINE_F(OtBench, MalSoftspokenOTe##K)(benchmark::State & state) { \ YACL_ENFORCE(lctxs_.size() == 2); \ @@ -403,6 +435,9 @@ DELCARE_MAL_SOFTSPOKEN_BENCH(8) #define BM_REGISTER_FERRET_OTE(Arguments) \ BENCHMARK_REGISTER_F(OtBench, FerretOTe)->Apply(Arguments); +#define BM_REGISTER_MAL_FERRET_OTE(Arguments) \ + BENCHMARK_REGISTER_F(OtBench, MalFerretOTe)->Apply(Arguments); + #define BM_REGISTER_ALL_OT(Arguments) \ BM_REGISTER_SIMPLEST_OT(Arguments) \ BM_REGISTER_IKNP_OTE(Arguments) \ @@ -411,6 +446,7 @@ DELCARE_MAL_SOFTSPOKEN_BENCH(8) BM_REGISTER_SGRR_OTE(Arguments) \ BM_REGISTER_GYWZ_OTE(Arguments) \ BM_REGISTER_FERRET_OTE(Arguments) \ + BM_REGISTER_MAL_FERRET_OTE(Arguments) \ BM_REGISTER_SOFTSPOKEN_OTE(Arguments) \ BM_REGISTER_MAL_SOFTSPOKEN_OTE(Arguments) } // namespace yacl::crypto diff --git a/yacl/kernels/svole_kernel.cc b/yacl/kernels/svole_kernel.cc index a92b045b..48bcd84f 100644 --- a/yacl/kernels/svole_kernel.cc +++ b/yacl/kernels/svole_kernel.cc @@ -88,8 +88,8 @@ void SVoleKernel::eval_multithread(const std::shared_ptr& lctx, } tl_c[threads - 1] = out_c.subspan(iter_size * (threads - 1), last_size); - *out_delta = std::get(core_).GetDelta(); uint128_t shared_seed = SyncSeedSend(lctx); + *out_delta = std::get(core_).GetDelta(); auto lctx_tl = SetupLink(lctx, threads); /* thread-local link */ ThreadPool pool(threads); // the destructor joins all threads diff --git a/yacl/math/f2k/f2k.h b/yacl/math/f2k/f2k.h index 9896aec5..bbfd3e5f 100644 --- a/yacl/math/f2k/f2k.h +++ b/yacl/math/f2k/f2k.h @@ -140,7 +140,7 @@ inline uint64_t GfMul64(uint64_t x, uint64_t y) { } // inverse over Galois Field F_{2^64} -inline uint64_t Inv64(uint64_t x) { +inline uint64_t GfInv64(uint64_t x) { uint64_t t0 = x; uint64_t t1 = GfMul64(t0, t0); uint64_t t2 = GfMul64(t1, t0); @@ -275,4 +275,18 @@ inline std::array GenGf64Basis() { static std::array gf64_basis = GenGf64Basis(); static std::array gf128_basis = GenGf128Basis(); + +inline uint128_t PackGf128(absl::Span data) { + const size_t size = data.size(); + YACL_ENFORCE(size <= 128); + // inner product + return GfMul128(data, absl::MakeSpan(gf128_basis.data(), size)); +} + +inline uint64_t PackGf64(absl::Span data) { + const size_t size = data.size(); + YACL_ENFORCE(size <= 64); + // inner product + return GfMul64(data, absl::MakeSpan(gf64_basis.data(), size)); +} }; // namespace yacl \ No newline at end of file diff --git a/yacl/math/f2k/f2k_test.cc b/yacl/math/f2k/f2k_test.cc index cbb20588..d8e00a00 100644 --- a/yacl/math/f2k/f2k_test.cc +++ b/yacl/math/f2k/f2k_test.cc @@ -128,4 +128,15 @@ TEST(F2kTest, GfMul64_inner_product) { EXPECT_EQ(ret, check); EXPECT_NE(ret, zero); +} + +TEST(F2kTest, GfInv64_inner_product) { + const uint64_t size = 1001; + + auto x = yacl::crypto::RandVec(size); + for (uint64_t i = 0; i < size; ++i) { + auto inv = yacl::GfInv64(x[i]); + auto check = yacl::GfMul64(x[i], inv); + EXPECT_EQ(uint64_t(1), check); + } } \ No newline at end of file diff --git a/yacl/math/gadget.h b/yacl/math/gadget.h index 9dde74ae..5134b745 100644 --- a/yacl/math/gadget.h +++ b/yacl/math/gadget.h @@ -57,7 +57,7 @@ uint64_t inline GfMul(absl::Span a, uint128_t inline GfMul(absl::Span a, absl::Span b) { - AlignedVector tmp(b.size()); + UninitAlignedVector tmp(b.size()); std::transform(b.cbegin(), b.cend(), tmp.begin(), [](const uint64_t& val) { return static_cast(val); });