Skip to content

Commit

Permalink
feat: now sha256 works
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamie-Cui committed Nov 22, 2024
1 parent 41120af commit d3766df
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 51 deletions.
47 changes: 34 additions & 13 deletions yacl/engine/plaintext/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ namespace yacl::engine {

class PlainExecutor {
public:
using BlockType = uint8_t;

// Constructor
explicit PlainExecutor() = default;

Expand All @@ -38,17 +40,21 @@ class PlainExecutor {
void SetupInputs(absl::Span<T> inputs) {
YACL_ENFORCE(inputs.size() == circ_->niv);

dynamic_bitset<uint128_t> input_wires;
dynamic_bitset<BlockType> input_wires;
input_wires.resize(sizeof(T) * 8 * inputs.size());
std::memcpy(input_wires.data(), inputs.data(), inputs.size() * sizeof(T));
wires_.append(input_wires);
wires_.resize(circ_->nw);
}

// Setup the input wire
//
// NOTE internally this function simply copies the memory of bytes to internal
// dynamic_bitset
void SetupInputBytes(ByteContainerView bytes) {
wires_.resize(circ_->nw);
wires_.resize(bytes.size() * 8);
std::memcpy(wires_.data(), bytes.data(), bytes.size());
wires_.resize(circ_->nw);
}

// Execute the circuit
Expand All @@ -60,7 +66,7 @@ class PlainExecutor {
YACL_ENFORCE(outputs.size() >= circ_->nov);
size_t index = wires_.size();
for (size_t i = 0; i < circ_->nov; ++i) {
dynamic_bitset<T> result(circ_->now[i]);
dynamic_bitset<BlockType> result(circ_->now[i]);
for (size_t j = 0; j < circ_->now[i]; ++j) {
result[j] = wires_[index - circ_->now[i] + j];
}
Expand All @@ -76,23 +82,38 @@ class PlainExecutor {
total_out_bitnum += circ_->now[i];
}

// Make sure that the circuit output wire is full bytes
YACL_ENFORCE(total_out_bitnum % 8 == 0);
// // Make sure that the circuit output wire is full bytes
// YACL_ENFORCE(total_out_bitnum % 8 == 0);

// const size_t wire_size = wires_.size();
// dynamic_bitset<BlockType> result(total_out_bitnum);
// for (size_t i = 0; i < total_out_bitnum; ++i) {
// result[total_out_bitnum - i - 1] = wires_[wire_size - i - 1];
// }
// YACL_ENFORCE(result.size() == total_out_bitnum);
// std::vector<uint8_t> out(total_out_bitnum / 8);
// std::memcpy(out.data(), result.data(), out.size());
// SPDLOG_INFO(result.to_string());
// return out;

const size_t wire_size = wires_.size();
dynamic_bitset<uint128_t> result(total_out_bitnum);
for (size_t i = 0; i < total_out_bitnum; ++i) {
result[total_out_bitnum - i - 1] = wires_[wire_size - i - 1];
}
YACL_ENFORCE(result.size() == total_out_bitnum);
std::vector<uint8_t> out(total_out_bitnum / 8);
std::memcpy(out.data(), result.data(), out.size());

size_t index = wires_.size();
for (size_t i = 0; i < 32; ++i) {
dynamic_bitset<BlockType> result(8);
for (size_t j = 0; j < 8; ++j) {
result[j] = wires_[index - 8 + j];
}
out[32 - i - 1] = *(uint8_t *)result.data();
index -= 8;
}
std::reverse(out.begin(), out.end());
return out;
}

private:
// NOTE: please make sure you use the correct order of wires
dynamic_bitset<uint128_t> wires_; // shares
dynamic_bitset<BlockType> wires_; // shares
std::shared_ptr<io::BFCircuit> circ_; // bristol fashion circuit
};

Expand Down
24 changes: 13 additions & 11 deletions yacl/engine/plaintext/executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,21 +211,23 @@ TEST(CryptoTest, Sha256Test) {
/* GIVEN */
auto input = crypto::FastRandBytes(crypto::RandLtN(10));

std::string temp = "1";
SPDLOG_INFO(absl::BytesToHexString(
ByteContainerView(io::BuiltinBFCircuit::PrepareSha256Input(temp))));
// std::array<char, 3> temp = {'a', 'b', 'c'};
// std::string temp = "1";
std::array<char, 1> message = {'1'};
auto in_buf = io::BuiltinBFCircuit::PrepareSha256Input(message);

/* WHEN */
// PlainExecutor exec;
// exec.LoadCircuitFile(io::BuiltinBFCircuit::Sha256Path());
// exec.SetupInputBytes(io::BuiltinBFCircuit::PrepareSha256Input(input));
// exec.Exec();
// auto result = exec.FinalizeBytes();
PlainExecutor exec;
exec.LoadCircuitFile(io::BuiltinBFCircuit::Sha256Path());
exec.SetupInputBytes(in_buf);
exec.Exec();
auto result = exec.FinalizeBytes();

/* THEN */
// auto compare = crypto::Sha256Hash().Update(input).CumulativeHash();
// SPDLOG_INFO(absl::BytesToHexString(ByteContainerView(result)));
// SPDLOG_INFO(absl::BytesToHexString(ByteContainerView(compare)));
auto compare = crypto::Sha256Hash().Update(message).CumulativeHash();
SPDLOG_INFO(absl::BytesToHexString(ByteContainerView(result)));
SPDLOG_INFO(absl::BytesToHexString(ByteContainerView(compare)));
EXPECT_EQ(compare.size(), result.size());
}

} // namespace yacl::engine
1 change: 1 addition & 0 deletions yacl/io/circuit/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ yacl_cc_library(
deps = [
"//yacl/io/stream:file_io",
"//yacl/link:context",
"//yacl/utils/spi:type_traits",
],
)

Expand Down
77 changes: 50 additions & 27 deletions yacl/io/circuit/bristol_fashion.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
#include <string>
#include <vector>

#include "absl/strings/escaping.h"
#include "spdlog/spdlog.h"

#include "yacl/base/byte_container_view.h"
#include "yacl/base/exception.h"
#include "yacl/io/stream/file_io.h"
#include "yacl/io/stream/interface.h"
#include "yacl/utils/spi/type_traits.h"

namespace yacl::io {

Expand Down Expand Up @@ -137,48 +139,59 @@ class BuiltinBFCircuit {
std::filesystem::current_path().string());
}

constexpr static std::array<uint8_t, 32> GetSha256InitialHashValues() {
return {0x6a, 0x09, 0xe6, 0x67, 0xbb, 0x67, 0xae, 0x85, 0x3c, 0x6e, 0xf3,
0x72, 0xa5, 0x4f, 0xf5, 0x3a, 0x51, 0x0e, 0x52, 0x7f, 0x9b, 0x05,
0x68, 0x8c, 0x1f, 0x83, 0xd9, 0xab, 0x5b, 0xe0, 0xcd, 0x19};
}

// Prepare (append & tweak) the input sha256 message before fed to the sha256
// bristol circuit.
//
// For more details, please check:
// https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf
//
// NOTE since we are using dynamic_bitset for bristol format circuit
// representation, the actual bit operation here is slightly different from
// the standards.
static std::vector<uint8_t> PrepareSha256Input(ByteContainerView input) {
constexpr size_t kSha256FixPadSize = 1; // in bytes
constexpr size_t kSha256MessageBlockSize = 64; // in bytes
constexpr size_t kFixPadSize = 1; // in bytes
constexpr size_t kMsgLenSize = sizeof(uint64_t); // in bytes
constexpr size_t kMsgBlockSize = 64; // in bytes
auto kInitSha256Bytes = GetSha256InitialHashValues();

uint64_t input_size = input.size();
constexpr auto kInitSha256Bytes = GetSha256InitialHashValues();
uint64_t input_size = input.size(); // in bits
uint64_t zero_padding_size =
(input_size + kSha256FixPadSize) % kSha256MessageBlockSize == 0
(input_size + kFixPadSize + kMsgLenSize) % kMsgBlockSize == 0
? 0
: kSha256MessageBlockSize -
(input_size + kSha256FixPadSize) % kSha256MessageBlockSize;
uint64_t message_size = input_size + kSha256FixPadSize + zero_padding_size;
: kMsgBlockSize -
(input_size + kFixPadSize + kMsgLenSize) % kMsgBlockSize;
uint64_t message_size =
input_size + kFixPadSize + zero_padding_size + kMsgLenSize;
uint64_t result_size = message_size + kInitSha256Bytes.size();

YACL_ENFORCE(message_size % kSha256MessageBlockSize == 0);
YACL_ENFORCE(message_size % kMsgBlockSize == 0);

// Declare the resut byte-vector
// Declare the result byte-vector
size_t offset = 0;
std::vector<uint8_t> result(result_size);

// original input message
size_t offset = kInitSha256Bytes.size();
std::memcpy(result.data() + offset, input.data(), input_size);

// additional padding (as a mark)
offset = kInitSha256Bytes.size() + input_size;
result[offset] = 0x80;
// the next 64 bits should be the byte length of input message
uint64_t input_bitnum = input_size * 8; // in bytes
std::memcpy(result.data() + offset, &input_bitnum, sizeof(input_bitnum));
offset += sizeof(uint64_t);

// zero padding (result vector has zero initialization)
// ... should doing nothing ...
offset += zero_padding_size;

// the last 64 bits should be the byte length of input message
offset = kInitSha256Bytes.size() + input_size + kSha256FixPadSize;
std::memcpy(result.data() + offset, &input_size, sizeof(uint64_t));
// additional padding bit-'1' (as a mark)
result[offset] = 0x80;
offset += kFixPadSize;

// original input message
// auto input_reverse = ReverseBytes(absl::MakeSpan(input)); // copy here
std::memcpy(result.data() + offset, input.data(), input_size);
offset += input_size;

// initial hash values
std::memcpy(result.data(), kInitSha256Bytes.data(), kInitSha256Bytes.size());
std::memcpy(result.data() + offset, kInitSha256Bytes.data(),
kInitSha256Bytes.size());
offset += kInitSha256Bytes.size();

return result;
}
Expand All @@ -200,6 +213,16 @@ class BuiltinBFCircuit {
return fmt::format("{}/yacl/io/circuit/data/sha256.txt",
std::filesystem::current_path().string());
}


static std::array<uint8_t, 32> GetSha256InitialHashValues() {
std::array<uint8_t, 32> standard_init_array = {
0x6a, 0x09, 0xe6, 0x67, 0xbb, 0x67, 0xae, 0x85, 0x3c, 0x6e, 0xf3,
0x72, 0xa5, 0x4f, 0xf5, 0x3a, 0x51, 0x0e, 0x52, 0x7f, 0x9b, 0x05,
0x68, 0x8c, 0x1f, 0x83, 0xd9, 0xab, 0x5b, 0xe0, 0xcd, 0x19};
std::reverse(standard_init_array.begin(), standard_init_array.end());
return standard_init_array;
}
};

} // namespace yacl::io

0 comments on commit d3766df

Please sign in to comment.