diff --git a/.bazelrc b/.bazelrc index d332aaf1..3d29c78a 100644 --- a/.bazelrc +++ b/.bazelrc @@ -23,6 +23,9 @@ build --cxxopt=-std=c++17 build --host_cxxopt=-std=c++17 build --linkopt -lm +# HACK +build:linux --cxxopt -Wno-error=mismatched-new-delete + # Binary safety flags build --host_copt=-fPIE build --host_copt=-fstack-protector-strong diff --git a/.github/workflows/buildifier.yml b/.github/workflows/buildifier.yml index fb34efe1..313f1c1f 100644 --- a/.github/workflows/buildifier.yml +++ b/.github/workflows/buildifier.yml @@ -1,3 +1,17 @@ +# Copyright 2024 Ant Group Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + --- name: Bazel files linter on: diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml index 1523a908..5bc68987 100644 --- a/.github/workflows/cla.yml +++ b/.github/workflows/cla.yml @@ -1,3 +1,17 @@ +# Copyright 2024 Ant Group Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + --- name: CLA Assistant on: diff --git a/.github/workflows/clang-format-linter.yml b/.github/workflows/clang-format-linter.yml index d5efb2ed..ded293ca 100644 --- a/.github/workflows/clang-format-linter.yml +++ b/.github/workflows/clang-format-linter.yml @@ -1,3 +1,17 @@ +# Copyright 2024 Ant Group Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + --- name: Run clang-format Linter on: diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 8940a8f5..d12cc79d 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -1,3 +1,17 @@ +# Copyright 2024 Ant Group Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # This workflow uses actions that are not certified by GitHub. They are provided # by a third-party and are governed by separate terms of service, privacy # policy, and support documentation. diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 9216b2bc..c5341d2f 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -1,3 +1,17 @@ +# Copyright 2024 Ant Group Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + --- name: Mark stale issues and pull requests on: diff --git a/.github/workflows/yaml-lint.yml b/.github/workflows/yaml-lint.yml index 8f5e07f0..86a49e5b 100644 --- a/.github/workflows/yaml-lint.yml +++ b/.github/workflows/yaml-lint.yml @@ -1,3 +1,17 @@ +# Copyright 2024 Ant Group Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + --- name: Yaml Lint on: diff --git a/ALGORITHMS.md b/ALGORITHMS.md index a0239281..febbbafc 100644 --- a/ALGORITHMS.md +++ b/ALGORITHMS.md @@ -2,39 +2,117 @@ ## Primitives -- OT - - Simplest OT : https://eprint.iacr.org/2015/267.pdf - - INKP OT Extension : https://www.iacr.org/archive/crypto2003/27290145/27290145.pdf - - KOS OT Extension : https://eprint.iacr.org/2015/546.pdf - - KKRT OT Extension : https://eprint.iacr.org/2016/799.pdf - - SGRR OT Extension: https://eprint.iacr.org/2019/1084.pdf - - GYWZ OT Extension : https://eprint.iacr.org/2022/1431.pdf - - Ferret OT Extension : https://eprint.iacr.org/2020/924.pdf - - Softspoken OT Extension : https://eprint.iacr.org/2022/192.pdf -- VOLE(over f2k) - - base VOLE : https://eprint.iacr.org/2016/505.pdf - - Silent VOLE : https://eprint.iacr.org/2019/1159.pdf, https://eprint.iacr.org/2021/1150.pdf https://eprint.iacr.org/2022/1014.pdf - -- CODE - - Local Linear Code : https://eprint.iacr.org/2020/924.pdf - - Low Density Parity Check Code (Silver Code) : https://eprint.iacr.org/2021/1150.pdf - - Expanding Accumulation Code : https://eprint.iacr.org/2022/1014.pdf +### Oblivious Transfer and Extensions + +- The Simplest Protocol for Oblivious Transfer\ + *Tung Chou, Claudio Orlandi*\ + LatinCrypt 2015, [eprint](https://eprint.iacr.org/2015/267), CO15 + +- Extending Oblivious Transfers Efficiently\ + *Yuval Ishai, Joe Kilian, Kobbi Nissim, Erez Petrank*\ + Crypto 2003, [eprint](https://www.iacr.org/archive/crypto2003/27290145/27290145.pdf), IKNP03 + +- Actively Secure OT Extension with Optimal Overhead\ + *Marcel Keller, Emmanuela Orsini, Peter Scholl*\ + Crypto 2015, [eprint](https://eprint.iacr.org/2015/546), KOS15 + +- Efficient Batched Oblivious PRF with Applications to Private Set Intersection\ + *Vladimir Kolesnikov, Ranjit Kumaresan, Mike Rosulek, Ni Trieu*\ + CCS 2016, [eprint](https://eprint.iacr.org/2016/799), KKRT16 + +- Distributed vector-OLE: Improved constructions and implementation\ + *Phillipp Schoppmann, Adrià Gascón, Leonie Reichert, Mariana Raykova*\ + CCS 2019, [eprint](https://eprint.iacr.org/2019/1084), SGRR19 + +- Half-Tree: Halving the Cost of Tree Expansion in COT and DPF\ + *Xiaojie Guo, Kang Yang, Xiao Wang, Wenhao Zhang, Xiang Xie, Jiang Zhang, Zheli Liu*\ + EUROCRYPT 2023, [eprint](https://eprint.iacr.org/2022/1431), GYWZ+23 + +- Ferret: Fast Extension for coRRElated oT with small communication\ + *Kang Yang, Chenkai Weng, Xiao Lan, Jiang Zhang, Xiao Wang*\ + CCS'20, [eprint](https://eprint.iacr.org/2020/924), YWLZ+20 + +- SoftSpokenOT: Quieter OT Extension from Small-Field Silent VOLE in the Minicrypt Model\ + *Lawrence Roy*\ + Crypto 2022, [publisher](https://www.iacr.org/cryptodb//data/paper.php?pubkey=32258), Roy22 + +### Vector Oblivious Linear Evaluation (over Field 2k) + +Base VOLE: + +- MASCOT: Faster Malicious Arithmetic Secure Computation with Oblivious Transfer\ + *Marcel Keller, Emmanuela Orsini*\ + CCS 2016, [eprint](https://eprint.iacr.org/2016/505), KO16 + +Silent VOLE: + +- Efficient Two-Round OT Extension and Silent Non-Interactive Secure Computation\ + *Elette Boyle, Geoffroy Couteau, Niv Gilboa, Yuval Ishai, Lisa Kohl, Peter Rindal, Peter Scholl*\ + CCS 2019, [eprint](https://eprint.iacr.org/2019/1159), BCGI+19 (with Peter Rindal) + +- Efficient Two-Round OT Extension and Silent Non-Interactive Secure Computation\ + *Elette Boyle, Geoffroy Couteau, Niv Gilboa, Yuval Ishai, Lisa Kohl, Peter Rindal, Peter Scholl*\ + CCS 2019, [eprint](https://eprint.iacr.org/2019/1159), BCGI+19 + +- Correlated Pseudorandomness from Expand-Accumulate Codes\ + *Elette Boyle, Geoffroy Couteau, Niv Gilboa, Yuval Ishai, Lisa Kohl, Nicolas Resch, Peter Scholl*\ + Crypto 2022, [eprint](https://eprint.iacr.org/2022/1014), BCG+22 + + +### Codes + +Local Linear Code + +- Ferret: Fast Extension for coRRElated oT with small communication\ + *Kang Yang, Chenkai Weng, Xiao Lan, Jiang Zhang, Xiao Wang*\ + CCS'20, [eprint](https://eprint.iacr.org/2020/924), YWLZ+20 + +Low Density Parity Check Code (Silver Code) + +- Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes\ + *Geoffroy Couteau, Peter Rindal, Srinivasan Raghuraman*\ + Crypto 2021, [eprint](https://eprint.iacr.org/2021/1150), CRR21 + +Expanding Accumulation Code : https://eprint.iacr.org/2022/1014.pdf + +- Correlated Pseudorandomness from Expand-Accumulate Codes\ + *Elette Boyle, Geoffroy Couteau, Niv Gilboa, Yuval Ishai, Lisa Kohl, Nicolas Resch, Peter Scholl*\ + Crypto 2022, [eprint](https://eprint.iacr.org/2022/1014), BCG+22 + ## Theoretical Tools -- Random Oracle (RO) -- Random Permutation (RP) -- Pseudorandom Generator (PRG) -- Correlation-Robust Hash (CrHash) : https://eprint.iacr.org/2019/074.pdf -- Circular Correlation-Robust Hash (CcrHash) : https://eprint.iacr.org/2019/074.pdf +Random Oracle (RO) + +- TBD + +Random Permutation (RP) + +- TBD + +Pseudorandom Generator (PRG) + +- TBD + +Correlation-Robust Hash (CrHash) + +- Efficient and Secure Multiparty Computation from Fixed-Key Block Ciphers\ + *Chun Guo, Jonathan Katz, Xiao Wang, Yu Yu*\ + Preprint 2019, [eprint](https://eprint.iacr.org/2019/074), GKWY19 + +Circular Correlation-Robust Hash (CCR Hash) + +- Efficient and Secure Multiparty Computation from Fixed-Key Block Ciphers\ + *Chun Guo, Jonathan Katz, Xiao Wang, Yu Yu*\ + Preprint 2019, [eprint](https://eprint.iacr.org/2019/074), GKWY19 -## Basic (Traditional) algorithms +## Basic (Traditional) algorithms (TBD) - AEAD - AES - Block Cipher -- ECC (TODO) +- ECC - Hash - HMAC -- PKE: RSA, SM2 -- Signature: RSA, SM2 +- Public-Key Encryption: RSA, SM2 +- Digital Signature: RSA, SM2 diff --git a/yacl/crypto/primitives/dpf/BUILD.bazel b/yacl/crypto/primitives/dpf/BUILD.bazel index a01e2739..2e8ddb5c 100644 --- a/yacl/crypto/primitives/dpf/BUILD.bazel +++ b/yacl/crypto/primitives/dpf/BUILD.bazel @@ -13,8 +13,6 @@ # limitations under the License. load("//bazel:yacl.bzl", "AES_COPT_FLAGS", "yacl_cc_library", "yacl_cc_test") -load("@rules_proto//proto:defs.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(default_visibility = ["//visibility:public"]) @@ -23,7 +21,6 @@ yacl_cc_library( srcs = ["dpf.cc"], hdrs = ["dpf.h"], deps = [ - ":serializable_cc_proto", "//yacl/base:int128", "//yacl/crypto/tools:prg", "//yacl/link", @@ -38,21 +35,6 @@ yacl_cc_test( ], ) -proto_library( - name = "serializable_proto", - srcs = [ - "serializable.proto", - ], - deps = [ - "//yacl/utils:serializable_proto", - ], -) - -cc_proto_library( - name = "serializable_cc_proto", - deps = [":serializable_proto"], -) - yacl_cc_library( name = "mpfss", srcs = ["mpfss.cc"], diff --git a/yacl/crypto/primitives/dpf/dpf.cc b/yacl/crypto/primitives/dpf/dpf.cc index a362e0c5..dd0c7650 100644 --- a/yacl/crypto/primitives/dpf/dpf.cc +++ b/yacl/crypto/primitives/dpf/dpf.cc @@ -15,11 +15,9 @@ #include "yacl/crypto/primitives/dpf/dpf.h" #include -#include -#include "spdlog/spdlog.h" - -#include "yacl/crypto/primitives/dpf/serializable.pb.h" +#include "yacl/utils/serializer.h" +#include "yacl/utils/serializer_adapter.h" namespace yacl::crypto { @@ -290,59 +288,31 @@ std::vector DpfContext::EvalAll(DpfKey& key) { return result; } -std::string DpfKey::Serialize() const { - DpfKeyProto proto; - // Set properties - proto.set_enable_evalall(enable_evalall); +Buffer DpfKey::Serialize() const { + // var "cws_vec" 's type 'std::vector' not supported, convert to STL + // type + std::vector> dpf_cws; + dpf_cws.reserve(cws_vec.size()); for (const auto& cws : cws_vec) { - auto* cws_proto = proto.add_cws_vec(); - auto i128_parts = DecomposeUInt128(cws.GetSeed()); - cws_proto->mutable_seed()->set_hi(i128_parts.first); - cws_proto->mutable_seed()->set_lo(i128_parts.second); - cws_proto->set_t_store(cws.GetTStore()); - } - for (const auto& last_cw : last_cw_vec) { - auto* last_cw_proto = proto.add_last_cw_vec(); - auto i128_parts = DecomposeUInt128(last_cw); - last_cw_proto->set_hi(i128_parts.first); - last_cw_proto->set_lo(i128_parts.second); + dpf_cws.emplace_back(cws.GetSeed(), cws.GetTStore()); } - proto.set_rank(rank_); - proto.set_in_bitnum(in_bitnum_); - proto.set_ss_bitnum(ss_bitnum_); - proto.set_sec_param(sec_param_); - auto i128_parts = DecomposeUInt128(mseed_); - proto.mutable_mseed()->set_hi(i128_parts.first); - proto.mutable_mseed()->set_lo(i128_parts.second); - - return proto.SerializeAsString(); + // do serialize + return SerializeVars(enable_evalall, dpf_cws, last_cw_vec, rank_, in_bitnum_, + ss_bitnum_, sec_param_, mseed_); } -void DpfKey::Deserialize(const std::string& s) { - DpfKeyProto proto; - proto.ParseFromString(s); +void DpfKey::Deserialize(ByteContainerView in) { + std::vector> dpf_cws; + DeserializeVarsTo(in, &enable_evalall, &dpf_cws, &last_cw_vec, &rank_, + &in_bitnum_, &ss_bitnum_, &sec_param_, &mseed_); - enable_evalall = proto.enable_evalall(); + // recover "cws_vec" with type std::vector cws_vec.clear(); - for (const auto& cws_proto : proto.cws_vec()) { - cws_vec.emplace_back( - MakeUint128(cws_proto.seed().hi(), cws_proto.seed().lo()), - cws_proto.t_store()); + cws_vec.reserve(dpf_cws.size()); + for (const auto& cws : dpf_cws) { + cws_vec.emplace_back(cws.first, cws.second); } - - last_cw_vec.clear(); - for (const auto& last_cw_proto : proto.last_cw_vec()) { - last_cw_vec.emplace_back( - MakeUint128(last_cw_proto.hi(), last_cw_proto.lo())); - } - - rank_ = proto.rank(); - in_bitnum_ = proto.in_bitnum(); - ss_bitnum_ = proto.ss_bitnum(); - sec_param_ = proto.sec_param(); - - mseed_ = MakeUint128(proto.mseed().hi(), proto.mseed().lo()); } } // namespace yacl::crypto diff --git a/yacl/crypto/primitives/dpf/dpf.h b/yacl/crypto/primitives/dpf/dpf.h index e222bf7f..c5b933dd 100644 --- a/yacl/crypto/primitives/dpf/dpf.h +++ b/yacl/crypto/primitives/dpf/dpf.h @@ -101,8 +101,8 @@ class DpfKey { uint32_t GetSecParam() const { return sec_param_; } - std::string Serialize() const; - void Deserialize(const std::string& s); + Buffer Serialize() const; + void Deserialize(ByteContainerView s); private: bool rank_{}; // only support two parties (0/1), compulsory param diff --git a/yacl/crypto/primitives/dpf/dpf_test.cc b/yacl/crypto/primitives/dpf/dpf_test.cc index a0738795..01132366 100644 --- a/yacl/crypto/primitives/dpf/dpf_test.cc +++ b/yacl/crypto/primitives/dpf/dpf_test.cc @@ -123,7 +123,7 @@ TEST_P(FssDpfEvalAllTest, Works) { } DpfKey k1_copy; - std::string k1_string = k1.Serialize(); + auto k1_string = k1.Serialize(); k1_copy.Deserialize(k1_string); temp0 = context.EvalAll(k0); diff --git a/yacl/crypto/primitives/dpf/serializable.proto b/yacl/crypto/primitives/dpf/serializable.proto deleted file mode 100644 index 55add3b3..00000000 --- a/yacl/crypto/primitives/dpf/serializable.proto +++ /dev/null @@ -1,37 +0,0 @@ -// -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -syntax = "proto3"; - -package yacl.crypto; - -import "yacl/utils/serializable.proto"; - -message DpfCWProto { - Uint128Proto seed = 1; - uint32 t_store = 2; -} - -message DpfKeyProto { - bool enable_evalall = 1; - repeated DpfCWProto cws_vec = 2; - repeated Uint128Proto last_cw_vec = 3; - bool rank = 4; - uint64 in_bitnum = 5; - uint64 ss_bitnum = 6; - uint32 sec_param = 7; - Uint128Proto mseed = 8; -} diff --git a/yacl/crypto/primitives/ot/kkrt_ote_test.cc b/yacl/crypto/primitives/ot/kkrt_ote_test.cc index 557f46f9..ecf209e7 100644 --- a/yacl/crypto/primitives/ot/kkrt_ote_test.cc +++ b/yacl/crypto/primitives/ot/kkrt_ote_test.cc @@ -35,7 +35,7 @@ struct TestParams { class KkrtOtExtTest : public ::testing::TestWithParam {}; -TEST_P(KkrtOtExtTest, Works) { +TEST_P(KkrtOtExtTest, DISABLED_Works) { // GIVEN const int kWorldSize = 2; auto contexts = link::test::SetupWorld(kWorldSize); @@ -78,7 +78,7 @@ INSTANTIATE_TEST_SUITE_P(Works_Instances, KkrtOtExtTest, TestParams{4096}, // TestParams{65536})); -TEST(KkrtOtExtEdgeTest, Test) { +TEST(KkrtOtExtEdgeTest, DISABLED_Test) { // GIVEN const int kWorldSize = 2; auto contexts = link::test::SetupWorld(kWorldSize); @@ -111,7 +111,7 @@ TEST(KkrtOtExtEdgeTest, Test) { } class KkrtOtExtTest2 : public ::testing::TestWithParam {}; -TEST_P(KkrtOtExtTest2, Works) { +TEST_P(KkrtOtExtTest2, DISABLED_Works) { // GIVEN const int kWorldSize = 2; auto contexts = link::test::SetupWorld(kWorldSize); diff --git a/yacl/crypto/primitives/vss/poly.cc b/yacl/crypto/primitives/vss/poly.cc index 4374a55b..2e9e8b4c 100644 --- a/yacl/crypto/primitives/vss/poly.cc +++ b/yacl/crypto/primitives/vss/poly.cc @@ -1,3 +1,17 @@ +// Copyright 2024 Ant Group Co., Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "yacl/crypto/primitives/vss/poly.h" namespace yacl::crypto { diff --git a/yacl/crypto/primitives/vss/poly.h b/yacl/crypto/primitives/vss/poly.h index 71ff8d48..19ac874d 100644 --- a/yacl/crypto/primitives/vss/poly.h +++ b/yacl/crypto/primitives/vss/poly.h @@ -1,3 +1,19 @@ +/* + * Copyright 2024 Ant Group Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #pragma once #include "yacl/math/mpint/mp_int.h" diff --git a/yacl/crypto/primitives/vss/vss.cc b/yacl/crypto/primitives/vss/vss.cc index b29cd230..1003df35 100644 --- a/yacl/crypto/primitives/vss/vss.cc +++ b/yacl/crypto/primitives/vss/vss.cc @@ -1,3 +1,17 @@ +// Copyright 2024 Ant Group Co., Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "yacl/crypto/primitives/vss/vss.h" namespace yacl::crypto { diff --git a/yacl/crypto/primitives/vss/vss.h b/yacl/crypto/primitives/vss/vss.h index 952a9654..937cbd02 100644 --- a/yacl/crypto/primitives/vss/vss.h +++ b/yacl/crypto/primitives/vss/vss.h @@ -1,3 +1,19 @@ +/* + * Copyright 2024 Ant Group Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #pragma once #include diff --git a/yacl/crypto/primitives/vss/vss_test.cc b/yacl/crypto/primitives/vss/vss_test.cc index cc40ab05..67ebdb44 100644 --- a/yacl/crypto/primitives/vss/vss_test.cc +++ b/yacl/crypto/primitives/vss/vss_test.cc @@ -1,3 +1,17 @@ +// Copyright 2024 Ant Group Co., Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "yacl/crypto/primitives/vss/vss.h" #include "gtest/gtest.h" diff --git a/yacl/crypto/utils/drbg/ic_factory.h b/yacl/crypto/utils/drbg/ic_factory.h index f19ba835..55c2cccc 100644 --- a/yacl/crypto/utils/drbg/ic_factory.h +++ b/yacl/crypto/utils/drbg/ic_factory.h @@ -48,8 +48,8 @@ class IcDrbg : public Drbg { const SpiArgs &config) { YACL_ENFORCE(Check(type, config)); // make sure check passes return std::make_unique( - absl::AsciiStrToUpper(type), config.Get(ArgUseYaclEs, true), - config.Get(ArgSecParamC, SecParam::C::k128)); + absl::AsciiStrToUpper(type), config.GetOrDefault(ArgUseYaclEs, true), + config.GetOrDefault(ArgSecParamC, SecParam::C::k128)); } // this checker would return ture only for ctr-drbg type diff --git a/yacl/crypto/utils/drbg/openssl_factory.h b/yacl/crypto/utils/drbg/openssl_factory.h index 6fb2b174..477a71b2 100644 --- a/yacl/crypto/utils/drbg/openssl_factory.h +++ b/yacl/crypto/utils/drbg/openssl_factory.h @@ -51,8 +51,8 @@ class OpensslDrbg : public Drbg { const SpiArgs &config) { YACL_ENFORCE(Check(type, config)); // make sure check passes return std::make_unique( - absl::AsciiStrToUpper(type), config.Get(ArgUseYaclEs, true), - config.Get(ArgSecParamC, SecParam::C::k128)); + absl::AsciiStrToUpper(type), config.GetOrDefault(ArgUseYaclEs, true), + config.GetOrDefault(ArgSecParamC, SecParam::C::k128)); } // this checker would return ture only for ctr-drbg type diff --git a/yacl/io/msgpack/BUILD.bazel b/yacl/io/msgpack/BUILD.bazel index db995f51..a1587751 100644 --- a/yacl/io/msgpack/BUILD.bazel +++ b/yacl/io/msgpack/BUILD.bazel @@ -35,5 +35,6 @@ yacl_cc_test( srcs = ["buffer_test.cc"], deps = [ ":buffer", + "@com_github_msgpack_msgpack//:msgpack", ], ) diff --git a/yacl/io/msgpack/buffer.h b/yacl/io/msgpack/buffer.h index 46d67728..9bd3dead 100644 --- a/yacl/io/msgpack/buffer.h +++ b/yacl/io/msgpack/buffer.h @@ -81,4 +81,17 @@ class FixedBuffer { size_t pos_ = 0; }; +// ShadowBuffer does not store any data. +// It is used solely for calculating the size of objects after msgpack +// serialization. +class ShadowBuffer { + public: + void write(const char *, size_t len) { size_ += len; } + + size_t GetDataSize() const { return size_; } + + private: + size_t size_ = 0; +}; + } // namespace yacl::io diff --git a/yacl/io/msgpack/buffer_test.cc b/yacl/io/msgpack/buffer_test.cc index d2f86461..32f16bb6 100644 --- a/yacl/io/msgpack/buffer_test.cc +++ b/yacl/io/msgpack/buffer_test.cc @@ -15,6 +15,7 @@ #include "yacl/io/msgpack/buffer.h" #include "gtest/gtest.h" +#include "msgpack.hpp" namespace yacl::io::test { @@ -31,4 +32,15 @@ TEST(TestStreamBuffer, SimpleWorks) { EXPECT_EQ(std::string(buffer.data(), buffer.size()), "abcd"); } +TEST(TestStreamBuffer, ShadowBufferWorks) { + yacl::Buffer buf; + yacl::io::StreamBuffer sbuf(&buf); + msgpack::pack(sbuf, 1.2); + + ShadowBuffer sd_buf; + msgpack::pack(sd_buf, 1.2); + + EXPECT_EQ(sd_buf.GetDataSize(), buf.size()); +} + } // namespace yacl::io::test diff --git a/yacl/math/galois_field/mcl_field/mcl_field.cc b/yacl/math/galois_field/mcl_field/mcl_field.cc index da8ca5fd..e6d83ab0 100644 --- a/yacl/math/galois_field/mcl_field/mcl_field.cc +++ b/yacl/math/galois_field/mcl_field/mcl_field.cc @@ -57,8 +57,8 @@ const std::vector kMclFieldMetas = { std::unique_ptr MclFieldFactory::Create( const std::string& field_name, const SpiArgs& args) { auto mod = args.GetRequired(ArgMod); - auto degree = args.Get(ArgDegree, 1); - auto maxBitSize = args.Get(ArgMaxBitSize, 512); + auto degree = args.GetOrDefault(ArgDegree, 1); + auto maxBitSize = args.GetOrDefault(ArgMaxBitSize, 512); auto it = kMclFieldMetas.cbegin(); for (; it != kMclFieldMetas.cend(); it++) { if (it->IsEquivalent({field_name, degree, maxBitSize})) { @@ -92,8 +92,8 @@ std::unique_ptr MclFieldFactory::Create( bool MclFieldFactory::Check(const std::string& field_name, const SpiArgs& args) { - auto degree = args.Get(ArgDegree, 1); - auto maxBitSize = args.Get(ArgMaxBitSize, 512); + auto degree = args.GetOrDefault(ArgDegree, 1); + auto maxBitSize = args.GetOrDefault(ArgMaxBitSize, 512); MclFieldMeta meta = {field_name, degree, maxBitSize}; for (auto it : kMclFieldMetas) { if (meta.IsEquivalent(it)) { diff --git a/yacl/math/galois_field/mpint_field/mpint_field_test.cc b/yacl/math/galois_field/mpint_field/mpint_field_test.cc index 5f6c3de4..eaa779ad 100644 --- a/yacl/math/galois_field/mpint_field/mpint_field_test.cc +++ b/yacl/math/galois_field/mpint_field/mpint_field_test.cc @@ -216,7 +216,7 @@ TEST_F(MPIntFieldTest, VectorIoWorks) { auto item2 = item1.SubItem(0, 1); ASSERT_TRUE(item2.IsView()); ASSERT_FALSE(item2.IsReadOnly()); - ASSERT_TRUE(item2.WrappedTypeIs>()); + ASSERT_TRUE(item2.RawTypeIs>()); ASSERT_TRUE(gf->Equal(item2, Item::Take({0_mp}))); // deepcopy diff --git a/yacl/math/mpint/mp_int.cc b/yacl/math/mpint/mp_int.cc index 31d149d2..3910e03d 100644 --- a/yacl/math/mpint/mp_int.cc +++ b/yacl/math/mpint/mp_int.cc @@ -514,6 +514,7 @@ void MPInt::Pow(const MPInt &a, uint32_t b, MPInt *c) { MPInt MPInt::Pow(uint32_t b) const { if (b == 0) { + YACL_ENFORCE(!IsZero(), "Power: 0^0 is illegal"); return MPInt::_1_; } diff --git a/yacl/math/mpint/mp_int.h b/yacl/math/mpint/mp_int.h index 6d25633e..60ac8bfe 100644 --- a/yacl/math/mpint/mp_int.h +++ b/yacl/math/mpint/mp_int.h @@ -117,7 +117,11 @@ class MPInt { [[nodiscard]] size_t BitCount() const; + // The size of memory allocated by this MPInt. + // Not equal to the byte size of the number, nor equal to the serialized size size_t SizeAllocated() { return n_.alloc * sizeof(mp_digit); } + // The size of memory used by this MPInt. + // Not equal to the byte size of the number, nor equal to the serialized size size_t SizeUsed() { return n_.used * sizeof(mp_digit); } //================================// @@ -352,7 +356,7 @@ class MPInt { const std::function &combine_inplace) { YACL_ENFORCE(!scalar.IsNegative(), "scalar must >= 0, get {}", scalar); - if (scalar.n_.used == 0) { + if (scalar.IsZero()) { return identity; } diff --git a/yacl/utils/BUILD.bazel b/yacl/utils/BUILD.bazel index 3dd05fa3..68ab8322 100644 --- a/yacl/utils/BUILD.bazel +++ b/yacl/utils/BUILD.bazel @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_cc//cc:defs.bzl", "cc_proto_library") -load("@rules_proto//proto:defs.bzl", "proto_library") load("//bazel:yacl.bzl", "OMP_CFLAGS", "OMP_DEPS", "OMP_LINKFLAGS", "yacl_cc_binary", "yacl_cc_library", "yacl_cc_test") package(default_visibility = ["//visibility:public"]) @@ -77,25 +75,35 @@ yacl_cc_library( ], ) -proto_library( - name = "serializable_proto", - srcs = ["serializable.proto"], +# deprecated, use serializer instead +yacl_cc_library( + name = "serialize", + hdrs = ["serialize.h"], + deps = [ + ":serializer", + ], ) -cc_proto_library( - name = "serializable_cc_proto", - deps = [":serializable_proto"], +yacl_cc_library( + name = "serializer", + srcs = ["serializer.cc"], + hdrs = [ + "serializer.h", + "serializer_adapter.h", + ], + deps = [ + "@com_github_msgpack_msgpack//:msgpack", + "@yacl//yacl/base:byte_container_view", + "@yacl//yacl/base:int128", + "@yacl//yacl/io/msgpack:buffer", + ], ) -yacl_cc_library( - name = "serialize", - srcs = ["serialize.cc"], - hdrs = ["serialize.h"], +yacl_cc_test( + name = "serializer_test", + srcs = ["serializer_test.cc"], deps = [ - ":serializable_cc_proto", - "//yacl/base:buffer", - "//yacl/base:byte_container_view", - "//yacl/base:int128", + ":serializer", ], ) diff --git a/yacl/utils/serialize.cc b/yacl/utils/serialize.cc deleted file mode 100644 index fd5a54b0..00000000 --- a/yacl/utils/serialize.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "yacl/utils/serialize.h" - -#include "yacl/utils/serializable.pb.h" - -namespace yacl { - -Buffer SerializeArrayOfBuffers(const std::vector& bufs) { - ArrayOfBuffer proto; - for (const auto& b : bufs) { - proto.add_bufs(b.data(), b.size()); - } - Buffer b(proto.ByteSizeLong()); - proto.SerializePartialToArray(b.data(), b.size()); - return b; -} - -std::vector DeserializeArrayOfBuffers(ByteContainerView buf) { - ArrayOfBuffer proto; - std::vector bufs; - proto.ParseFromArray(buf.data(), buf.size()); - for (const auto& b : proto.bufs()) { - bufs.emplace_back(b); - } - return bufs; -} - -Buffer SerializeInt128(int128_t v) { - Int128Proto proto; - auto parts = DecomposeInt128(v); - proto.set_hi(parts.first); - proto.set_lo(parts.second); - - Buffer b(proto.ByteSizeLong()); - proto.SerializePartialToArray(b.data(), b.size()); - return b; -} - -int128_t DeserializeInt128(ByteContainerView buf) { - Int128Proto proto; - proto.ParseFromArray(buf.data(), buf.size()); - return MakeInt128(proto.hi(), proto.lo()); -} - -Buffer SerializeUint128(uint128_t v) { - Uint128Proto proto; - auto parts = DecomposeUInt128(v); - proto.set_hi(parts.first); - proto.set_lo(parts.second); - - Buffer b(proto.ByteSizeLong()); - proto.SerializePartialToArray(b.data(), b.size()); - return b; -} - -uint128_t DeserializeUint128(ByteContainerView buf) { - Uint128Proto proto; - proto.ParseFromArray(buf.data(), buf.size()); - return MakeUint128(proto.hi(), proto.lo()); -} - -} // namespace yacl diff --git a/yacl/utils/serialize.h b/yacl/utils/serialize.h index ef688588..e59616d1 100644 --- a/yacl/utils/serialize.h +++ b/yacl/utils/serialize.h @@ -19,19 +19,36 @@ #include "yacl/base/buffer.h" #include "yacl/base/byte_container_view.h" #include "yacl/base/int128.h" +#include "yacl/utils/serializer.h" +#include "yacl/utils/serializer_adapter.h" namespace yacl { -Buffer SerializeArrayOfBuffers(const std::vector& bufs); +// deprecated. please call SerializeVars(...) directly +inline Buffer SerializeArrayOfBuffers( + const std::vector& bufs) { + return SerializeVars(bufs); +} -std::vector DeserializeArrayOfBuffers(ByteContainerView buf); +// deprecated. please call DeserializeVars(...) directly +inline std::vector DeserializeArrayOfBuffers(ByteContainerView buf) { + return DeserializeVars>(buf); +} -Buffer SerializeInt128(int128_t v); +// deprecated. please call SerializeVars(...) directly +inline Buffer SerializeInt128(int128_t v) { return SerializeVars(v); } -int128_t DeserializeInt128(ByteContainerView buf); +// deprecated. please call DeserializeVars(...) directly +inline int128_t DeserializeInt128(ByteContainerView buf) { + return DeserializeVars(buf); +} -Buffer SerializeUint128(uint128_t v); +// deprecated. please call SerializeVars(...) directly +inline Buffer SerializeUint128(uint128_t v) { return SerializeVars(v); } -uint128_t DeserializeUint128(ByteContainerView buf); +// deprecated. please call DeserializeVars(...) directly +inline uint128_t DeserializeUint128(ByteContainerView buf) { + return DeserializeVars(buf); +} } // namespace yacl diff --git a/yacl/utils/serializer.cc b/yacl/utils/serializer.cc new file mode 100644 index 00000000..9b1ab21a --- /dev/null +++ b/yacl/utils/serializer.cc @@ -0,0 +1,36 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/utils/serializer.h" + +namespace yacl { + +bool internal::ref_or_copy(msgpack::type::object_type type, std::size_t length, + void *) { + switch (type) { + case msgpack::type::STR: + // Small strings are copied. + return length >= 32; + case msgpack::type::BIN: + // BIN is always referenced. + return true; + case msgpack::type::EXT: + // EXT is always copied. + return false; + default: + YACL_THROW("unexpected type {}", static_cast(type)); + } +} + +} // namespace yacl diff --git a/yacl/utils/serializer.h b/yacl/utils/serializer.h new file mode 100644 index 00000000..11b4bfdd --- /dev/null +++ b/yacl/utils/serializer.h @@ -0,0 +1,140 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "msgpack.hpp" + +#include "yacl/base/buffer.h" +#include "yacl/base/byte_container_view.h" +#include "yacl/io/msgpack/buffer.h" + +namespace yacl { + +// Serialization/deserialization tools, +// supporting a single variable or a set of variables. +// +// +// Use case 1: Serialize a single variable +// +// > int64_t v1 = -12345; +// > auto buf = SerializeVars(v1); +// > int64_t v1_new = DeserializeVars(buf); +// +// +// Use case 2: Serialize multiple variables +// +// > int64_t v1 = -12345; +// > bool v2 = true; +// > std::string v3 = "hello"; +// > +// > auto buf = SerializeVars(v1, v2, v3); +// > auto [v1_new, v2_new, v3_new] = +// > DeserializeVars(buf); +// +// +// Use case 3: Deserialize to an existing variable, avoiding memory copy +// +// > int64_t v1 = 123; +// > std::vector v2 = {"hello", "world"}; +// > +// > auto needed_buf_size = SerializeVarsTo(nullptr, 0, v1, v2); +// > ... get a buffer larger than needed_buf_size +// > SerializeVarsTo(buffer, needed_buf_size, v1, v2); +// > +// > int64_t v1_new; +// > std::vector v2_new; +// > DeserializeVarsTo(buffer, &v1_new, &v2_new); +// +// +// Special notice: +// If you need to serialize special types (non-STL types/containers), +// such as int128_t, please include the following header file to enable support +// for extended types: +// > #include "yacl/utils/serializer_adapter.h" + +template +inline yacl::Buffer SerializeVars(const Ts &...obj) { + yacl::Buffer buf; + yacl::io::StreamBuffer sbuf(&buf); + (..., msgpack::pack(sbuf, obj)); + return buf; +} + +// Serialize the 'obj' into buf and return the actual number of bytes written. +// If buf is empty, only calculate the size of the serialized object. +template +inline size_t SerializeVarsTo(uint8_t *buf, size_t buf_len, const Ts &...obj) { + if (buf == nullptr) { + yacl::io::ShadowBuffer sd_buf; + (..., msgpack::pack(sd_buf, obj)); + return sd_buf.GetDataSize(); + } + + yacl::io::FixedBuffer fbuf((char *)buf, buf_len); + (..., msgpack::pack(fbuf, obj)); + return fbuf.WrittenSize(); +} + +namespace internal { + +bool ref_or_copy(msgpack::type::object_type type, std::size_t length, void *); + +template +std::tuple DoDeserializeAsTuple(std::index_sequence, + yacl::ByteContainerView in) { + std::size_t off = 0; + std::tuple res; + (..., msgpack::unpack(reinterpret_cast(in.data()), in.size(), + off, ref_or_copy) + ->convert(std::get(res))); + + return res; +} + +} // namespace internal + +// If Ts is a single type, return type T +// If Ts is a type array, return std::tuple +template +inline auto DeserializeVars(yacl::ByteContainerView in) -> + typename std::conditional_t>, + std::tuple> { + if constexpr (sizeof...(Ts) == 1) { + auto msg = msgpack::unpack(reinterpret_cast(in.data()), + in.size(), internal::ref_or_copy); + + std::tuple_element_t<0, std::tuple> res; + msg->convert(res); + return res; + } else { + return internal::DoDeserializeAsTuple( + std::index_sequence_for(), in); + } +} + +template +inline size_t DeserializeVarsTo(yacl::ByteContainerView in, Ts *...vars) { + std::size_t off = 0; + (..., msgpack::unpack(reinterpret_cast(in.data()), in.size(), + off, internal::ref_or_copy) + ->convert(*vars)); + return off; +} + +} // namespace yacl diff --git a/yacl/utils/serializer_adapter.h b/yacl/utils/serializer_adapter.h new file mode 100644 index 00000000..9460929e --- /dev/null +++ b/yacl/utils/serializer_adapter.h @@ -0,0 +1,129 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/base/int128.h" +#include "yacl/utils/serializer.h" + +// clang-format off +namespace msgpack { +MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) { +namespace adaptor { + // clang-format on + + //=== adapter of int128_t ===// + + template <> + struct pack { + template + packer &operator()(msgpack::packer &o, + const int128_t &v) const { + std::pair pair = yacl::DecomposeInt128(v); + o.pack(pair); + return o; + } + }; + + template <> + struct convert { + const msgpack::object &operator()(const msgpack::object &o, + int128_t &v) const { + auto pair = o.as>(); + v = yacl::MakeInt128(pair.first, pair.second); + return o; + } + }; + + //=== adapter of uint128_t ===// + + template <> + struct pack { + template + packer &operator()(msgpack::packer &o, + const uint128_t &v) const { + std::pair pair = yacl::DecomposeUInt128(v); + o.pack(pair); + return o; + } + }; + + template <> + struct convert { + const msgpack::object &operator()(const msgpack::object &o, + uint128_t &v) const { + auto pair = o.as>(); + v = yacl::MakeUint128(pair.first, pair.second); + return o; + } + }; + + //=== adapter of ByteContainerView ===// + + template <> + struct pack { + template + packer &operator()(msgpack::packer &o, + const yacl::ByteContainerView &v) const { + uint32_t size = checked_get_container_size(v.size()); + o.pack_bin(size); + o.pack_bin_body(reinterpret_cast(v.data()), size); + return o; + } + }; + + // If you deserialize into ByteContainerView, you can avoid copying, but + // ownership depends on the input buffer. + template <> + struct convert { + const msgpack::object &operator()(const msgpack::object &o, + yacl::ByteContainerView &v) const { + YACL_ENFORCE(o.type == msgpack::type::BIN, + "Type mismatch, cannot deserialize. exp_type={}", + static_cast(o.type)); + v = yacl::ByteContainerView(o.via.bin.ptr, o.via.bin.size); + return o; + } + }; + + //=== adapter of yacl::Buffer ===// + + // yacl::Buffer is compatible with yacl::ByteContainerView + template <> + struct pack { + template + packer &operator()(msgpack::packer &o, + const yacl::Buffer &v) const { + uint32_t size = checked_get_container_size(v.size()); + o.pack_bin(size); + o.pack_bin_body(v.data(), size); + return o; + } + }; + + template <> + struct convert { + const msgpack::object &operator()(const msgpack::object &o, + yacl::Buffer &v) const { + YACL_ENFORCE(o.type == msgpack::type::BIN, + "Type mismatch, cannot deserialize. "); + v = yacl::Buffer(reinterpret_cast(o.via.bin.ptr), + o.via.bin.size); + return o; + } + }; + + // clang-format off +} // namespace adaptor +} // namespace msgpack +} // namespace msgpack +// clang-format on diff --git a/yacl/utils/serializer_test.cc b/yacl/utils/serializer_test.cc new file mode 100644 index 00000000..68631e2b --- /dev/null +++ b/yacl/utils/serializer_test.cc @@ -0,0 +1,95 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/utils/serializer.h" + +#include "gtest/gtest.h" + +#include "yacl/base/int128.h" +#include "yacl/utils/serializer_adapter.h" + +namespace yacl::test { + +TEST(SerializerTest, SingleWorks) { + int64_t v1 = -12345; + int64_t f1; + + // ser/deser single + auto buf = SerializeVars(v1); + // DeserializeVars directly returns int64_t + ASSERT_EQ(DeserializeVars(buf), v1); + + ASSERT_EQ(buf.size(), SerializeVarsTo(nullptr, 0, v1)); + ASSERT_EQ(buf.size(), SerializeVarsTo(buf.data(), buf.size(), v1)); + DeserializeVarsTo(buf, &f1); + ASSERT_EQ(v1, f1); +} + +TEST(SerializerTest, MultiWorks) { + int64_t v1 = -12345; + bool v2 = true; + std::string v3 = "hello"; + double v4 = 1.2; + int64_t v5 = 987; // test type is duplicate with v1 + + auto buf = SerializeVars(v1, v2, v3, v4, v5); + auto [f1, f2, f3, f4, f5] = + DeserializeVars(buf); + EXPECT_EQ(v1, f1); + EXPECT_EQ(v2, f2); + EXPECT_EQ(v3, f3); + EXPECT_EQ(v4, f4); + EXPECT_EQ(v5, f5); + + ASSERT_EQ(buf.size(), SerializeVarsTo(nullptr, 0, v1, v2, v3, v4, v5)); + ASSERT_EQ(buf.size(), SerializeVarsTo(buf.data(), buf.size(), v1, v2, + v3, v4, v5)); + std::tie(f1, f2, f3, f4, f5) = std::make_tuple(0, false, "", .0, 0); + DeserializeVarsTo(buf, &f1, &f2, &f3, &f4, &f5); + EXPECT_EQ(v1, f1); + EXPECT_EQ(v2, f2); + EXPECT_EQ(v3, f3); + EXPECT_EQ(v4, f4); +} + +TEST(SerializerTest, Int128) { + int128_t v1 = yacl::MakeInt128(INT64_MAX, INT64_MAX); + auto buf = SerializeVars(v1); + EXPECT_EQ(DeserializeVars(buf), v1); + + uint128_t v2 = yacl::MakeUint128(INT64_MAX, 123); + buf = SerializeVars(v2); + EXPECT_EQ(DeserializeVars(buf), v2); +} + +TEST(SerializerTest, Buffer) { + // test serializes raw buffer + yacl::Buffer hello(std::string("hello")); + + auto buf = SerializeVars(hello, hello); + auto [f1, f2] = DeserializeVars(buf); + EXPECT_STREQ(f1.data(), "hello"); + EXPECT_EQ((std::string_view)f2, "hello"); + + ByteContainerView view1, view2; + DeserializeVarsTo(buf, &view1, &view2); + EXPECT_EQ( + std::string(reinterpret_cast(view1.data()), view1.size()), + "hello"); + EXPECT_EQ( + std::string(reinterpret_cast(view2.data()), view2.size()), + "hello"); +} + +} // namespace yacl::test diff --git a/yacl/utils/spi/argument/BUILD.bazel b/yacl/utils/spi/argument/BUILD.bazel index 3f4c1b7f..527aeba8 100644 --- a/yacl/utils/spi/argument/BUILD.bazel +++ b/yacl/utils/spi/argument/BUILD.bazel @@ -41,6 +41,29 @@ yacl_cc_library( "arg_set.h", ], deps = [ + ":util", "//yacl/base:exception", + "//yacl/math/mpint", ], ) + +yacl_cc_library( + name = "util", + srcs = [ + "util.cc", + ], + hdrs = [ + "util.h", + ], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +yacl_cc_test( + name = "util_test", + srcs = [ + "util_test.cc", + ], + deps = [":util"], +) diff --git a/yacl/utils/spi/argument/arg_k.h b/yacl/utils/spi/argument/arg_k.h index f8cfeb81..816df18e 100644 --- a/yacl/utils/spi/argument/arg_k.h +++ b/yacl/utils/spi/argument/arg_k.h @@ -18,6 +18,7 @@ #include "spdlog/spdlog.h" #include "yacl/utils/spi/argument/arg_kv.h" +#include "yacl/utils/spi/argument/util.h" namespace yacl { @@ -26,11 +27,11 @@ class SpiArgKey { public: using ValueType = T; - explicit SpiArgKey(const std::string &key) - : key_(absl::AsciiStrToLower(key)) {} + explicit SpiArgKey(const std::string &key) : key_(util::ToSnakeCase(key)) {} const std::string &Key() const & { return key_; } + // If value is a string, it will be automatically converted to lowercase SpiArg operator=(T &&value) const { return {key_, std::forward(value)}; } SpiArg operator=(const T &value) const { return {key_, value}; } @@ -52,6 +53,8 @@ class SpiArgKey { #define DEFINE_ARG_uint(ArgName) DEFINE_ARG(uint, ArgName) #define DEFINE_ARG_int64(ArgName) DEFINE_ARG(int64_t, ArgName) #define DEFINE_ARG_uint64(ArgName) DEFINE_ARG(uint64_t, ArgName) +#define DEFINE_ARG_double(ArgName) DEFINE_ARG(double, ArgName) +// Note: The arg value will be automatically converted to lowercase #define DEFINE_ARG_string(ArgName) DEFINE_ARG(std::string, ArgName) // declare an arg @@ -60,6 +63,8 @@ class SpiArgKey { #define DECLARE_ARG_uint(ArgName) DECLARE_ARG(uint, ArgName) #define DECLARE_ARG_int64(ArgName) DECLARE_ARG(int64_t, ArgName) #define DECLARE_ARG_uint64(ArgName) DECLARE_ARG(uint64_t, ArgName) +#define DECLARE_ARG_double(ArgName) DECLARE_ARG(double, ArgName) +// Note: The arg value will be automatically converted to lowercase #define DECLARE_ARG_string(ArgName) DECLARE_ARG(std::string, ArgName) } // namespace yacl diff --git a/yacl/utils/spi/argument/arg_kv.cc b/yacl/utils/spi/argument/arg_kv.cc index 1075893c..0bea156f 100644 --- a/yacl/utils/spi/argument/arg_kv.cc +++ b/yacl/utils/spi/argument/arg_kv.cc @@ -14,6 +14,8 @@ #include "yacl/utils/spi/argument/arg_kv.h" +#include "yacl/math/mpint/mp_int.h" + namespace yacl { const std::string& SpiArg::Key() const { return key_; } @@ -30,4 +32,31 @@ SpiArg& SpiArg::operator=(const std::string& value) { return *this; } +#define TRY_TYPE(type) \ + if (t == typeid(type)) { \ + return fmt::format("{}={}", key_, std::any_cast(value_)); \ + } + +std::string SpiArg::ToString() const { + const auto& t = value_.type(); + // Place the types with a high probability of being hit at the front. + TRY_TYPE(std::string); + TRY_TYPE(int64_t); // mac-m1 doesn't support int128 + TRY_TYPE(uint64_t); + TRY_TYPE(bool); + TRY_TYPE(double); + + TRY_TYPE(int8_t); + TRY_TYPE(int16_t); + TRY_TYPE(int32_t); + TRY_TYPE(uint8_t); + TRY_TYPE(uint16_t); + TRY_TYPE(uint32_t); + TRY_TYPE(float); + TRY_TYPE(char); + TRY_TYPE(unsigned char); + TRY_TYPE(yacl::math::MPInt); // MPInt is a first-class citizen in SPI + return fmt::format("{}=Object<{}>", key_, t.name()); +} + } // namespace yacl diff --git a/yacl/utils/spi/argument/arg_kv.h b/yacl/utils/spi/argument/arg_kv.h index 9e3bc5a6..5cd9e7c9 100644 --- a/yacl/utils/spi/argument/arg_kv.h +++ b/yacl/utils/spi/argument/arg_kv.h @@ -26,15 +26,17 @@ #include "spdlog/spdlog.h" #include "yacl/base/exception.h" +#include "yacl/utils/spi/argument/util.h" namespace yacl { class SpiArg { public: - explicit SpiArg(const std::string &key) : key_(absl::AsciiStrToLower(key)) {} + explicit SpiArg(const std::string &key) : key_(util::ToSnakeCase(key)) {} + // If value is a string, it will be automatically converted to lowercase template - SpiArg(const std::string &key, T &&value) : key_(absl::AsciiStrToLower(key)) { + SpiArg(const std::string &key, T &&value) : key_(util::ToSnakeCase(key)) { operator=(std::forward(value)); } @@ -45,6 +47,7 @@ class SpiArg { } // Specialized functions of operator= + // If value is a string, it will be automatically converted to lowercase SpiArg &operator=(const char *value); SpiArg &operator=(const std::string &value); @@ -61,9 +64,13 @@ class SpiArg { } } + std::string ToString() const; + private: std::string key_; std::any value_; }; +inline auto format_as(const SpiArg &arg) { return arg.ToString(); } + } // namespace yacl diff --git a/yacl/utils/spi/argument/arg_set.cc b/yacl/utils/spi/argument/arg_set.cc index a14df274..e93e6f06 100644 --- a/yacl/utils/spi/argument/arg_set.cc +++ b/yacl/utils/spi/argument/arg_set.cc @@ -14,12 +14,33 @@ #include "yacl/utils/spi/argument/arg_set.h" +// formatter to format SpiArgs values +template <> +struct fmt::formatter::value_type> { + template + constexpr auto parse(ParseContext &ctx) { + return ctx.begin(); + } + + template + auto format(const std::map::value_type &fp, + FormatContext &ctx) const { + return fmt::format_to(ctx.out(), "{}", fp.second); + } +}; + namespace yacl { SpiArgs::SpiArgs(std::initializer_list args) { - for (const auto& item : args) { + for (const auto &item : args) { insert({item.Key(), item}); } } +void SpiArgs::Insert(const SpiArg &arg) { insert({arg.Key(), arg}); } + +std::string SpiArgs::ToString() const { + return fmt::format("{{{}}}", fmt::join(*this, ", ")); +} + } // namespace yacl diff --git a/yacl/utils/spi/argument/arg_set.h b/yacl/utils/spi/argument/arg_set.h index bbf5c403..6f8396d4 100644 --- a/yacl/utils/spi/argument/arg_set.h +++ b/yacl/utils/spi/argument/arg_set.h @@ -14,21 +14,32 @@ #pragma once +#include + #include "yacl/utils/spi/argument/arg_k.h" namespace yacl { -class SpiArgs : public std::map { +class SpiArgs : private std::map { public: SpiArgs(std::initializer_list args); + void Insert(const SpiArg &arg); + + using std::map::size; + using std::map::empty; + using std::map::begin; + using std::map::cbegin; + using std::map::end; + using std::map::cend; + // Get an argument // If this parameter is not set, the default value is returned // If the user sets this parameter, but the type is not T, then an exception // is thrown template - T Get(const SpiArgKey &key, - const typename SpiArgKey::ValueType &default_value) const { + T GetOrDefault(const SpiArgKey &key, + const typename SpiArgKey::ValueType &default_value) const { auto it = find((key.Key())); if (it == end()) { return default_value; @@ -42,8 +53,7 @@ class SpiArgs : public std::map { // If the user sets this parameter, but the type is not T, then an exception // is thrown template - auto GetRequired(const SpiArgKey &key) const -> - typename SpiArgKey::ValueType { + T GetRequired(const SpiArgKey &key) const { auto it = find((key.Key())); YACL_ENFORCE(it != end(), "Missing required argument {}", key.Key()); return it->second.template Value(); @@ -53,14 +63,24 @@ class SpiArgs : public std::map { // After getting the SpiArg, you can use SpiArg.HasValue() to check if it // contains a value template - SpiArg GetOptional(const SpiArgKey &key) const { + std::optional GetOptional(const SpiArgKey &key) const { auto it = find(key.Key()); if (it == end()) { - return SpiArg{key.Key()}; + return {}; } else { - return it->second; + return it->second.template Value(); } } + + // Check if key exists + template + bool Exist(const SpiArgKey &key) const { + return find(key.Key()) != end(); + } + + std::string ToString() const; }; +inline auto format_as(const SpiArgs &arg) { return arg.ToString(); } + } // namespace yacl diff --git a/yacl/utils/spi/argument/util.cc b/yacl/utils/spi/argument/util.cc new file mode 100644 index 00000000..a1eb1ae6 --- /dev/null +++ b/yacl/utils/spi/argument/util.cc @@ -0,0 +1,37 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/utils/spi/argument/util.h" + +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" +#include "fmt/core.h" + +namespace yacl::util { + +std::string ToSnakeCase(const std::string& str) { + std::regex reg("[A-Z]?[a-z0-9]*"); + std::string log = str; + std::vector words; + for (std::smatch sm; + std::regex_search(log, sm, reg, std::regex_constants::match_not_null); + log = sm.suffix()) { + words.push_back(absl::AsciiStrToLower(sm.str())); + } + return absl::StrJoin(words, "_"); +} + +} // namespace yacl::util diff --git a/yacl/utils/serializable.proto b/yacl/utils/spi/argument/util.h similarity index 63% rename from yacl/utils/serializable.proto rename to yacl/utils/spi/argument/util.h index 9a0c3b63..f6d6708a 100644 --- a/yacl/utils/serializable.proto +++ b/yacl/utils/spi/argument/util.h @@ -1,33 +1,23 @@ -// -// Copyright 2022 Ant Group Co., Ltd. +// Copyright 2024 Ant Group Co., Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// -syntax = "proto3"; +#pragma once -package yacl; +#include -message ArrayOfBuffer { - repeated bytes bufs = 1; -} +namespace yacl::util { -message Uint128Proto { - uint64 hi = 1; - uint64 lo = 2; -} +std::string ToSnakeCase(const std::string& str); -message Int128Proto { - int64 hi = 1; - uint64 lo = 2; -} +} // namespace yacl::util diff --git a/yacl/utils/spi/argument/util_test.cc b/yacl/utils/spi/argument/util_test.cc new file mode 100644 index 00000000..a1fc93f6 --- /dev/null +++ b/yacl/utils/spi/argument/util_test.cc @@ -0,0 +1,43 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/utils/spi/argument/util.h" + +#include "gtest/gtest.h" + +namespace yacl::util::test { + +TEST(UtilTest, ToSnakeWorks) { + EXPECT_EQ(ToSnakeCase("Hello"), "hello"); + EXPECT_EQ(ToSnakeCase("HelloWorld"), "hello_world"); + EXPECT_EQ(ToSnakeCase("hello"), "hello"); + EXPECT_EQ(ToSnakeCase("helloWorld"), "hello_world"); + EXPECT_EQ(ToSnakeCase("hello_world"), "hello_world"); + EXPECT_EQ(ToSnakeCase("hello-world"), "hello_world"); + EXPECT_EQ(ToSnakeCase("hello-worldHaha"), "hello_world_haha"); + + EXPECT_EQ(ToSnakeCase(""), ""); + EXPECT_EQ(ToSnakeCase("T"), "t"); + EXPECT_EQ(ToSnakeCase("t"), "t"); + EXPECT_EQ(ToSnakeCase("Tesla"), "tesla"); + EXPECT_EQ(ToSnakeCase("Tesla3"), "tesla3"); + EXPECT_EQ(ToSnakeCase("TeslaModel3"), "tesla_model3"); + EXPECT_EQ(ToSnakeCase("teslaModel3"), "tesla_model3"); + EXPECT_EQ(ToSnakeCase("Tesla_Model3"), "tesla_model3"); + EXPECT_EQ(ToSnakeCase("tesla_Model3"), "tesla_model3"); + EXPECT_EQ(ToSnakeCase("Tesla_model3"), "tesla_model3"); + EXPECT_EQ(ToSnakeCase("tesla_model3"), "tesla_model3"); +} + +} // namespace yacl::util::test diff --git a/yacl/utils/spi/item.cc b/yacl/utils/spi/item.cc index b888c8bd..e3ef7011 100644 --- a/yacl/utils/spi/item.cc +++ b/yacl/utils/spi/item.cc @@ -85,7 +85,7 @@ bool Item::IsAll(const bool &element) const { std::string Item::ToString() const { if (IsArray()) { - return fmt::format("{} Item, element_type={}, {}, Content={}", + return fmt::format("{} Item, element_type={}, {}, Content=[{}]", IsView() ? "Span" : "Vector", v_.type().name(), IsReadOnly() ? "RO" : "RW", TryRead(v_)); } else { diff --git a/yacl/utils/spi/item.h b/yacl/utils/spi/item.h index 5292a062..df638962 100644 --- a/yacl/utils/spi/item.h +++ b/yacl/utils/spi/item.h @@ -63,6 +63,33 @@ enum class OperandType : int { Vector2Vector = 0b11, }; +// Item is a container that can hold any type. It has 3 basic attributes: +// - Item::IsArray(): Whether the underlying type is a scalar or a vector +// - Item::IsView(): Whether the Item has ownership of the data +// - Item::IsReadonly(): For reference types, whether the underlying data is +// writable, used to distinguish between "reference to T" and "reference to +// const T" +// +// The combination of the three attributes is as follows: +// +---+-------+-------+----------+---------------+-------------------------+ +// | # | Is | Is | Is | Underlying | Remark | +// | | Array | View | ReadOnly | Type | | +// +---+-------+-------+----------+---------------+-------------------------+ +// | 1 | false | - | - | T | Item own value T | +// +---+-------+-------+----------+---------------+-------------------------+ +// | 2 | true | false | false | vector | Item own Array | +// +---+-------+-------+----------+---------------+-------------------------+ +// | 3 | true | false | true | vector | Not in use | +// | | | | | | Ignore readonly mark | +// +---+-------+-------+----------+---------------+-------------------------+ +// | 4 | true | true | false | Span | Ref to an Array | +// +---+-------+-------+----------+---------------+-------------------------+ +// | 5 | true | true | true | Span | Ref to a const Array | +// +---+-------+-------+----------+---------------+-------------------------+ +// +// For Item with ownership, (row #3 of the table), it currently does not +// distinguish whether the data is read-only. If you want to represent a +// constant vector, please use top-level const, that is, "const Item &var" class Item { public: // Take or copy scalar @@ -220,27 +247,123 @@ class Item { } } + // **This is a dark magic function; do not use it if you are not familiar with + // its behavior.** + // + // ResizeAndSpan checks and expands the underlying array, and returns a Span + // for external reading and writing to the Item. + // + // The function consists of three parts: + // 1) Resize part: Resize the underlying array; + // 2) Span part: Create a writable span for the underlying data and return + // it, so that the data inside the Item can be modified externally + // 3) Type reset part: If the original underlying type (denoted as U) is not + // consistent with the target type T, attempt to erase U and reset it to + // T. In this case, all existing data in the Item will be lost. + // + // Resize part: The function will attempt to resize the underlying data and + // perform a data type substitution when necessary: + // - If the underlying data is T, create a new Vector and replace it + // - If the underlying data is Vector, resize it to target size + // - If the underlying data is Span, This means it cannot be resized, + // however, the function can check the actual size of the Span and if it is + // greater than or equal to the target size, nothing needs to be done, + // otherwise an exception is thrown. + // - If the underlying data is Span, a writable Span cannot be + // created and an exception is thrown. + // + // Span part: Creates and returns a Span for the underlying data. + // - If the underlying data is a Vector, a Span of size "expected_size" is + // created based on the vector. + // - If the underlying data is a Span, and the size of the span is exactly + // equal to "expected_size", then the underlying span is returned. + // Otherwise, a new span of size "expected_size" is created based on that + // span. + // + // Type reset part: This is the most dangerous part of the function, as it + // will do everything possible to accommodate user needs, which means + // rewriting the underlying data type when necessary. + // Suppose the underlying type of Item is U, and the function ends up + // returning Span: + // - If the underlying type is a scalar U, then U is deleted and replaced + // with a new Vector. + // - If the underlying data is a Vector, then the Vector is deleted and + // replaced with Vector. + // - If the underlying data is Span, it indicates that the Item has + // referential properties and the data cannot be modified, in which case + // the function throws an exception. + // + // Usage example + // + // Proper use of ResizeAndSpan() can be very convenient. Here's an example: + // > void func(const Item &in, Item *out) { + // > auto in_sp = in.AsSpan(); + // > auto out_sp = out->ResizeAndSpan(in_sp.size()); + // > // ... now you can write data to out_sp + // > } template absl::Span ResizeAndSpan(size_t expected_size) { - auto sp = AsSpan(); + static_assert(!std::is_const_v, "Cannot resize as a const span"); + + // If the underlying data is T + if (!IsArray()) { + // create a vector and replace T + auto vec = std::vector(expected_size); + if (RawTypeIs() && expected_size > 0) { + vec[0] = std::move(As()); + } + // Don't do this: `*this = Item(std::move(vec))` + // this will discard the custom slots + v_ = std::move(vec); + Setup(true, false, false); + // now return + return absl::MakeSpan(As>()); + } + + // Now the underlying data is Vector or Span + if (!IsView()) { + // Vector case + if (RawTypeIs>()) { + auto& vec = As>(); + // do resize + if (vec.size() < expected_size) { + vec.resize(expected_size); + } + // do span + return absl::MakeSpan(vec.data(), expected_size); + } else { + // just discard vector and replace with vector + v_ = std::vector(expected_size); + return absl::MakeSpan(As>()); + } + } + + // Now the underlying data is Span or Const_Span + YACL_ENFORCE(!IsReadOnly(), + "The underlying data is readonly, Cannot create a read-write " + "span. Item detail: {}", + this->ToString()); + YACL_ENFORCE(RawTypeIs>(), + "The underlying type of item is {}, excepted type is {}, " + "cannot resize", + v_.type().name(), typeid(T).name()); + auto sp = As>(); if (sp.size() == expected_size) { return sp; + } else if (sp.size() > expected_size) { + return sp.subspan(0, expected_size); } - - // size doesn't match, now we try to resize - YACL_ENFORCE(IsArray() && !IsView() && !IsReadOnly(), - "Resize item fail, actual size={}, resize={}, item_info: {}", - sp.size(), expected_size, ToString()); - auto& vec = As>(); - vec.resize(expected_size); - return absl::MakeSpan(vec); + YACL_THROW( + "The underlying data is Span, cannot resize. Current size={}, " + "expected={}", + sp.size(), expected_size); } bool HasValue() const noexcept { return v_.has_value(); } // Check the type that directly stored, which is, container type + data type template - bool WrappedTypeIs() const noexcept { + bool RawTypeIs() const noexcept { return v_.type() == typeid(T); } @@ -248,11 +371,10 @@ class Item { template bool DataTypeIs() const noexcept { if (IsArray()) { - return WrappedTypeIs>() || - WrappedTypeIs>() || - WrappedTypeIs>(); + return RawTypeIs>() || RawTypeIs>() || + RawTypeIs>(); } else { - return WrappedTypeIs(); + return RawTypeIs(); } } @@ -287,7 +409,7 @@ class Item { static_assert(!std::is_same_v, "Cannot compare to another Item, since the Type info is " "discarded at runtime"); - return HasValue() && WrappedTypeIs() && As() == other; + return HasValue() && RawTypeIs() && As() == other; } template @@ -323,6 +445,7 @@ class Item { } virtual std::string ToString() const; + friend std::ostream& operator<<(std::ostream& os, const Item& a); protected: template @@ -424,8 +547,6 @@ class Item { uint8_t meta_ = 0; }; -std::ostream& operator<<(std::ostream& os, const Item& a); - template <> bool Item::IsAll(const bool& element) const; diff --git a/yacl/utils/spi/item_test.cc b/yacl/utils/spi/item_test.cc index 452a4750..26a4180d 100644 --- a/yacl/utils/spi/item_test.cc +++ b/yacl/utils/spi/item_test.cc @@ -125,6 +125,7 @@ TEST(ItemTest, RefPtr) { } TEST(ItemTest, ResizeAndSpan) { + // Vector case auto item = Item::Take(std::vector()); EXPECT_EQ(item.AsSpan().size(), 0); EXPECT_EQ(item.Size(), 0); @@ -133,8 +134,53 @@ TEST(ItemTest, ResizeAndSpan) { EXPECT_EQ(sp.size(), 100); EXPECT_EQ(item.Size(), 100); + sp = item.ResizeAndSpan(50); + EXPECT_EQ(sp.size(), 50); + EXPECT_EQ(item.Size(), 100); + sp[29] = 456; EXPECT_EQ(item.SubItem(29, 20).AsSpan()[0], 456); + + // ... with type change + auto sp2 = item.ResizeAndSpan(0); + EXPECT_EQ(sp2.size(), 0); + EXPECT_EQ(item.Size(), 0); + + // Single T case + item = (int)3; + EXPECT_FALSE(item.IsArray()); + sp = item.ResizeAndSpan(40); + EXPECT_TRUE(item.IsArray()); + EXPECT_FALSE(item.IsView()); + EXPECT_FALSE(item.IsReadOnly()); + EXPECT_EQ(sp.size(), 40); + EXPECT_EQ(item.Size(), 40); + EXPECT_EQ(sp[0], 3); + + // ... with type change + item = "hello"; + auto sp3 = item.ResizeAndSpan(1); + EXPECT_EQ(sp3.size(), 1); + EXPECT_EQ(item.Size(), 1); + + // Span case + std::vector vec = {1, 2, 3}; + item = Item::Ref(vec); + EXPECT_TRUE(item.IsArray()); + EXPECT_TRUE(item.IsView()); + sp = item.ResizeAndSpan(2); + EXPECT_EQ(sp.size(), 2); + EXPECT_EQ(item.Size(), 3); + sp[1] = 456; + EXPECT_EQ(vec[1], 456); + sp = item.ResizeAndSpan(3); + EXPECT_EQ(sp.size(), 3); + EXPECT_EQ(item.Size(), 3); + EXPECT_EQ(sp[2], 3); + // cannot resize a span + EXPECT_ANY_THROW(item.ResizeAndSpan(4)); + // cannot change type + EXPECT_ANY_THROW(item.ResizeAndSpan(1)); } class DummyItem : public Item { diff --git a/yacl/utils/spi/spi_factory.h b/yacl/utils/spi/spi_factory.h index 57d8c93f..b6e9839c 100644 --- a/yacl/utils/spi/spi_factory.h +++ b/yacl/utils/spi/spi_factory.h @@ -41,6 +41,16 @@ using SpiCreatorT = // Returns: True is supported and false is unsupported. using SpiCheckerT = std::function; +template +struct SpiLibMeta { + int64_t performance; + + // pointer to Ckeck(...) function + SpiCheckerT Check; + // pointer to Create(...) function + SpiCreatorT Create; +}; + // The base factory of SPI. // Each SPI can inherit this class for better flexibility template @@ -51,63 +61,102 @@ class SpiFactoryBase { void operator=(const SpiFactoryBase &) = delete; void operator=(SpiFactoryBase &&) = delete; - // Auto selects the best library and creates the spi instance. - // feature_name: The actual meaning is defined by each SPI. For example, - // feature_name represents the name of the elliptic curve in ECC SPI, and it - // represents the name of the phe algorithm in PHE SPI. + // Create a library instance + // + // If `extra_args` explicitly specifies the library to be created (with + // ArgLib=xxx_name), the factory checks if the library supports the input + // parameters; if it does, it creates an instance of the library. + // If `extra_args` does not specify a library name, the factory automatically + // selects the highest-performing library that meets the parameter + // requirements and creates an instance. + // + // 中文(translation): + // 如果extra_args明确指定了要创建的库,则工厂检查该库是否支持输入参数,如果支持则创建库实例。 + // 如果extra_args未指定库名称,则工厂自动选择性能最高,且满足参数要求的库并创建实例 + // + // @param: feature_name: The actual meaning is defined by each SPI. For + // example, feature_name represents the name of the elliptic curve in ECC + // SPI, and it represents the name of the HE algorithm in HE SPI. template std::unique_ptr Create(const std::string &feature_name, T &&...extra_args) const { - SpiArgs args({std::forward(extra_args)...}); + return CreateFromArgPkg(feature_name, {std::forward(extra_args)...}); + } + + std::unique_ptr CreateFromArgPkg(const std::string &feature_name, + const SpiArgs &args) const { auto lib_name = args.GetOptional(ArgLib); - if (!lib_name.HasValue()) { - // auto select best lib + if (!lib_name) { + // no lib name, auto select best lib for (const auto &perf_item : performance_map_) { - if (checker_map_.at(perf_item.second)(feature_name, args)) { + if (libs_map_.at(perf_item.second).Check(feature_name, args)) { lib_name = perf_item.second; break; } SPDLOG_DEBUG("SPI lib {} does not support feature {}, try next ...", perf_item.second, feature_name); } + + // check the target lib is founded after for-loop + YACL_ENFORCE( + lib_name, + "There are no lib supports {}, please use other feature/args", + feature_name); } else { // The user has specified lib - auto lib_it = checker_map_.find(lib_name.Value()); - YACL_ENFORCE(lib_it != checker_map_.end(), "Lib {} not exist", - lib_name.Value()); - YACL_ENFORCE(lib_it->second(feature_name, args), - "Lib {} does not support feature {} or args", - lib_name.Value(), feature_name); + auto lib_it = libs_map_.find(*lib_name); + YACL_ENFORCE(lib_it != libs_map_.end(), "Lib {} not exist", *lib_name); + YACL_ENFORCE(lib_it->second.Check(feature_name, args), + "Lib {} does not support feature {} or args", *lib_name, + feature_name); } - YACL_ENFORCE(lib_name.HasValue(), - "There are no lib supports {}, please use other feature/args", - feature_name); - YACL_ENFORCE(creator_map_.count(lib_name.Value()) > 0, - "Create {} instance fail, spi lib not found", - lib_name.Value()); - - return creator_map_.at(lib_name.Value())(feature_name, args); + try { + return libs_map_.at(*lib_name).Create(feature_name, args); + } catch (const std::exception &ex) { + SPDLOG_ERROR( + "SPI: Create Lib {} fail, Input args are: {}, Detail message:\n{}", + *lib_name, args.ToString(), ex.what()); + throw; + } } // List all registered libraries std::vector ListLibraries() const { std::vector res; - res.reserve(creator_map_.size()); - for (const auto &[key, _] : creator_map_) { + res.reserve(libs_map_.size()); + for (const auto &[key, _] : libs_map_) { res.push_back(key); } return res; } // List libraries that support this feature + // + // * If `extra_args` explicitly specifies the library to create (for example, + // ArgLib=xxx_name), then the method will check if the library supports the + // specified parameters. If it does, it returns a list containing only that + // library; otherwise, it returns an empty list. + // * If `extra_args` does not specify a library name, then the method returns + // the names of all libraries that satisfy the parameter requirements. template std::vector ListLibraries(const std::string &feature_name, T &&...extra_args) const { + return ListLibrariesFromArgPkg(feature_name, + {std::forward(extra_args)...}); + } + + std::vector ListLibrariesFromArgPkg( + const std::string &feature_name, const SpiArgs &args) const { std::vector res; - SpiArgs args({std::forward(extra_args)...}); - for (const auto &item : checker_map_) { - if (!item.second(feature_name, args)) { + auto lib_name = args.GetOptional(ArgLib); + for (const auto &item : libs_map_) { + // Check ArgLib limit + if (lib_name && *lib_name != item.first) { + continue; + } + // Check other args limit + if (!item.second.Check(feature_name, args)) { continue; } res.push_back(item.first); @@ -115,19 +164,15 @@ class SpiFactoryBase { return res; } - void Register(const std::string &lib_name, uint64_t performance, + void Register(const std::string &lib_name, int64_t performance, const SpiCheckerT &checker, const SpiCreatorT &creator) { auto lib_key = absl::AsciiStrToLower(lib_name); - YACL_ENFORCE(creator_map_.count(lib_key) == 0, + YACL_ENFORCE(libs_map_.count(lib_key) == 0, "SPI lib name conflict, {} already exist", lib_key); - while (performance_map_.count(performance) > 0) { - ++performance; - } performance_map_.insert({performance, lib_key}); - checker_map_.insert({lib_key, checker}); - creator_map_.insert({lib_key, creator}); + libs_map_.insert({lib_key, {performance, checker, creator}}); } protected: @@ -135,11 +180,9 @@ class SpiFactoryBase { private: // performance/priority -> lib name - std::map> performance_map_; - // lib name -> lib factory - std::map> creator_map_; - // lib name -> lib factory - std::map checker_map_; + std::multimap> performance_map_; + // lib name -> lib meta (include factory) + std::map> libs_map_; }; // Helper class for REGISTER_SPI_LIBRARY_HELPER macro @@ -152,7 +195,7 @@ class Registration { /// \param performance the estimated performance of this lib, bigger is /// better template - Registration(const std::string &lib_name, uint64_t performance, + Registration(const std::string &lib_name, int64_t performance, const CheckerT &checker, const CreatorT &creator) { FACTORY_T::Instance().Register(lib_name, performance, checker, creator); } @@ -163,6 +206,6 @@ class Registration { #define CONCAT(x, y) CONCAT_IMPL(x, y) #define REGISTER_SPI_LIBRARY_HELPER(factory_t, lib_name, performance, checker, \ creator) \ - static Registration CONCAT(registration_spi_, __COUNTER__)( \ - lib_name, performance, checker, creator) + static ::yacl::Registration CONCAT( \ + registration_spi_, __COUNTER__)(lib_name, performance, checker, creator) } // namespace yacl diff --git a/yacl/utils/spi/spi_factory_test.cc b/yacl/utils/spi/spi_factory_test.cc index 553708bb..da7c9063 100644 --- a/yacl/utils/spi/spi_factory_test.cc +++ b/yacl/utils/spi/spi_factory_test.cc @@ -53,12 +53,16 @@ class MockPaillierLib : public MockPheSpi { static std::unique_ptr Create(const std::string &phe_name, const SpiArgs &args) { + fmt::println("Create MockPaillierLib with args {}", args); + YACL_ENFORCE(phe_name == "paillier"); - return std::make_unique(args.Get(ArgKeySize, 2048)); + return std::make_unique( + args.GetOrDefault(ArgKeySize, 2048)); } static bool Check(const std::string &phe_name, const SpiArgs &args) { - return phe_name == "paillier" && args.Get(ArgKeySize, 2048) <= 4096; + return phe_name == "paillier" && + args.GetOrDefault(ArgKeySize, 2048) <= 4096; } std::string ToString() override { @@ -83,11 +87,13 @@ class MockQuantumLib : public MockPheSpi { static std::unique_ptr Create(const std::string &phe_name, const SpiArgs &args) { YACL_ENFORCE(phe_name == "elgamal"); - return std::make_unique(args.Get(Curve, "ed25519")); + return std::make_unique( + args.GetOrDefault(Curve, "ed25519")); } static bool Check(const std::string &phe_name, const SpiArgs &args) { - return phe_name == "elgamal" && args.Get(Curve, "ed25519") == "ed25519"; + return phe_name == "elgamal" && + args.GetOrDefault(Curve, "ed25519") == "ed25519"; } std::string ToString() override { @@ -118,8 +124,8 @@ TEST(SpiFactoryTest, TestListLibs) { ASSERT_EQ(libs.size(), 1); ASSERT_TRUE(libs[0] == "mock_paillier_lib"); - libs = MockPheSpiFactory::Instance().ListLibraries("paillier", - ArgKeySize = 2048); + libs = MockPheSpiFactory::Instance().ListLibraries( + "paillier", ArgLib = "mock_paillier_lib", ArgKeySize = 2048); ASSERT_EQ(libs.size(), 1); ASSERT_TRUE(libs[0] == "mock_paillier_lib"); @@ -127,6 +133,10 @@ TEST(SpiFactoryTest, TestListLibs) { ArgKeySize = 100000); ASSERT_EQ(libs.size(), 0); + libs = MockPheSpiFactory::Instance().ListLibraries("paillier", + ArgLib = "no-lib"); + ASSERT_EQ(libs.size(), 0); + libs = MockPheSpiFactory::Instance().ListLibraries("elgamal", Curve = "ed25519"); ASSERT_EQ(libs.size(), 1);