diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index bb3109d82..a6533ad22 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -5,6 +5,7 @@ use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSlice, }; +use std::sync::Arc; use crate::{ circuit_builder::CircuitBuilder, @@ -18,6 +19,10 @@ pub mod riscv; pub enum InstancePaddingStrategy { Zero, RepeatLast, + // Custom strategy consists of a closure + // `pad(i, j) = padding value for cell at row i, column j` + // pad should be able to cross thread boundaries + Custom(Arc u64 + Send + Sync>), } pub trait Instruction { diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index b824c2646..5baf948f3 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -1,4 +1,4 @@ -use std::marker::PhantomData; +use std::{marker::PhantomData, sync::Arc}; use ceno_emul::{Addr, Cycle}; use ff_ext::ExtensionField; @@ -388,8 +388,23 @@ impl DynVolatileRamTableConfig ) -> Result, ZKVMError> { assert!(final_mem.len() <= DVRAM::max_len(&self.params)); assert!(DVRAM::max_len(&self.params).is_power_of_two()); - let mut final_table = - RowMajorMatrix::::new(final_mem.len(), num_witness, InstancePaddingStrategy::Zero); + + let params = self.params.clone(); + // TODO: make this more robust + let addr_column = 0; + let padding_fn = move |row: u64, col: u64| { + if col == addr_column { + DVRAM::addr(¶ms, row as usize) as u64 + } else { + 0u64 + } + }; + + let mut final_table = RowMajorMatrix::::new( + final_mem.len(), + num_witness, + InstancePaddingStrategy::Custom(Arc::new(padding_fn)), + ); final_table .par_iter_mut() @@ -416,3 +431,57 @@ impl DynVolatileRamTableConfig Ok(final_table) } } + +#[cfg(test)] +mod tests { + use std::iter::successors; + + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + structs::ProgramParams, + tables::{DynVolatileRamTable, HintsCircuit, HintsTable, MemFinalRecord, TableCircuit}, + witness::LkMultiplicity, + }; + + use ceno_emul::WORD_SIZE; + use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; + use itertools::Itertools; + + #[test] + fn test_well_formed_address_padding() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = HintsCircuit::construct_circuit(&mut cb).unwrap(); + + let def_params = ProgramParams::default(); + let lkm = LkMultiplicity::default().into_finalize_result(); + + let input = (0..23) + .map(|i| MemFinalRecord { + addr: HintsTable::addr(&def_params, i), + cycle: 0, + value: 0, + }) + .collect_vec(); + let wit = + HintsCircuit::::assign_instances(&config, cb.cs.num_witin as usize, &lkm, &input) + .unwrap(); + + let addr_column = cb + .cs + .witin_namespace_map + .iter() + .position(|name| name == "riscv/RAM_Memory_HintsTable/addr") + .unwrap(); + + let addr_padded_view = wit.column_padded(addr_column); + // Expect addresses to proceed consecutively inside the padding as well + let expected = successors(Some(addr_padded_view[0]), |idx| { + Some(*idx + F::from(WORD_SIZE as u64)) + }) + .take(addr_padded_view.len()) + .collect::>(); + + assert_eq!(addr_padded_view, expected) + } +} diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 79cb6a9ba..134a66689 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -3,7 +3,6 @@ use std::{ array, cell::RefCell, collections::HashMap, - iter, mem::{self}, ops::Index, slice::{Chunks, ChunksMut}, @@ -46,7 +45,7 @@ pub struct RowMajorMatrix { padding_strategy: InstancePaddingStrategy, } -impl RowMajorMatrix { +impl> RowMajorMatrix { pub fn new(num_rows: usize, num_col: usize, padding_strategy: InstancePaddingStrategy) -> Self { RowMajorMatrix { values: (0..num_rows * num_col) @@ -66,6 +65,10 @@ impl RowMajorMatrix { self.values.len() / self.num_col } + pub fn row(&self, row: usize) -> &[T] { + &self.values[row * self.num_col..(row + 1) * self.num_col] + } + pub fn iter_rows(&self) -> Chunks { self.values.chunks(self.num_col) } @@ -83,31 +86,36 @@ impl RowMajorMatrix { } } -impl RowMajorMatrix { +impl> RowMajorMatrix { + // Returns column number `column`, padded appropriately according to the stored strategy + pub fn column_padded(&self, column: usize) -> Vec { + let num_instances = self.num_instances(); + let num_padding_instances = self.num_padding_instances(); + let last_element = self.row(num_instances - 1)[column]; + + let padding_iter = (num_instances..num_instances + num_padding_instances).map(|i| { + match &self.padding_strategy { + InstancePaddingStrategy::Custom(fun) => F::from(fun(i as u64, column as u64)), + InstancePaddingStrategy::RepeatLast => last_element, + _ => F::ZERO, + } + }); + + self.values + .iter() + .skip(column) + .step_by(self.num_col) + .copied() + .chain(padding_iter) + .collect::>() + } + pub fn into_mles>( self, ) -> Vec> { - let padding_row = match self.padding_strategy { - // Repeat last row if it exists - InstancePaddingStrategy::RepeatLast if !self.values.is_empty() => { - self.values[self.values.len() - self.num_col..].to_vec() - } - // Otherwise use zeroes - _ => vec![F::ZERO; self.num_col], - }; - let num_padding = self.num_padding_instances(); (0..self.num_col) .into_par_iter() - .map(|i| { - self.values - .iter() - .skip(i) - .step_by(self.num_col) - .chain(&mut iter::repeat(&padding_row[i]).take(num_padding)) - .copied() - .collect::>() - .into_mle() - }) + .map(|i| self.column_padded(i).into_mle()) .collect() } }