diff --git a/poseidon2-starky/src/plonky2/generation.rs b/poseidon2-starky/src/plonky2/generation.rs index 1694891..191e9ff 100644 --- a/poseidon2-starky/src/plonky2/generation.rs +++ b/poseidon2-starky/src/plonky2/generation.rs @@ -96,40 +96,61 @@ fn generate_outputs(preimage: &[Field; STATE_SIZE]) -> [Field; /// Function to generate the Poseidon2 trace pub fn generate_poseidon2_trace(step_rows: &Vec>) -> [Vec; NUM_COLS] { - let trace_len = step_rows.len(); + let mut trace_len = step_rows.len(); + if trace_len == 0 { + trace_len = 1; + } let mut trace: Vec> = vec![vec![F::ZERO; trace_len]; NUM_COLS]; - for (i, row) in step_rows.iter().enumerate() { - trace[COL_IS_EXE][i] = F::ONE; + let mut add_rows = |step_rows: &Vec>, is_exe: bool| { + for (i, row) in step_rows.iter().enumerate() { + if is_exe { + trace[COL_IS_EXE][i] = F::ONE; + } else { + trace[COL_IS_EXE][i] = F::ZERO; + } - for j in 0..STATE_SIZE { - trace[COL_INPUT_START + j][i] = row.preimage[j]; - } - let outputs = generate_outputs(&row.preimage); - for j in 0..STATE_SIZE { - trace[COL_OUTPUT_START + j][i] = outputs[j]; - } + for j in 0..STATE_SIZE { + trace[COL_INPUT_START + j][i] = row.preimage[j]; + } + let outputs = generate_outputs(&row.preimage); + for j in 0..STATE_SIZE { + trace[COL_OUTPUT_START + j][i] = outputs[j]; + } - let first_full_round_state = generate_1st_full_round_state(&row.preimage); - let partial_round_state = generate_partial_round_state( - first_full_round_state.last().unwrap().try_into().unwrap(), - ); - let second_full_round_state = - generate_2st_full_round_state(partial_round_state.last().unwrap().try_into().unwrap()); - for j in 0..(ROUNDS_F / 2) { - for k in 0..STATE_SIZE { - trace[COL_1ST_FULLROUND_STATE_START + j * STATE_SIZE + k][i] = - first_full_round_state[j][k]; - trace[COL_2ND_FULLROUND_STATE_START + j * STATE_SIZE + k][i] = - second_full_round_state[j][k]; + let first_full_round_state = generate_1st_full_round_state(&row.preimage); + let partial_round_state = generate_partial_round_state( + first_full_round_state.last().unwrap().try_into().unwrap(), + ); + let second_full_round_state = generate_2st_full_round_state( + partial_round_state.last().unwrap().try_into().unwrap(), + ); + for j in 0..(ROUNDS_F / 2) { + for k in 0..STATE_SIZE { + trace[COL_1ST_FULLROUND_STATE_START + j * STATE_SIZE + k][i] = + first_full_round_state[j][k]; + trace[COL_2ND_FULLROUND_STATE_START + j * STATE_SIZE + k][i] = + second_full_round_state[j][k]; + } + } + for j in 0..ROUNDS_P { + trace[COL_PARTIAL_ROUND_STATE_START + j][i] = partial_round_state[j][0]; + } + for j in 0..STATE_SIZE { + trace[COL_PARTIAL_ROUND_END_STATE_START + j][i] = + partial_round_state[ROUNDS_P - 1][j]; } } - for j in 0..ROUNDS_P { - trace[COL_PARTIAL_ROUND_STATE_START + j][i] = partial_round_state[j][0]; - } - for j in 0..STATE_SIZE { - trace[COL_PARTIAL_ROUND_END_STATE_START + j][i] = partial_round_state[ROUNDS_P - 1][j]; - } + }; + + add_rows(step_rows, true); + + if step_rows.len() == 0 { + let preimage = (0..STATE_SIZE).map(|_| F::rand()).collect::>(); + let dummy_rows = vec![Row { + preimage: preimage.try_into().expect("can't fail"), + }]; + add_rows(&dummy_rows, false); } trace = pad_trace(trace); @@ -144,7 +165,7 @@ pub fn generate_poseidon2_trace(step_rows: &Vec>) -> [Vec; NUM_COLS] = super::generate_poseidon2_trace(&step_rows); + assert_eq!(trace[0].len(), 1); + } }