Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat+refactor: remove allocation from hashing #82

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
43 changes: 23 additions & 20 deletions plonky2/src/hash/field_merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;

use itertools::Itertools;
use itertools::{chain, Itertools};

use crate::hash::hash_types::{RichField, NUM_HASH_OUT_ELTS};
use crate::hash::hash_types::RichField;
use crate::hash::merkle_proofs::MerkleProof;
use crate::hash::merkle_tree::{
capacity_up_to_mut, fill_digests_buf, merkle_tree_prove, MerkleCap,
capacity_up_to_mut, fill_digests_buf, fill_digests_buf_custom, merkle_tree_prove, MerkleCap,
};
use crate::plonk::config::{GenericHashOut, Hasher};
use crate::util::log2_strict;
Expand Down Expand Up @@ -83,24 +83,17 @@ impl<F: RichField, H: Hasher<F>> FieldMerkleTree<F, H> {
);
} else {
// The rest leaf layers
let new_leaves: Vec<Vec<F>> = cap
.iter()
.enumerate()
.map(|(i, cap_hash)| {
let mut new_hash = Vec::with_capacity(NUM_HASH_OUT_ELTS + cur[i].len());
new_hash.extend(&cap_hash.to_vec());
new_hash.extend(&cur[i]);
new_hash
})
.collect();
cap.clear();
cap.reserve_exact(next_cap_len);
let new_leaves = cap;
cap = Vec::with_capacity(next_cap_len);
let tmp_cap_buf = capacity_up_to_mut(&mut cap, next_cap_len);
fill_digests_buf::<F, H>(
fill_digests_buf_custom::<F, H, _, _>(
&mut digests_buf[digests_buf_pos..(digests_buf_pos + num_tmp_digests)],
tmp_cap_buf,
&new_leaves[..],
next_cap_height,
|i, cap_hash| {
H::hash_or_noop_iter(chain!(cap_hash.into_iter(), cur[i].iter().copied()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This smells like it wants to be a zip or so?

Copy link
Author

@Daniel-Aaron-Bloom Daniel-Aaron-Bloom Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zip? Why? The goal is to concatenate them before hashing. Mirroring the previous code which used extend

},
);
}

Expand Down Expand Up @@ -214,7 +207,10 @@ mod tests {
assert_eq!(layer_1, fmt.digests[2..4]);

let root = H::two_to_one(layer_1[0], layer_1[1]);
assert_eq!(fmt.cap.flatten(), root.to_vec());
assert_eq!(
fmt.cap.flatten().collect_vec(),
root.into_iter().collect_vec()
);

let proof = fmt.open_batch(2);
assert_eq!(proof.siblings, [mat_1_leaf_hashes[3], layer_1[0]]);
Expand Down Expand Up @@ -259,8 +255,12 @@ mod tests {
assert_eq!(mat_1_leaf_hashes, fmt.digests[0..4]);

let hidden_layer = [
H::two_to_one(mat_1_leaf_hashes[0], mat_1_leaf_hashes[1]).to_vec(),
H::two_to_one(mat_1_leaf_hashes[2], mat_1_leaf_hashes[3]).to_vec(),
H::two_to_one(mat_1_leaf_hashes[0], mat_1_leaf_hashes[1])
.into_iter()
.collect_vec(),
H::two_to_one(mat_1_leaf_hashes[2], mat_1_leaf_hashes[3])
.into_iter()
.collect_vec(),
];
let new_leaves = hidden_layer
.iter()
Expand All @@ -278,7 +278,10 @@ mod tests {
assert_eq!(layer_1, fmt.digests[4..]);

let root = H::two_to_one(layer_1[0], layer_1[1]);
assert_eq!(fmt.cap.flatten(), root.to_vec());
assert_eq!(
fmt.cap.flatten().collect_vec(),
root.into_iter().collect_vec()
);

let proof = fmt.open_batch(1);
assert_eq!(proof.siblings, [mat_1_leaf_hashes[0], layer_1[1]]);
Expand Down
74 changes: 48 additions & 26 deletions plonky2/src/hash/hash_types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use core::borrow::BorrowMut;

use anyhow::ensure;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
Expand Down Expand Up @@ -83,27 +84,41 @@ where
}

impl<F: RichField> GenericHashOut<F> for HashOut<F> {
fn to_bytes(&self) -> Vec<u8> {
self.elements
.into_iter()
.flat_map(|x| x.to_canonical_u64().to_le_bytes())
.collect()
fn to_bytes(self) -> impl AsRef<[u8]> + AsMut<[u8]> + BorrowMut<[u8]> + Copy {
let mut bytes = [0u8; NUM_HASH_OUT_ELTS * 8];
for (i, x) in self.elements.into_iter().enumerate() {
let i = i * 8;
bytes[i..i + 8].copy_from_slice(&x.to_canonical_u64().to_le_bytes())
}
bytes
}

fn from_bytes(bytes: &[u8]) -> Self {
let mut bytes = bytes
.chunks(8)
.take(NUM_HASH_OUT_ELTS)
.map(|x| F::from_canonical_u64(u64::from_le_bytes(x.try_into().unwrap())));
HashOut {
elements: [(); NUM_HASH_OUT_ELTS].map(|()| bytes.next().unwrap()),
}
}

fn from_byte_iter(mut bytes: impl Iterator<Item = u8>) -> Self {
let bytes = [[(); 8]; NUM_HASH_OUT_ELTS].map(|b| b.map(|()| bytes.next().unwrap()));

HashOut {
elements: bytes.map(|x| F::from_canonical_u64(u64::from_le_bytes(x))),
}
}

fn from_iter(mut inputs: impl Iterator<Item = F>) -> Self {
HashOut {
elements: bytes
.chunks(8)
.take(NUM_HASH_OUT_ELTS)
.map(|x| F::from_canonical_u64(u64::from_le_bytes(x.try_into().unwrap())))
.collect::<Vec<_>>()
.try_into()
.unwrap(),
elements: [(); NUM_HASH_OUT_ELTS].map(|()| inputs.next().unwrap()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this better than .collect::<Vec<_>>.try_into().unwrap() or something like that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Better" is a matter of opinion. This approach avoids allocation at the cost of being pretty ugly.

}
}

fn to_vec(&self) -> Vec<F> {
self.elements.to_vec()
fn into_iter(self) -> impl Iterator<Item = F> {
Copy link
Collaborator

@matthiasgoergens matthiasgoergens Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a trait implementation for the IntoIterator trait or so?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you're asking. This is using "return position impl Trait" to enable unique return types between the different implementations without associated types.

self.elements.into_iter()
}
}

Expand Down Expand Up @@ -172,24 +187,31 @@ impl<const N: usize> Sample for BytesHash<N> {
}

impl<F: RichField, const N: usize> GenericHashOut<F> for BytesHash<N> {
fn to_bytes(&self) -> Vec<u8> {
self.0.to_vec()
fn to_bytes(self) -> impl AsRef<[u8]> + AsMut<[u8]> + BorrowMut<[u8]> + Copy {
self.0
}

fn from_bytes(bytes: &[u8]) -> Self {
Self(bytes.try_into().unwrap())
}

fn to_vec(&self) -> Vec<F> {
self.0
// Chunks of 7 bytes since 8 bytes would allow collisions.
.chunks(7)
.map(|bytes| {
let mut arr = [0; 8];
arr[..bytes.len()].copy_from_slice(bytes);
F::from_canonical_u64(u64::from_le_bytes(arr))
})
.collect()
fn from_byte_iter(mut bytes: impl Iterator<Item = u8>) -> Self {
Self(core::array::from_fn(|_| bytes.next().unwrap()))
}

fn into_iter(self) -> impl Iterator<Item = F> {
// Chunks of 7 bytes since 8 bytes would allow collisions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dimdumon Please check with your work. Perhaps you can piggy-back on this, instead of re-implementing so much of your own packing?

const STRIDE: usize = 7;

(0..N).step_by(STRIDE).map(move |i| {
let mut bytes = &self.0[i..];
if bytes.len() > STRIDE {
bytes = &bytes[..STRIDE];
}
let mut arr = [0; 8];
arr[..bytes.len()].copy_from_slice(bytes);
F::from_canonical_u64(u64::from_le_bytes(arr))
})
}
}

Expand Down
44 changes: 44 additions & 0 deletions plonky2/src/hash/hashing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use core::fmt::Debug;
use core::iter::repeat_with;

use itertools::chain;

use crate::field::extension::Extendable;
use crate::field::types::Field;
Expand Down Expand Up @@ -91,6 +94,9 @@ pub trait PlonkyPermutation<T: Copy + Default>:

/// Return a slice of `RATE` elements
fn squeeze(&self) -> &[T];

/// Return an array of `RATE` elements
fn squeeze_iter(self) -> impl IntoIterator<Item = T> + Copy;
}

/// A one-way compression function which takes two ~256 bit inputs and returns a ~256 bit output.
Expand Down Expand Up @@ -140,6 +146,44 @@ pub fn hash_n_to_m_no_pad<F: RichField, P: PlonkyPermutation<F>>(
}
}

/// Hash a message without any padding step. Note that this can enable length-extension attacks.
/// However, it is still collision-resistant in cases where the input has a fixed length.
pub fn hash_n_to_m_no_pad_iter<F: RichField, P: PlonkyPermutation<F>, I: IntoIterator<Item = F>>(
inputs: I,
) -> impl Iterator<Item = F> {
let mut perm = P::new(core::iter::repeat(F::ZERO));

// Absorb all input chunks.
let mut inputs = inputs.into_iter().peekable();
while inputs.peek().is_some() {
let input_chunk = inputs.by_ref().take(P::RATE);
perm.set_from_iter(input_chunk, 0);
perm.permute();
}

chain!(
[perm.squeeze_iter()],
repeat_with(move || {
perm.permute();
perm.squeeze_iter()
})
)
.flatten()
}

pub fn hash_n_to_hash_no_pad<F: RichField, P: PlonkyPermutation<F>>(inputs: &[F]) -> HashOut<F> {
HashOut::from_vec(hash_n_to_m_no_pad::<F, P>(inputs, NUM_HASH_OUT_ELTS))
}

pub fn hash_n_to_hash_no_pad_iter<
F: RichField,
P: PlonkyPermutation<F>,
I: IntoIterator<Item = F>,
>(
inputs: I,
) -> HashOut<F> {
let mut elements = hash_n_to_m_no_pad_iter::<F, P, I>(inputs);
HashOut {
elements: core::array::from_fn(|_| elements.next().unwrap()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why with from_fn here and via eg Self([(); N].map(|()| bytes.next().unwrap())) elsewhere?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just being silly, I guess.

}
}
40 changes: 24 additions & 16 deletions plonky2/src/hash/keccak.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
use core::borrow::Borrow;
use core::mem::size_of;

use itertools::Itertools;
Expand Down Expand Up @@ -60,23 +59,25 @@ impl<F: RichField> PlonkyPermutation<F> for KeccakPermutation<F> {
}

fn permute(&mut self) {
let mut state_bytes = vec![0u8; SPONGE_WIDTH * size_of::<u64>()];
for i in 0..SPONGE_WIDTH {
let mut state_bytes = [0u8; SPONGE_WIDTH * size_of::<u64>()];
for (i, x) in self.state.iter().enumerate() {
state_bytes[i * size_of::<u64>()..(i + 1) * size_of::<u64>()]
.copy_from_slice(&self.state[i].to_canonical_u64().to_le_bytes());
.copy_from_slice(&x.to_canonical_u64().to_le_bytes());
}

let hash_onion = core::iter::repeat_with(|| {
let output = keccak(state_bytes.clone()).0;
state_bytes = output.to_vec();
output
let hash_onion = (0..).scan(keccak(state_bytes), |state, _| {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you use scan here and, repeat_with elsewhere?

(This one could be repeat_with and take. Or the other version could turn into a scan.)

successors might also work, and would probably be the cleanest, if it does.

let output = state.0;
*state = keccak(output);
Some(output)
});

let hash_onion_u64s = hash_onion.flat_map(|output| {
output
.chunks_exact(size_of::<u64>())
.map(|word| u64::from_le_bytes(word.try_into().unwrap()))
.collect_vec()
const STRIDE: usize = size_of::<u64>();

(0..32).step_by(STRIDE).map(move |i| {
let bytes = output[i..].first_chunk::<STRIDE>().unwrap();
u64::from_le_bytes(*bytes)
})
});

// Parse field elements from u64 stream, using rejection sampling such that words that don't
Expand All @@ -95,6 +96,12 @@ impl<F: RichField> PlonkyPermutation<F> for KeccakPermutation<F> {
fn squeeze(&self) -> &[F] {
&self.state[..Self::RATE]
}

fn squeeze_iter(self) -> impl IntoIterator<Item = F> + Copy {
let mut vals = [F::default(); SPONGE_RATE];
vals.copy_from_slice(self.squeeze());
vals
}
}

/// Keccak-256 hash function.
Expand All @@ -105,12 +112,13 @@ impl<F: RichField, const N: usize> Hasher<F> for KeccakHash<N> {
type Hash = BytesHash<N>;
type Permutation = KeccakPermutation<F>;

fn hash_no_pad(input: &[F]) -> Self::Hash {
fn hash_no_pad_iter<I: IntoIterator<Item = F>>(input: I) -> Self::Hash {
let mut keccak256 = Keccak::v256();
for x in input.iter() {
let b = x.to_canonical_u64().to_le_bytes();
for x in input.into_iter() {
let b = x.borrow().to_canonical_u64().to_le_bytes();
keccak256.update(&b);
}

let mut hash_bytes = [0u8; 32];
keccak256.finalize(&mut hash_bytes);

Expand Down
10 changes: 6 additions & 4 deletions plonky2/src/hash/merkle_proofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use alloc::{vec, vec::Vec};

use anyhow::{ensure, Result};
use itertools::Itertools;
use itertools::{chain, Itertools};
use serde::{Deserialize, Serialize};

use crate::field::extension::Extendable;
Expand Down Expand Up @@ -91,9 +91,11 @@ pub fn verify_field_merkle_proof_to_cap<F: RichField, H: Hasher<F>>(
let mut leaf_data_index = 1;
for &sibling_digest in proof.siblings.iter() {
if leaf_data_index < leaf_heights.len() && current_height == leaf_heights[leaf_data_index] {
let mut new_leaves = current_digest.to_vec();
new_leaves.extend_from_slice(&leaf_data[leaf_data_index]);
current_digest = H::hash_or_noop(&new_leaves);
let new_leaves = chain!(
current_digest.into_iter(),
leaf_data[leaf_data_index].iter().copied(),
);
current_digest = H::hash_or_noop_iter(new_leaves);
leaf_data_index += 1;
}

Expand Down
Loading