Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix poseidon sponge bug #148

Merged
merged 2 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions crypto-primitives/src/sponge/poseidon/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,13 @@ impl<F: PrimeField> PoseidonSpongeVar<F> {
..(self.parameters.capacity + num_elements_squeezed + rate_start_index)],
);

// Repeat with updated output slices and rate start index
remaining_output = &mut remaining_output[num_elements_squeezed..];

// Unless we are done with squeezing in this call, permute.
if remaining_output.len() != self.parameters.rate {
if !remaining_output.is_empty() {
self.permute()?;
}
// Repeat with updated output slices and rate start index
remaining_output = &mut remaining_output[num_elements_squeezed..];
rate_start_index = 0;
}
}
Expand Down
7 changes: 4 additions & 3 deletions crypto-primitives/src/sponge/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,13 @@ impl<F: PrimeField> PoseidonSponge<F> {
..(self.parameters.capacity + num_elements_squeezed + rate_start_index)],
);

// Repeat with updated output slices
output_remaining = &mut output_remaining[num_elements_squeezed..];
// Unless we are done with squeezing in this call, permute.
if output_remaining.len() != self.parameters.rate {
if !output_remaining.is_empty() {
self.permute();
}
// Repeat with updated output slices
output_remaining = &mut output_remaining[num_elements_squeezed..];

rate_start_index = 0;
}
}
Expand Down
235 changes: 234 additions & 1 deletion crypto-primitives/src/sponge/poseidon/tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,243 @@
use crate::sponge::poseidon::{PoseidonConfig, PoseidonSponge};
use crate::sponge::poseidon::{PoseidonConfig, PoseidonDefaultConfigField, PoseidonSponge};
use crate::sponge::test::Fr;
use crate::sponge::{Absorb, AbsorbWithLength, CryptographicSponge, FieldBasedCryptographicSponge};
use crate::{absorb, collect_sponge_bytes, collect_sponge_field_elements};
use ark_ff::{One, PrimeField, UniformRand};
use ark_std::test_rng;

#[test]
// Remove once this PR matures
fn demo_bug() {
let sponge_params = Fr::get_default_poseidon_parameters(2, false).unwrap();

let rng = &mut test_rng();
let input = (0..3).map(|_| Fr::rand(rng)).collect::<Vec<_>>();

// works good
let e0 = {
let mut sponge = PoseidonSponge::<Fr>::new(&sponge_params);
sponge.absorb(&input);
sponge.squeeze_native_field_elements(3)
};

// works good
let e1 = {
let mut sponge = PoseidonSponge::<Fr>::new(&sponge_params);
sponge.absorb(&input);
let e0 = sponge.squeeze_native_field_elements(1);
let e1 = sponge.squeeze_native_field_elements(1);
let e2 = sponge.squeeze_native_field_elements(1);
e0.iter()
.chain(e1.iter())
.chain(e2.iter())
.cloned()
.collect::<Vec<_>>()
};

// also works good
let e2 = {
let mut sponge = PoseidonSponge::<Fr>::new(&sponge_params);
sponge.absorb(&input);

let e0 = sponge.squeeze_native_field_elements(2);
let e1 = sponge.squeeze_native_field_elements(1);
e0.iter().chain(e1.iter()).cloned().collect::<Vec<_>>()
};

// skips a permutation if sponge
// * in squeezing mode
// * number of elements are equal to rate
let e3 = {
let mut sponge = PoseidonSponge::<Fr>::new(&sponge_params);
sponge.absorb(&input);
let e0 = sponge.squeeze_native_field_elements(1);
let e1 = sponge.squeeze_native_field_elements(2);
e0.iter().chain(e1.iter()).cloned().collect::<Vec<_>>()
};

assert_eq!(e0, e1);
assert_eq!(e0, e2);
assert_eq!(e0, e3); // this will fail
}

// Remove once this PR matures
fn run_cross_test<F: PrimeField + Absorb>(cfg: &PoseidonConfig<F>) {
#[derive(Debug, PartialEq, Eq)]
enum SpongeMode {
Absorbing,
Squeezing,
}

#[derive(Clone, Debug)]
struct Reference<F: PrimeField> {
cfg: PoseidonConfig<F>,
state: Vec<F>,
absorbing: Vec<F>,
squeeze_count: Option<usize>,
}

// workaround to permute a state
fn permute<F: PrimeField>(cfg: &PoseidonConfig<F>, state: &mut [F]) {
let mut sponge = PoseidonSponge::new(&cfg);
sponge.state.copy_from_slice(state);
sponge.permute();
state.copy_from_slice(&sponge.state)
}

impl<F: PrimeField> Reference<F> {
fn new(cfg: &PoseidonConfig<F>) -> Self {
let t = cfg.rate + cfg.capacity;
let state = vec![F::zero(); t];
Self {
cfg: cfg.clone(),
state,
absorbing: Vec::new(),
squeeze_count: None,
}
}

fn mode(&self) -> SpongeMode {
match self.squeeze_count {
Some(_) => {
assert!(self.absorbing.is_empty());
SpongeMode::Squeezing
}
None => SpongeMode::Absorbing,
}
}

fn absorb(&mut self, input: &[F]) {
if !input.is_empty() {
match self.mode() {
SpongeMode::Absorbing => self.absorbing.extend_from_slice(input),
SpongeMode::Squeezing => {
// Wash the state as mode changes
// This is not appied in SAFE sponge
permute(&self.cfg, &mut self.state);
// Append inputs to the absorbing line
self.absorbing.extend_from_slice(input);
// Change mode to absorbing
self.squeeze_count = None;
}
}
}
}

fn _absorb(&mut self) {
let rate = self.cfg.rate;
self.absorbing.chunks(rate).for_each(|chunk| {
self.state
.iter_mut()
.skip(self.cfg.capacity)
.zip(chunk.iter())
.for_each(|(s, c)| *s += *c);
permute(&self.cfg, &mut self.state);
});

// This case can only happen in the begining when the absorbing line is empty
// and user wants to squeeze elements. Notice that after moving to squueze mode
// if user calls absorb again with empty input it will be ignored
self.absorbing
.is_empty()
.then(|| permute(&self.cfg, &mut self.state));

// flush the absorbing line
self.absorbing.clear();

// Change to the squeezing mode
assert_eq!(self.mode(), SpongeMode::Absorbing);
self.squeeze_count = Some(0);
}

pub fn squeeze(&mut self, n: usize) -> Vec<F> {
match self.mode() {
SpongeMode::Absorbing => self._absorb(),
SpongeMode::Squeezing => {
assert!(self.absorbing.is_empty());
assert!(self.squeeze_count.is_some());

// ???
// **This seems nonsense to me**
// If,
// * number of squeeze is zero AND
// * in squeezing mode AND
// * output index is is at `rate`
// it applies a useless permutation.
// This is also not appied in SAFE sponge

if n == 0 {
let squeeze_count = self.squeeze_count.unwrap();
let out_index = self.squeeze_count.unwrap() % self.cfg.rate;
(out_index == 0 && squeeze_count != 0).then(|| {
permute(&self.cfg, &mut self.state);
self.squeeze_count = Some(0);
});
}
}
}

let rate = self.cfg.rate;
let mut output = Vec::new();
for _ in 0..n {
let squeeze_count = self.squeeze_count.unwrap();
let out_index = squeeze_count % rate;

// proceed with a permutation if
// * the rate is full
// * and it is not the first output
(out_index == 0 && squeeze_count != 0).then(|| permute(&self.cfg, &mut self.state));

// skip the capacity elements
let out_index = out_index + self.cfg.capacity;
output.push(self.state[out_index]);
self.squeeze_count.as_mut().map(|c| *c += 1);
}

output
}
}

let mut sponge = PoseidonSponge::new(cfg);
let mut sponge_ref = Reference::new(cfg);
let mut rng = test_rng();

for _ in 0..1000 {
let test = (0..100)
.map(|_| {
use crate::ark_std::rand::Rng;
let do_absorb = rng.gen_bool(0.5);
let do_squeeze = rng.gen_bool(0.5);

(
(do_absorb, rng.gen_range(0..=cfg.rate * 2 + 1)),
(do_squeeze, rng.gen_range(0..=cfg.rate * 2 + 1)),
)
})
.collect::<Vec<_>>();

// fuzz fuzz
for (_i, ((do_absorb, n_absorb), (do_squeeze, n_squeeze))) in test.into_iter().enumerate() {
do_absorb.then(|| {
let inputs = (0..n_absorb).map(|_| F::rand(&mut rng)).collect::<Vec<_>>();
sponge_ref.absorb(&inputs);
sponge.absorb(&inputs);
});
do_squeeze.then(|| {
let out0 = sponge_ref.squeeze(n_squeeze);
let out1 = sponge.squeeze_field_elements(n_squeeze);
assert_eq!(out0, out1);
});
}
}
}

#[test]
// Remove once this PR matures
fn test_cross() {
let cfg = Fr::get_default_poseidon_parameters(2, false).unwrap();
run_cross_test::<Fr>(&cfg);
}

fn assert_different_encodings<F: PrimeField, A: Absorb>(a: &A, b: &A) {
let bytes1 = a.to_sponge_bytes_as_vec();
let bytes2 = b.to_sponge_bytes_as_vec();
Expand Down
Loading