Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Private input integration #622

Merged
merged 20 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ceno_emul/src/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,17 @@ impl ops::AddAssign<u32> for ByteAddr {
}

pub trait IterAddresses {
fn iter_addresses(&self) -> impl Iterator<Item = Addr>;
fn iter_addresses(&self) -> impl ExactSizeIterator<Item = Addr>;
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
}

impl IterAddresses for Range<Addr> {
fn iter_addresses(&self) -> impl Iterator<Item = Addr> {
fn iter_addresses(&self) -> impl ExactSizeIterator<Item = Addr> {
self.clone().step_by(WORD_SIZE)
}
}

impl<'a, T: GetAddr> IterAddresses for &'a [T] {
fn iter_addresses(&self) -> impl Iterator<Item = Addr> {
fn iter_addresses(&self) -> impl ExactSizeIterator<Item = Addr> {
self.iter().map(T::get_addr)
}
}
Expand Down
8 changes: 7 additions & 1 deletion ceno_emul/src/platform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct Platform {
pub rom: Range<Addr>,
pub ram: Range<Addr>,
pub public_io: Range<Addr>,
pub private_io: Range<Addr>,
pub stack_top: Addr,
/// If true, ecall instructions are no-op instead of trap. Testing only.
pub unsafe_ecall_nop: bool,
Expand All @@ -21,6 +22,7 @@ pub const CENO_PLATFORM: Platform = Platform {
rom: 0x2000_0000..0x3000_0000,
ram: 0x8000_0000..0xFFFF_0000,
public_io: 0x3000_1000..0x3000_2000,
private_io: 0x4000_0000..0x5000_0000,
stack_top: 0xC0000000,
unsafe_ecall_nop: false,
};
Expand All @@ -40,6 +42,10 @@ impl Platform {
self.public_io.contains(&addr)
}

pub fn is_priv_io(&self, addr: Addr) -> bool {
self.private_io.contains(&addr)
}

/// Virtual address of a register.
pub const fn register_vma(index: RegIdx) -> Addr {
// Register VMAs are aligned, cannot be confused with indices, and readable in hex.
Expand All @@ -60,7 +66,7 @@ impl Platform {
// Permissions.

pub fn can_read(&self, addr: Addr) -> bool {
self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr)
self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) || self.is_priv_io(addr)
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
}

pub fn can_write(&self, addr: Addr) -> bool {
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ fn main() {
&reg_final,
&mem_final,
&public_io_final,
&[],
)
.unwrap();

Expand Down
45 changes: 42 additions & 3 deletions ceno_zkvm/src/bin/e2e.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ceno_emul::{
ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, Platform, StepRecord, Tracer, VMState,
WORD_SIZE, WordAddr,
ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, IterAddresses, Platform, StepRecord,
Tracer, VMState, WORD_SIZE, Word, WordAddr,
};
use ceno_zkvm::{
instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig},
Expand All @@ -19,7 +19,9 @@ use itertools::{Itertools, MinMaxResult, chain, enumerate};
use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme};
use std::{
collections::{HashMap, HashSet},
fs, panic,
fs,
iter::zip,
panic,
time::Instant,
};
use tracing::level_filters::LevelFilter;
Expand All @@ -41,6 +43,11 @@ struct Args {
/// The preset configuration to use.
#[arg(short, long, value_enum, default_value_t = Preset::Ceno)]
platform: Preset,

/// The private input or hints. This is a raw file mounted as a memory segment.
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
/// Zero-padded to the next power-of-two size.
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
#[arg(long)]
private_input: Option<String>,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
Expand Down Expand Up @@ -94,6 +101,17 @@ fn main() {
let elf_bytes = fs::read(&args.elf).expect("read elf file");
let mut vm = VMState::new_from_elf(platform.clone(), &elf_bytes).unwrap();

tracing::info!("Loading private input file: {:?}", args.private_input);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the discussion about private IO, someone suggested to use the term 'hint' throughout?

('Private input' works perfectly fine for me; but I remember some people having strong opinions on this.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ref. If the term "hint" is understood I’ll rename things to that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, you can do that after you finish the sproll-evm test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let priv_io = memory_from_file(&args.private_input);
assert!(
priv_io.len() <= platform.private_io.iter_addresses().len(),
"private input must fit in {} bytes",
platform.private_io.len()
);
for (addr, value) in zip(platform.private_io.iter_addresses(), &priv_io) {
vm.init_memory(addr.into(), *value);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a comment to say, if len(private_io) != len(private_input_file), then it is right-padded with 0s?

Copy link
Collaborator Author

@naure naure Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It is not padded. Or precisely it is zero-padded only to the next power-of-two size.

The rest of memory does not exist - unsatisfiable if the program tries to use it.

(added a comment and assert)

Copy link
Collaborator

@matthiasgoergens matthiasgoergens Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For technical reasons, if we are using rkyv, it would be better to pad from the left by default.

That's means the last byte that's specified should always appear in the same position in the VM.

The rest of memory does not exist - unsatisfiable if the program tries to use it.

That's good, and exactly what we want. Can you please make both the doc comment and the emulator reflect that behaviour?

Btw, would it be possible to make both kinds of padding be 'inaccessible' (from the point of view of the VM), instead of a mix of 0 and 'inaccessible'? That would be best!


// keygen
let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
Expand Down Expand Up @@ -249,6 +267,14 @@ fn main() {
.map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0))
.collect_vec();

let priv_io_final = zip(platform.private_io.iter_addresses(), &priv_io)
.map(|(addr, &value)| MemFinalRecord {
addr,
value,
cycle: *final_access.get(&addr.into()).unwrap_or(&0),
})
.collect_vec();

// assign table circuits
config
.assign_table_circuit(&zkvm_cs, &mut zkvm_witness)
Expand All @@ -260,6 +286,7 @@ fn main() {
&reg_final,
&mem_final,
&io_final,
&priv_io_final,
)
.unwrap();
// assign program circuit
Expand Down Expand Up @@ -332,6 +359,18 @@ fn main() {
};
}

fn memory_from_file(path: &Option<String>) -> Vec<u32> {
path.as_ref()
.map(|path| {
let mut buf = fs::read(path).expect("could not read file");
buf.resize(buf.len().next_multiple_of(WORD_SIZE), 0);
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
buf.chunks_exact(WORD_SIZE)
.map(|word| Word::from_le_bytes(word.try_into().unwrap()))
.collect_vec()
})
.unwrap_or_default()
}

fn debug_memory_ranges(vm: &VMState, mem_final: &[MemFinalRecord]) {
let accessed_addrs = vm
.tracer()
Expand Down
24 changes: 21 additions & 3 deletions ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use crate::{
error::ZKVMError,
structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::{
MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RegTable,
RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit,
MemFinalRecord, MemInitRecord, NonVolatileTable, PrivateIOCircuit, PubIOCircuit,
PubIOTable, RegTable, RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit,
},
};

Expand All @@ -20,6 +20,8 @@ pub struct MmuConfig<E: ExtensionField> {
pub static_mem_config: <StaticMemCircuit<E> as TableCircuit<E>>::TableConfig,
/// Initialization of public IO.
pub public_io_config: <PubIOCircuit<E> as TableCircuit<E>>::TableConfig,
/// Initialization of private IO.
pub private_io_config: <PrivateIOCircuit<E> as TableCircuit<E>>::TableConfig,
pub params: ProgramParams,
}

Expand All @@ -30,11 +32,13 @@ impl<E: ExtensionField> MmuConfig<E> {
let static_mem_config = cs.register_table_circuit::<StaticMemCircuit<E>>();

let public_io_config = cs.register_table_circuit::<PubIOCircuit<E>>();
let private_io_config = cs.register_table_circuit::<PrivateIOCircuit<E>>();

Self {
reg_config,
static_mem_config,
public_io_config,
private_io_config,
params: cs.params.clone(),
}
}
Expand All @@ -48,7 +52,13 @@ impl<E: ExtensionField> MmuConfig<E> {
io_addrs: &[Addr],
) {
assert!(
chain!(static_mem_init.iter_addresses(), io_addrs.iter_addresses()).all_unique(),
chain!(
static_mem_init.iter_addresses(),
io_addrs.iter_addresses(),
// TODO: optimize with min_max and Range.
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
self.params.platform.private_io.iter_addresses(),
)
.all_unique(),
"memory addresses must be unique"
);

Expand All @@ -61,6 +71,7 @@ impl<E: ExtensionField> MmuConfig<E> {
);

fixed.register_table_circuit::<PubIOCircuit<E>>(cs, &self.public_io_config, io_addrs);
fixed.register_table_circuit::<PrivateIOCircuit<E>>(cs, &self.private_io_config, &());
naure marked this conversation as resolved.
Show resolved Hide resolved
}

pub fn assign_table_circuit(
Expand All @@ -70,6 +81,7 @@ impl<E: ExtensionField> MmuConfig<E> {
reg_final: &[MemFinalRecord],
static_mem_final: &[MemFinalRecord],
io_cycles: &[Cycle],
private_io_final: &[MemFinalRecord],
) -> Result<(), ZKVMError> {
witness.assign_table_circuit::<RegTableCircuit<E>>(cs, &self.reg_config, reg_final)?;

Expand All @@ -81,6 +93,12 @@ impl<E: ExtensionField> MmuConfig<E> {

witness.assign_table_circuit::<PubIOCircuit<E>>(cs, &self.public_io_config, io_cycles)?;

witness.assign_table_circuit::<PrivateIOCircuit<E>>(
cs,
&self.private_io_config,
private_io_final,
)?;

Ok(())
}

Expand Down
12 changes: 6 additions & 6 deletions ceno_zkvm/src/tables/ram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,25 @@ impl DynVolatileRamTable for DynMemTable {
pub type DynMemCircuit<E> = DynVolatileRamCircuit<E, DynMemTable>;

#[derive(Clone)]
pub struct PrivateMemTable;
impl DynVolatileRamTable for PrivateMemTable {
pub struct PrivateIOTable;
impl DynVolatileRamTable for PrivateIOTable {
const RAM_TYPE: RAMType = RAMType::Memory;
const V_LIMBS: usize = 1; // See `MemoryExpr`.
const ZERO_INIT: bool = false;

fn offset_addr(params: &ProgramParams) -> Addr {
params.platform.ram.start
params.platform.private_io.start
}

fn end_addr(params: &ProgramParams) -> Addr {
params.platform.ram.end
params.platform.private_io.end
}

fn name() -> &'static str {
"PrivateMemTable"
"PrivateIOTable"
}
}
pub type PrivateMemCircuit<E> = DynVolatileRamCircuit<E, PrivateMemTable>;
pub type PrivateIOCircuit<E> = DynVolatileRamCircuit<E, PrivateIOTable>;

/// RegTable, fix size without offset
#[derive(Clone)]
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/tables/ram/ram_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ impl<E: ExtensionField, DVRAM: DynVolatileRamTable + Send + Sync + Clone> TableC
type WitnessInput = [MemFinalRecord];

fn name() -> String {
format!("RAM_{:?}", DVRAM::RAM_TYPE)
format!("RAM_{:?}_{}", DVRAM::RAM_TYPE, DVRAM::name())
}

fn construct_circuit(cb: &mut CircuitBuilder<E>) -> Result<Self::TableConfig, ZKVMError> {
Expand Down
8 changes: 5 additions & 3 deletions ceno_zkvm/src/tables/ram/ram_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,15 +369,17 @@ impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfig
) -> Result<RowMajorMatrix<F>, ZKVMError> {
assert!(final_mem.len() <= DVRAM::max_len(&self.params));
assert!(DVRAM::max_len(&self.params).is_power_of_two());
let mut final_table =
RowMajorMatrix::<F>::new(final_mem.len().next_power_of_two(), num_witness);
let mut final_table = RowMajorMatrix::<F>::new(final_mem.len(), num_witness);

final_table
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(final_mem.into_par_iter())
.for_each(|(row, rec)| {
.enumerate()
matthiasgoergens marked this conversation as resolved.
Show resolved Hide resolved
.for_each(|(i, (row, rec))| {
assert_eq!(rec.addr, DVRAM::addr(&self.params, i));
set_val!(row, self.addr, rec.addr as u64);

if self.final_v.len() == 1 {
// Assign value directly.
set_val!(row, self.final_v[0], rec.value as u64);
Expand Down