Skip to content

Commit

Permalink
Dynamic platform (#608)
Browse files Browse the repository at this point in the history
Makes the platform dynamic:
- Extracts some of the more fixed behaviors of `Platform` into static
methods, which replace some calls to `CENO_PLATFORM`.
- Introduces a `ProgramParams` struct which stores the platform, the
program size (no longer present in the type) and other useful params.
- `ProgramPrams` flow from `ZKVMConstraintSystem`, through the
`CircuitBuilder`, through some of the `TableConfig`s and into the
relevant table trait methods which implement the behavior we wanted to
make dynamic.

Observations:
- Could instate some safeguards to use the same platform throughout the
pipeline. One example: presently it's technically possible to use one
platform for building the `VM` and another for the constraint system.
- There is a default `ProgramParams` which is implicitly used in some
constructors. This was done so as not to edit a lot of tests
prematurely. While the default values should generally be compatible
with the original intention of these test cases, it may be wise to
remove the implicit defaults and force ourselves to choose params
everywhere. This can be done in a subsequent PR.

---------

Co-authored-by: naure <[email protected]>
  • Loading branch information
mcalancea and naure authored Nov 21, 2024
1 parent a50cfc5 commit a1e7462
Show file tree
Hide file tree
Showing 21 changed files with 205 additions and 138 deletions.
21 changes: 12 additions & 9 deletions ceno_emul/src/platform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::addr::{Addr, RegIdx};
/// - the layout of virtual memory,
/// - special addresses, such as the initial PC,
/// - codes of environment calls.
#[derive(Copy, Clone)]
#[derive(Clone, Debug)]
pub struct Platform {
pub rom_start: Addr,
pub rom_end: Addr,
Expand Down Expand Up @@ -78,13 +78,13 @@ impl Platform {
}

/// Virtual address of a register.
pub const fn register_vma(&self, index: RegIdx) -> Addr {
pub const fn register_vma(index: RegIdx) -> Addr {
// Register VMAs are aligned, cannot be confused with indices, and readable in hex.
(index << 8) as Addr
}

/// Register index from a virtual address (unchecked).
pub const fn register_index(&self, vma: Addr) -> RegIdx {
pub const fn register_index(vma: Addr) -> RegIdx {
(vma >> 8) as RegIdx
}

Expand All @@ -111,27 +111,27 @@ impl Platform {
// Environment calls.

/// Register containing the ecall function code. (x5, t0)
pub const fn reg_ecall(&self) -> RegIdx {
pub const fn reg_ecall() -> RegIdx {
5
}

/// Register containing the first function argument. (x10, a0)
pub const fn reg_arg0(&self) -> RegIdx {
pub const fn reg_arg0() -> RegIdx {
10
}

/// Register containing the 2nd function argument. (x11, a1)
pub const fn reg_arg1(&self) -> RegIdx {
pub const fn reg_arg1() -> RegIdx {
11
}

/// The code of ecall HALT.
pub const fn ecall_halt(&self) -> u32 {
pub const fn ecall_halt() -> u32 {
0
}

/// The code of success.
pub const fn code_success(&self) -> u32 {
pub const fn code_success() -> u32 {
0
}
}
Expand All @@ -151,7 +151,10 @@ mod tests {
assert!(!p.is_ram(p.rom_start()));
assert!(!p.is_ram(p.rom_end()));
// Registers do not overlap with ROM or RAM.
for reg in [p.register_vma(0), p.register_vma(VMState::REG_COUNT - 1)] {
for reg in [
Platform::register_vma(0),
Platform::register_vma(VMState::REG_COUNT - 1),
] {
assert!(!p.is_rom(reg));
assert!(!p.is_ram(reg));
}
Expand Down
18 changes: 8 additions & 10 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::HashMap, fmt, mem};

use crate::{
CENO_PLATFORM, InsnKind, PC_STEP_SIZE,
CENO_PLATFORM, InsnKind, PC_STEP_SIZE, Platform,
addr::{ByteAddr, Cycle, RegIdx, Word, WordAddr},
encode_rv32,
rv32im::DecodedInstruction,
Expand Down Expand Up @@ -35,7 +35,7 @@ pub struct StepRecord {
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct MemOp<T> {
/// Virtual Memory Address.
/// For registers, get it from `CENO_PLATFORM.register_vma(idx)`.
/// For registers, get it from `Platform::register_vma(idx)`.
pub addr: WordAddr,
/// The Word read, or the Change<Word> to be written.
pub value: T,
Expand All @@ -46,7 +46,7 @@ pub struct MemOp<T> {
impl<T> MemOp<T> {
/// Get the register index of this operation.
pub fn register_index(&self) -> RegIdx {
CENO_PLATFORM.register_index(self.addr.into())
Platform::register_index(self.addr.into())
}
}

Expand Down Expand Up @@ -227,19 +227,17 @@ impl StepRecord {
pc,
insn_code,
rs1: rs1_read.map(|rs1| ReadOp {
addr: CENO_PLATFORM.register_vma(insn.rs1() as RegIdx).into(),
addr: Platform::register_vma(insn.rs1() as RegIdx).into(),
value: rs1,
previous_cycle,
}),
rs2: rs2_read.map(|rs2| ReadOp {
addr: CENO_PLATFORM.register_vma(insn.rs2() as RegIdx).into(),
addr: Platform::register_vma(insn.rs2() as RegIdx).into(),
value: rs2,
previous_cycle,
}),
rd: rd.map(|rd| WriteOp {
addr: CENO_PLATFORM
.register_vma(insn.rd_internal() as RegIdx)
.into(),
addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(),
value: rd,
previous_cycle,
}),
Expand Down Expand Up @@ -335,7 +333,7 @@ impl Tracer {
}

pub fn load_register(&mut self, idx: RegIdx, value: Word) {
let addr = CENO_PLATFORM.register_vma(idx).into();
let addr = Platform::register_vma(idx).into();

match (&self.record.rs1, &self.record.rs2) {
(None, None) => {
Expand All @@ -361,7 +359,7 @@ impl Tracer {
unimplemented!("Only one register write is supported");
}

let addr = CENO_PLATFORM.register_vma(idx).into();
let addr = Platform::register_vma(idx).into();
self.record.rd = Some(WriteOp {
addr,
value,
Expand Down
6 changes: 3 additions & 3 deletions ceno_emul/src/vm_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ impl VMState {
impl EmuContext for VMState {
// Expect an ecall to terminate the program: function HALT with argument exit_code.
fn ecall(&mut self) -> Result<bool> {
let function = self.load_register(self.platform.reg_ecall())?;
let arg0 = self.load_register(self.platform.reg_arg0())?;
if function == self.platform.ecall_halt() {
let function = self.load_register(Platform::reg_ecall())?;
let arg0 = self.load_register(Platform::reg_arg0())?;
if function == Platform::ecall_halt() {
tracing::debug!("halt with exit_code={}", arg0);

self.halt();
Expand Down
4 changes: 2 additions & 2 deletions ceno_emul/tests/test_elf.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use ceno_emul::{ByteAddr, CENO_PLATFORM, EmuContext, InsnKind, StepRecord, VMState};
use ceno_emul::{ByteAddr, CENO_PLATFORM, EmuContext, InsnKind, Platform, StepRecord, VMState};

#[test]
fn test_ceno_rt_mini() -> Result<()> {
Expand All @@ -16,7 +16,7 @@ fn test_ceno_rt_panic() -> Result<()> {
let steps = run(&mut state)?;
let last = steps.last().unwrap();
assert_eq!(last.insn().codes().kind, InsnKind::EANY);
assert_eq!(last.rs1().unwrap().value, CENO_PLATFORM.ecall_halt());
assert_eq!(last.rs1().unwrap().value, Platform::ecall_halt());
assert_eq!(last.rs2().unwrap().value, 1); // panic / halt(1)
Ok(())
}
Expand Down
10 changes: 5 additions & 5 deletions ceno_emul/tests/test_vm_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use anyhow::Result;
use std::collections::{BTreeMap, HashMap};

use ceno_emul::{
CENO_PLATFORM, Cycle, EmuContext, InsnKind, Program, StepRecord, Tracer, VMState, WORD_SIZE,
WordAddr,
CENO_PLATFORM, Cycle, EmuContext, InsnKind, Platform, Program, StepRecord, Tracer, VMState,
WORD_SIZE, WordAddr,
};

#[test]
Expand Down Expand Up @@ -119,7 +119,7 @@ fn expected_ops_fibonacci_20() -> Vec<InsnKind> {
/// Reconstruct the last access of each register.
fn expected_final_accesses_fibonacci_20() -> HashMap<WordAddr, Cycle> {
let mut accesses = HashMap::new();
let x = |i| WordAddr::from(CENO_PLATFORM.register_vma(i));
let x = |i| WordAddr::from(Platform::register_vma(i));
const C: Cycle = Tracer::SUBCYCLES_PER_INSN;

let mut cycle = C; // First cycle.
Expand All @@ -141,8 +141,8 @@ fn expected_final_accesses_fibonacci_20() -> HashMap<WordAddr, Cycle> {
cycle += C;

// Now at the final ECALL cycle.
accesses.insert(x(CENO_PLATFORM.reg_ecall()), cycle + Tracer::SUBCYCLE_RS1);
accesses.insert(x(CENO_PLATFORM.reg_arg0()), cycle + Tracer::SUBCYCLE_RS2);
accesses.insert(x(Platform::reg_ecall()), cycle + Tracer::SUBCYCLE_RS1);
accesses.insert(x(Platform::reg_arg0()), cycle + Tracer::SUBCYCLE_RS2);

accesses
}
25 changes: 13 additions & 12 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ use ceno_zkvm::{
instructions::riscv::{MemPadder, MmuConfig, Rv32imConfig, constants::EXIT_PC},
scheme::{mock_prover::MockProver, prover::ZKVMProver},
state::GlobalState,
structs::ProgramParams,
tables::{MemFinalRecord, ProgramTableCircuit},
};
use clap::Parser;

use ceno_emul::{
CENO_PLATFORM, EmuContext,
InsnKind::{ADD, BLTU, EANY, LUI, LW},
PC_WORD_SIZE, Program, StepRecord, Tracer, VMState, Word, WordAddr, encode_rv32,
PC_WORD_SIZE, Platform, Program, StepRecord, Tracer, VMState, Word, WordAddr, encode_rv32,
};
use ceno_zkvm::{
scheme::{PublicValues, constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier},
Expand Down Expand Up @@ -54,7 +55,7 @@ const PROGRAM_CODE: [u32; PROGRAM_SIZE] = {
);
program
};
type ExampleProgramTableCircuit<E> = ProgramTableCircuit<E, PROGRAM_SIZE>;
type ExampleProgramTableCircuit<E> = ProgramTableCircuit<E>;

/// Simple program to greet a person
#[derive(Parser, Debug)]
Expand Down Expand Up @@ -119,7 +120,11 @@ fn main() {
// keygen
let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let mut zkvm_cs = ZKVMConstraintSystem::default();
let program_params = ProgramParams {
program_size: PROGRAM_SIZE,
..Default::default()
};
let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params);

let config = Rv32imConfig::<E>::construct_circuits(&mut zkvm_cs);
let mmu_config = MmuConfig::<E>::construct_circuits(&mut zkvm_cs);
Expand All @@ -136,17 +141,13 @@ fn main() {

let static_report = StaticReport::new(&zkvm_cs);

let reg_init = MmuConfig::<E>::initial_registers();
let reg_init = mmu_config.initial_registers();

// RAM is not used in this program, but it must have a particular size at the moment.
let mem_init = MemPadder::init_mem(mem_addresses, MmuConfig::<E>::static_mem_len(), &[]);
let mem_init = MemPadder::init_mem(mem_addresses, mmu_config.static_mem_len(), &[]);

let init_public_io = |values: &[Word]| {
MemPadder::init_mem(
io_addresses.clone(),
MmuConfig::<E>::public_io_len(),
values,
)
MemPadder::init_mem(io_addresses.clone(), mmu_config.public_io_len(), values)
};

let io_addrs = init_public_io(&[]).iter().map(|v| v.addr).collect_vec();
Expand Down Expand Up @@ -197,7 +198,7 @@ fn main() {
.rev()
.find(|record| {
record.insn().codes().kind == EANY
&& record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt()
&& record.rs1().unwrap().value == Platform::ecall_halt()
})
.expect("halt record not found");

Expand Down Expand Up @@ -227,7 +228,7 @@ fn main() {
.map(|rec| {
let index = rec.addr as usize;
if index < VMState::REG_COUNT {
let vma: WordAddr = CENO_PLATFORM.register_vma(index).into();
let vma: WordAddr = Platform::register_vma(index).into();
MemFinalRecord {
addr: rec.addr,
value: vm.peek_register(index),
Expand Down
23 changes: 14 additions & 9 deletions ceno_zkvm/src/bin/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ceno_zkvm::{
verifier::ZKVMVerifier,
},
state::GlobalState,
structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit},
};
use clap::{Parser, ValueEnum};
Expand Down Expand Up @@ -55,7 +55,7 @@ fn main() {
type E = GoldilocksExt2;
type Pcs = Basefold<GoldilocksExt2, BasefoldRSParams>;
const PROGRAM_SIZE: usize = 1 << 14;
type ExampleProgramTableCircuit<E> = ProgramTableCircuit<E, PROGRAM_SIZE>;
type ExampleProgramTableCircuit<E> = ProgramTableCircuit<E>;

// set up logger
let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap();
Expand Down Expand Up @@ -93,12 +93,17 @@ fn main() {

tracing::info!("Loading ELF file: {}", args.elf);
let elf_bytes = fs::read(&args.elf).expect("read elf file");
let mut vm = VMState::new_from_elf(platform, &elf_bytes).unwrap();
let mut vm = VMState::new_from_elf(platform.clone(), &elf_bytes).unwrap();

// keygen
let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let mut zkvm_cs = ZKVMConstraintSystem::default();
let program_params = ProgramParams {
platform: platform.clone(),
program_size: PROGRAM_SIZE,
..ProgramParams::default()
};
let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params);

let config = Rv32imConfig::<E>::construct_circuits(&mut zkvm_cs);
let mmu_config = MmuConfig::<E>::construct_circuits(&mut zkvm_cs);
Expand Down Expand Up @@ -130,13 +135,13 @@ fn main() {

let mem_init = chain!(program_addrs, stack_addrs).collect_vec();

mem_padder.padded_sorted(MmuConfig::<E>::static_mem_len(), mem_init)
mem_padder.padded_sorted(mmu_config.static_mem_len(), mem_init)
};

// IO is not used in this program, but it must have a particular size at the moment.
let io_init = mem_padder.padded_sorted(MmuConfig::<E>::public_io_len(), vec![]);
let io_init = mem_padder.padded_sorted(mmu_config.public_io_len(), vec![]);

let reg_init = MmuConfig::<E>::initial_registers();
let reg_init = mmu_config.initial_registers();
config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces);
mmu_config.generate_fixed_traces(
&zkvm_cs,
Expand Down Expand Up @@ -175,7 +180,7 @@ fn main() {
.rev()
.find(|record| {
record.insn().codes().kind == EANY
&& record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt()
&& record.rs1().unwrap().value == Platform::ecall_halt()
})
.and_then(|halt_record| halt_record.rs2())
.map(|rs2| rs2.value);
Expand Down Expand Up @@ -208,7 +213,7 @@ fn main() {
.map(|rec| {
let index = rec.addr as usize;
if index < VMState::REG_COUNT {
let vma: WordAddr = CENO_PLATFORM.register_vma(index).into();
let vma: WordAddr = Platform::register_vma(index).into();
MemFinalRecord {
addr: rec.addr,
value: vm.peek_register(index),
Expand Down
10 changes: 7 additions & 3 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ use crate::{
END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX,
UINT_LIMBS,
},
structs::{RAMType, ROMType},
structs::{ProgramParams, RAMType, ROMType},
tables::InsnRecord,
};

impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
pub fn new(cs: &'a mut ConstraintSystem<E>) -> Self {
Self { cs }
Self::new_with_params(cs, ProgramParams::default())
}
pub fn new_with_params(cs: &'a mut ConstraintSystem<E>, params: ProgramParams) -> Self {
Self { cs, params }
}

pub fn create_witin<NR, N>(&mut self, name_fn: N) -> WitIn
Expand Down Expand Up @@ -321,7 +324,8 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
cb: impl FnOnce(&mut CircuitBuilder<E>) -> Result<T, ZKVMError>,
) -> Result<T, ZKVMError> {
self.cs.namespace(name_fn, |cs| {
let mut inner_circuit_builder = CircuitBuilder::new(cs);
let mut inner_circuit_builder =
CircuitBuilder::new_with_params(cs, self.params.clone());
cb(&mut inner_circuit_builder)
})
}
Expand Down
Loading

0 comments on commit a1e7462

Please sign in to comment.