Skip to content

Commit

Permalink
incorporate well-formed address padding and add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
mcalancea committed Dec 11, 2024
1 parent 0953c7e commit bdf705c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 25 deletions.
5 changes: 5 additions & 0 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rayon::{
iter::{IndexedParallelIterator, ParallelIterator},
slice::ParallelSlice,
};
use std::sync::Arc;

use crate::{
circuit_builder::CircuitBuilder,
Expand All @@ -18,6 +19,10 @@ pub mod riscv;
pub enum InstancePaddingStrategy {
Zero,
RepeatLast,
// Custom strategy consists of a closure
// `pad(i, j) = padding value for cell at row i, column j`
// pad should be able to cross thread boundaries
Custom(Arc<dyn Fn(u64, u64) -> u64 + Send + Sync>),
}

pub trait Instruction<E: ExtensionField> {
Expand Down
75 changes: 72 additions & 3 deletions ceno_zkvm/src/tables/ram/ram_impl.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::marker::PhantomData;
use std::{marker::PhantomData, sync::Arc};

use ceno_emul::{Addr, Cycle};
use ff_ext::ExtensionField;
Expand Down Expand Up @@ -388,8 +388,23 @@ impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfig
) -> Result<RowMajorMatrix<F>, ZKVMError> {
assert!(final_mem.len() <= DVRAM::max_len(&self.params));
assert!(DVRAM::max_len(&self.params).is_power_of_two());
let mut final_table =
RowMajorMatrix::<F>::new(final_mem.len(), num_witness, InstancePaddingStrategy::Zero);

let params = self.params.clone();
// TODO: make this more robust
let addr_column = 0;
let padding_fn = move |row: u64, col: u64| {
if col == addr_column {
DVRAM::addr(&params, row as usize) as u64
} else {
0u64
}
};

let mut final_table = RowMajorMatrix::<F>::new(
final_mem.len(),
num_witness,
InstancePaddingStrategy::Custom(Arc::new(padding_fn)),
);

final_table
.par_iter_mut()
Expand All @@ -416,3 +431,57 @@ impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfig
Ok(final_table)
}
}

#[cfg(test)]
mod tests {
use std::iter::successors;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
structs::ProgramParams,
tables::{DynVolatileRamTable, HintsCircuit, HintsTable, MemFinalRecord, TableCircuit},
witness::LkMultiplicity,
};

use ceno_emul::WORD_SIZE;
use goldilocks::{Goldilocks as F, GoldilocksExt2 as E};
use itertools::Itertools;

#[test]
fn test_well_formed_address_padding() {
let mut cs = ConstraintSystem::<E>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = HintsCircuit::construct_circuit(&mut cb).unwrap();

let def_params = ProgramParams::default();
let lkm = LkMultiplicity::default().into_finalize_result();

let input = (0..23)
.map(|i| MemFinalRecord {
addr: HintsTable::addr(&def_params, i),
cycle: 0,
value: 0,
})
.collect_vec();
let wit =
HintsCircuit::<E>::assign_instances(&config, cb.cs.num_witin as usize, &lkm, &input)
.unwrap();

let addr_column = cb
.cs
.witin_namespace_map
.iter()
.position(|name| name == "riscv/RAM_Memory_HintsTable/addr")
.unwrap();

let addr_padded_view = wit.column_padded_view(addr_column);
let expected = successors(Some(addr_padded_view[0]), |idx| {
Some(*idx + F::from(WORD_SIZE as u64))
})
.take(addr_padded_view.len())
.collect::<Vec<_>>();

// Address column should contain increasing addresses everywhere, including padding
assert_eq!(addr_padded_view, expected)
}
}
51 changes: 29 additions & 22 deletions ceno_zkvm/src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::{
array,
cell::RefCell,
collections::HashMap,
iter,
mem::{self},
ops::Index,
slice::{Chunks, ChunksMut},
Expand Down Expand Up @@ -46,7 +45,7 @@ pub struct RowMajorMatrix<T: Sized + Sync + Clone + Send + Copy> {
padding_strategy: InstancePaddingStrategy,
}

impl<T: Sized + Sync + Clone + Send + Copy + Default> RowMajorMatrix<T> {
impl<T: Sized + Sync + Clone + Send + Copy + Default + From<u64>> RowMajorMatrix<T> {
pub fn new(num_rows: usize, num_col: usize, padding_strategy: InstancePaddingStrategy) -> Self {
RowMajorMatrix {
values: (0..num_rows * num_col)
Expand Down Expand Up @@ -81,33 +80,41 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default> RowMajorMatrix<T> {
pub fn par_batch_iter_mut(&mut self, num_rows: usize) -> rayon::slice::ChunksMut<T> {
self.values.par_chunks_mut(num_rows * self.num_col)
}

pub fn column_padded_view(&self, col_index: usize) -> Vec<T> {
let num_instances = self.num_instances();
let num_padding_instances = self.num_padding_instances();

let padding_fn: Arc<dyn Fn(u64) -> T> = match &self.padding_strategy {
InstancePaddingStrategy::Custom(fun) => Arc::new(|i| T::from(fun(i, col_index as u64))),
InstancePaddingStrategy::RepeatLast if !self.values.is_empty() => {
Arc::new(|_| self.values[self.values.len() - self.num_col + col_index])
}
// zero
_ => Arc::new(|_| T::default()),
};

let padding_iter = (num_instances..num_instances + num_padding_instances)
.map(|i| padding_fn(i as u64))
.take(num_padding_instances);

self.values
.iter()
.skip(col_index)
.step_by(self.num_col)
.copied()
.chain(padding_iter)
.collect::<Vec<_>>()
}
}

impl<F: Field> RowMajorMatrix<F> {
impl<F: Field + From<u64>> RowMajorMatrix<F> {
pub fn into_mles<E: ff_ext::ExtensionField<BaseField = F>>(
self,
) -> Vec<DenseMultilinearExtension<E>> {
let padding_row = match self.padding_strategy {
// Repeat last row if it exists
InstancePaddingStrategy::RepeatLast if !self.values.is_empty() => {
self.values[self.values.len() - self.num_col..].to_vec()
}
// Otherwise use zeroes
_ => vec![F::ZERO; self.num_col],
};
let num_padding = self.num_padding_instances();
(0..self.num_col)
.into_par_iter()
.map(|i| {
self.values
.iter()
.skip(i)
.step_by(self.num_col)
.chain(&mut iter::repeat(&padding_row[i]).take(num_padding))
.copied()
.collect::<Vec<_>>()
.into_mle()
})
.map(|i| self.column_padded_view(i).into_mle())
.collect()
}
}
Expand Down

0 comments on commit bdf705c

Please sign in to comment.