Skip to content

Commit

Permalink
incorporate well-formed address padding and add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
mcalancea committed Dec 11, 2024
1 parent 0953c7e commit da45ffc
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 25 deletions.
5 changes: 5 additions & 0 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rayon::{
iter::{IndexedParallelIterator, ParallelIterator},
slice::ParallelSlice,
};
use std::sync::Arc;

use crate::{
circuit_builder::CircuitBuilder,
Expand All @@ -18,6 +19,10 @@ pub mod riscv;
pub enum InstancePaddingStrategy {
Zero,
RepeatLast,
// Custom strategy consists of a closure
// `pad(i, j) = padding value for cell at row i, column j`
// pad should be able to cross thread boundaries
Custom(Arc<dyn Fn(u64, u64) -> u64 + Send + Sync>),
}

pub trait Instruction<E: ExtensionField> {
Expand Down
75 changes: 72 additions & 3 deletions ceno_zkvm/src/tables/ram/ram_impl.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::marker::PhantomData;
use std::{marker::PhantomData, sync::Arc};

use ceno_emul::{Addr, Cycle};
use ff_ext::ExtensionField;
Expand Down Expand Up @@ -388,8 +388,23 @@ impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfig
) -> Result<RowMajorMatrix<F>, ZKVMError> {
assert!(final_mem.len() <= DVRAM::max_len(&self.params));
assert!(DVRAM::max_len(&self.params).is_power_of_two());
let mut final_table =
RowMajorMatrix::<F>::new(final_mem.len(), num_witness, InstancePaddingStrategy::Zero);

let params = self.params.clone();
// TODO: make this more robust
let addr_column = 0;
let padding_fn = move |row: u64, col: u64| {
if col == addr_column {
DVRAM::addr(&params, row as usize) as u64
} else {
0u64
}
};

let mut final_table = RowMajorMatrix::<F>::new(
final_mem.len(),
num_witness,
InstancePaddingStrategy::Custom(Arc::new(padding_fn)),
);

final_table
.par_iter_mut()
Expand All @@ -416,3 +431,57 @@ impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfig
Ok(final_table)
}
}

#[cfg(test)]
mod tests {
use std::iter::successors;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
structs::ProgramParams,
tables::{DynVolatileRamTable, HintsCircuit, HintsTable, MemFinalRecord, TableCircuit},
witness::LkMultiplicity,
};

use ceno_emul::WORD_SIZE;
use goldilocks::{Goldilocks as F, GoldilocksExt2 as E};
use itertools::Itertools;

#[test]
fn test_well_formed_address_padding() {
let mut cs = ConstraintSystem::<E>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = HintsCircuit::construct_circuit(&mut cb).unwrap();

let def_params = ProgramParams::default();
let lkm = LkMultiplicity::default().into_finalize_result();

let input = (0..23)
.map(|i| MemFinalRecord {
addr: HintsTable::addr(&def_params, i),
cycle: 0,
value: 0,
})
.collect_vec();
let wit =
HintsCircuit::<E>::assign_instances(&config, cb.cs.num_witin as usize, &lkm, &input)
.unwrap();

let addr_column = cb
.cs
.witin_namespace_map
.iter()
.position(|name| name == "riscv/RAM_Memory_HintsTable/addr")
.unwrap();

let addr_padded_view = wit.column_padded(addr_column);
// Expect addresses to proceed consecutively inside the padding as well
let expected = successors(Some(addr_padded_view[0]), |idx| {
Some(*idx + F::from(WORD_SIZE as u64))
})
.take(addr_padded_view.len())
.collect::<Vec<_>>();

assert_eq!(addr_padded_view, expected)
}
}
52 changes: 30 additions & 22 deletions ceno_zkvm/src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::{
array,
cell::RefCell,
collections::HashMap,
iter,
mem::{self},
ops::Index,
slice::{Chunks, ChunksMut},
Expand Down Expand Up @@ -46,7 +45,7 @@ pub struct RowMajorMatrix<T: Sized + Sync + Clone + Send + Copy> {
padding_strategy: InstancePaddingStrategy,
}

impl<T: Sized + Sync + Clone + Send + Copy + Default> RowMajorMatrix<T> {
impl<T: Sized + Sync + Clone + Send + Copy + Default + From<u64>> RowMajorMatrix<T> {
pub fn new(num_rows: usize, num_col: usize, padding_strategy: InstancePaddingStrategy) -> Self {
RowMajorMatrix {
values: (0..num_rows * num_col)
Expand All @@ -66,6 +65,10 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default> RowMajorMatrix<T> {
self.values.len() / self.num_col
}

pub fn row(&self, row: usize) -> &[T] {
&self.values[row * self.num_col..(row + 1) * self.num_col]
}

pub fn iter_rows(&self) -> Chunks<T> {
self.values.chunks(self.num_col)
}
Expand All @@ -83,31 +86,36 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default> RowMajorMatrix<T> {
}
}

impl<F: Field> RowMajorMatrix<F> {
impl<F: Field + From<u64>> RowMajorMatrix<F> {
// Returns column number `column`, padded appropriately according to the stored strategy
pub fn column_padded(&self, column: usize) -> Vec<F> {
let num_instances = self.num_instances();
let num_padding_instances = self.num_padding_instances();
let last_element = self.row(num_instances - 1)[column];

let padding_iter = (num_instances..num_instances + num_padding_instances).map(|i| {
match &self.padding_strategy {
InstancePaddingStrategy::Custom(fun) => F::from(fun(i as u64, column as u64)),
InstancePaddingStrategy::RepeatLast => last_element,
_ => F::ZERO,
}
});

self.values
.iter()
.skip(column)
.step_by(self.num_col)
.copied()
.chain(padding_iter)
.collect::<Vec<_>>()
}

pub fn into_mles<E: ff_ext::ExtensionField<BaseField = F>>(
self,
) -> Vec<DenseMultilinearExtension<E>> {
let padding_row = match self.padding_strategy {
// Repeat last row if it exists
InstancePaddingStrategy::RepeatLast if !self.values.is_empty() => {
self.values[self.values.len() - self.num_col..].to_vec()
}
// Otherwise use zeroes
_ => vec![F::ZERO; self.num_col],
};
let num_padding = self.num_padding_instances();
(0..self.num_col)
.into_par_iter()
.map(|i| {
self.values
.iter()
.skip(i)
.step_by(self.num_col)
.chain(&mut iter::repeat(&padding_row[i]).take(num_padding))
.copied()
.collect::<Vec<_>>()
.into_mle()
})
.map(|i| self.column_padded(i).into_mle())
.collect()
}
}
Expand Down

0 comments on commit da45ffc

Please sign in to comment.