From d3766dfc15e17da274401c3274a956cb8949737a Mon Sep 17 00:00:00 2001 From: Jamie Cui Date: Fri, 22 Nov 2024 19:06:16 +0800 Subject: [PATCH] feat: now sha256 works --- yacl/engine/plaintext/executor.h | 47 +++++++++++----- yacl/engine/plaintext/executor_test.cc | 24 ++++---- yacl/io/circuit/BUILD.bazel | 1 + yacl/io/circuit/bristol_fashion.h | 77 +++++++++++++++++--------- 4 files changed, 98 insertions(+), 51 deletions(-) diff --git a/yacl/engine/plaintext/executor.h b/yacl/engine/plaintext/executor.h index 3028151..9d71e23 100644 --- a/yacl/engine/plaintext/executor.h +++ b/yacl/engine/plaintext/executor.h @@ -27,6 +27,8 @@ namespace yacl::engine { class PlainExecutor { public: + using BlockType = uint8_t; + // Constructor explicit PlainExecutor() = default; @@ -38,7 +40,7 @@ class PlainExecutor { void SetupInputs(absl::Span inputs) { YACL_ENFORCE(inputs.size() == circ_->niv); - dynamic_bitset input_wires; + dynamic_bitset input_wires; input_wires.resize(sizeof(T) * 8 * inputs.size()); std::memcpy(input_wires.data(), inputs.data(), inputs.size() * sizeof(T)); wires_.append(input_wires); @@ -46,9 +48,13 @@ class PlainExecutor { } // Setup the input wire + // + // NOTE internally this function simply copies the memory of bytes to internal + // dynamic_bitset void SetupInputBytes(ByteContainerView bytes) { - wires_.resize(circ_->nw); + wires_.resize(bytes.size() * 8); std::memcpy(wires_.data(), bytes.data(), bytes.size()); + wires_.resize(circ_->nw); } // Execute the circuit @@ -60,7 +66,7 @@ class PlainExecutor { YACL_ENFORCE(outputs.size() >= circ_->nov); size_t index = wires_.size(); for (size_t i = 0; i < circ_->nov; ++i) { - dynamic_bitset result(circ_->now[i]); + dynamic_bitset result(circ_->now[i]); for (size_t j = 0; j < circ_->now[i]; ++j) { result[j] = wires_[index - circ_->now[i] + j]; } @@ -76,23 +82,38 @@ class PlainExecutor { total_out_bitnum += circ_->now[i]; } - // Make sure that the circuit output wire is full bytes - YACL_ENFORCE(total_out_bitnum % 8 == 0); + // // Make sure that the circuit output wire is full bytes + // YACL_ENFORCE(total_out_bitnum % 8 == 0); + + // const size_t wire_size = wires_.size(); + // dynamic_bitset result(total_out_bitnum); + // for (size_t i = 0; i < total_out_bitnum; ++i) { + // result[total_out_bitnum - i - 1] = wires_[wire_size - i - 1]; + // } + // YACL_ENFORCE(result.size() == total_out_bitnum); + // std::vector out(total_out_bitnum / 8); + // std::memcpy(out.data(), result.data(), out.size()); + // SPDLOG_INFO(result.to_string()); + // return out; - const size_t wire_size = wires_.size(); - dynamic_bitset result(total_out_bitnum); - for (size_t i = 0; i < total_out_bitnum; ++i) { - result[total_out_bitnum - i - 1] = wires_[wire_size - i - 1]; - } - YACL_ENFORCE(result.size() == total_out_bitnum); std::vector out(total_out_bitnum / 8); - std::memcpy(out.data(), result.data(), out.size()); + + size_t index = wires_.size(); + for (size_t i = 0; i < 32; ++i) { + dynamic_bitset result(8); + for (size_t j = 0; j < 8; ++j) { + result[j] = wires_[index - 8 + j]; + } + out[32 - i - 1] = *(uint8_t *)result.data(); + index -= 8; + } + std::reverse(out.begin(), out.end()); return out; } private: // NOTE: please make sure you use the correct order of wires - dynamic_bitset wires_; // shares + dynamic_bitset wires_; // shares std::shared_ptr circ_; // bristol fashion circuit }; diff --git a/yacl/engine/plaintext/executor_test.cc b/yacl/engine/plaintext/executor_test.cc index f51b6d9..c7deac5 100644 --- a/yacl/engine/plaintext/executor_test.cc +++ b/yacl/engine/plaintext/executor_test.cc @@ -211,21 +211,23 @@ TEST(CryptoTest, Sha256Test) { /* GIVEN */ auto input = crypto::FastRandBytes(crypto::RandLtN(10)); - std::string temp = "1"; - SPDLOG_INFO(absl::BytesToHexString( - ByteContainerView(io::BuiltinBFCircuit::PrepareSha256Input(temp)))); + // std::array temp = {'a', 'b', 'c'}; + // std::string temp = "1"; + std::array message = {'1'}; + auto in_buf = io::BuiltinBFCircuit::PrepareSha256Input(message); /* WHEN */ - // PlainExecutor exec; - // exec.LoadCircuitFile(io::BuiltinBFCircuit::Sha256Path()); - // exec.SetupInputBytes(io::BuiltinBFCircuit::PrepareSha256Input(input)); - // exec.Exec(); - // auto result = exec.FinalizeBytes(); + PlainExecutor exec; + exec.LoadCircuitFile(io::BuiltinBFCircuit::Sha256Path()); + exec.SetupInputBytes(in_buf); + exec.Exec(); + auto result = exec.FinalizeBytes(); /* THEN */ - // auto compare = crypto::Sha256Hash().Update(input).CumulativeHash(); - // SPDLOG_INFO(absl::BytesToHexString(ByteContainerView(result))); - // SPDLOG_INFO(absl::BytesToHexString(ByteContainerView(compare))); + auto compare = crypto::Sha256Hash().Update(message).CumulativeHash(); + SPDLOG_INFO(absl::BytesToHexString(ByteContainerView(result))); + SPDLOG_INFO(absl::BytesToHexString(ByteContainerView(compare))); + EXPECT_EQ(compare.size(), result.size()); } } // namespace yacl::engine diff --git a/yacl/io/circuit/BUILD.bazel b/yacl/io/circuit/BUILD.bazel index 707f93c..71f3977 100644 --- a/yacl/io/circuit/BUILD.bazel +++ b/yacl/io/circuit/BUILD.bazel @@ -28,6 +28,7 @@ yacl_cc_library( deps = [ "//yacl/io/stream:file_io", "//yacl/link:context", + "//yacl/utils/spi:type_traits", ], ) diff --git a/yacl/io/circuit/bristol_fashion.h b/yacl/io/circuit/bristol_fashion.h index 2748fed..5561541 100644 --- a/yacl/io/circuit/bristol_fashion.h +++ b/yacl/io/circuit/bristol_fashion.h @@ -21,12 +21,14 @@ #include #include +#include "absl/strings/escaping.h" #include "spdlog/spdlog.h" #include "yacl/base/byte_container_view.h" #include "yacl/base/exception.h" #include "yacl/io/stream/file_io.h" #include "yacl/io/stream/interface.h" +#include "yacl/utils/spi/type_traits.h" namespace yacl::io { @@ -137,48 +139,59 @@ class BuiltinBFCircuit { std::filesystem::current_path().string()); } - constexpr static std::array GetSha256InitialHashValues() { - return {0x6a, 0x09, 0xe6, 0x67, 0xbb, 0x67, 0xae, 0x85, 0x3c, 0x6e, 0xf3, - 0x72, 0xa5, 0x4f, 0xf5, 0x3a, 0x51, 0x0e, 0x52, 0x7f, 0x9b, 0x05, - 0x68, 0x8c, 0x1f, 0x83, 0xd9, 0xab, 0x5b, 0xe0, 0xcd, 0x19}; - } - + // Prepare (append & tweak) the input sha256 message before fed to the sha256 + // bristol circuit. + // + // For more details, please check: + // https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf + // + // NOTE since we are using dynamic_bitset for bristol format circuit + // representation, the actual bit operation here is slightly different from + // the standards. static std::vector PrepareSha256Input(ByteContainerView input) { - constexpr size_t kSha256FixPadSize = 1; // in bytes - constexpr size_t kSha256MessageBlockSize = 64; // in bytes + constexpr size_t kFixPadSize = 1; // in bytes + constexpr size_t kMsgLenSize = sizeof(uint64_t); // in bytes + constexpr size_t kMsgBlockSize = 64; // in bytes + auto kInitSha256Bytes = GetSha256InitialHashValues(); - uint64_t input_size = input.size(); - constexpr auto kInitSha256Bytes = GetSha256InitialHashValues(); + uint64_t input_size = input.size(); // in bits uint64_t zero_padding_size = - (input_size + kSha256FixPadSize) % kSha256MessageBlockSize == 0 + (input_size + kFixPadSize + kMsgLenSize) % kMsgBlockSize == 0 ? 0 - : kSha256MessageBlockSize - - (input_size + kSha256FixPadSize) % kSha256MessageBlockSize; - uint64_t message_size = input_size + kSha256FixPadSize + zero_padding_size; + : kMsgBlockSize - + (input_size + kFixPadSize + kMsgLenSize) % kMsgBlockSize; + uint64_t message_size = + input_size + kFixPadSize + zero_padding_size + kMsgLenSize; uint64_t result_size = message_size + kInitSha256Bytes.size(); - YACL_ENFORCE(message_size % kSha256MessageBlockSize == 0); + YACL_ENFORCE(message_size % kMsgBlockSize == 0); - // Declare the resut byte-vector + // Declare the result byte-vector + size_t offset = 0; std::vector result(result_size); - // original input message - size_t offset = kInitSha256Bytes.size(); - std::memcpy(result.data() + offset, input.data(), input_size); - - // additional padding (as a mark) - offset = kInitSha256Bytes.size() + input_size; - result[offset] = 0x80; + // the next 64 bits should be the byte length of input message + uint64_t input_bitnum = input_size * 8; // in bytes + std::memcpy(result.data() + offset, &input_bitnum, sizeof(input_bitnum)); + offset += sizeof(uint64_t); // zero padding (result vector has zero initialization) // ... should doing nothing ... + offset += zero_padding_size; - // the last 64 bits should be the byte length of input message - offset = kInitSha256Bytes.size() + input_size + kSha256FixPadSize; - std::memcpy(result.data() + offset, &input_size, sizeof(uint64_t)); + // additional padding bit-'1' (as a mark) + result[offset] = 0x80; + offset += kFixPadSize; + + // original input message + // auto input_reverse = ReverseBytes(absl::MakeSpan(input)); // copy here + std::memcpy(result.data() + offset, input.data(), input_size); + offset += input_size; // initial hash values - std::memcpy(result.data(), kInitSha256Bytes.data(), kInitSha256Bytes.size()); + std::memcpy(result.data() + offset, kInitSha256Bytes.data(), + kInitSha256Bytes.size()); + offset += kInitSha256Bytes.size(); return result; } @@ -200,6 +213,16 @@ class BuiltinBFCircuit { return fmt::format("{}/yacl/io/circuit/data/sha256.txt", std::filesystem::current_path().string()); } + + + static std::array GetSha256InitialHashValues() { + std::array standard_init_array = { + 0x6a, 0x09, 0xe6, 0x67, 0xbb, 0x67, 0xae, 0x85, 0x3c, 0x6e, 0xf3, + 0x72, 0xa5, 0x4f, 0xf5, 0x3a, 0x51, 0x0e, 0x52, 0x7f, 0x9b, 0x05, + 0x68, 0x8c, 0x1f, 0x83, 0xd9, 0xab, 0x5b, 0xe0, 0xcd, 0x19}; + std::reverse(standard_init_array.begin(), standard_init_array.end()); + return standard_init_array; + } }; } // namespace yacl::io