Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[experiment] Remove PROGRAM_SIZE from type #587

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions ceno_zkvm/examples/fibonacci_elf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ fn main() {
type E = GoldilocksExt2;
type Pcs = Basefold<GoldilocksExt2, BasefoldRSParams>;
const PROGRAM_SIZE: usize = 1 << 14;
type ExampleProgramTableCircuit<E> = ProgramTableCircuit<E, PROGRAM_SIZE>;

// set up logger
let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap();
Expand Down Expand Up @@ -79,15 +78,18 @@ fn main() {
let config = Rv32imConfig::<E>::construct_circuits(&mut zkvm_cs);
let mmu_config = MmuConfig::<E>::construct_circuits(&mut zkvm_cs);
let dummy_config = DummyExtraConfig::<E>::construct_circuits(&mut zkvm_cs);
let prog_config = zkvm_cs.register_table_circuit::<ExampleProgramTableCircuit<E>>();
let ptc_with_size = ProgramTableCircuit::new(PROGRAM_SIZE);
let prog_config =
zkvm_cs.register_table_circuit_param::<ProgramTableCircuit<E>>(ptc_with_size.clone());
zkvm_cs.register_global_state::<GlobalState>();

let mut zkvm_fixed_traces = ZKVMFixedTraces::default();

zkvm_fixed_traces.register_table_circuit::<ExampleProgramTableCircuit<E>>(
zkvm_fixed_traces.register_table_circuit_param::<ProgramTableCircuit<E>>(
&zkvm_cs,
&prog_config,
vm.program(),
ptc_with_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I applied the same change to this new file.

);

let mem_init = {
Expand Down Expand Up @@ -236,7 +238,7 @@ fn main() {
.unwrap();
// assign program circuit
zkvm_witness
.assign_table_circuit::<ExampleProgramTableCircuit<E>>(&zkvm_cs, &prog_config, vm.program())
.assign_table_circuit::<ProgramTableCircuit<E>>(&zkvm_cs, &prog_config, vm.program())
.unwrap();

if std::env::var("MOCK_PROVING").is_ok() {
Expand Down
10 changes: 6 additions & 4 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ const PROGRAM_CODE: [u32; PROGRAM_SIZE] = {
);
program
};
type ExampleProgramTableCircuit<E> = ProgramTableCircuit<E, PROGRAM_SIZE>;

/// Simple program to greet a person
#[derive(Parser, Debug)]
Expand Down Expand Up @@ -123,15 +122,18 @@ fn main() {

let config = Rv32imConfig::<E>::construct_circuits(&mut zkvm_cs);
let mmu_config = MmuConfig::<E>::construct_circuits(&mut zkvm_cs);
let prog_config = zkvm_cs.register_table_circuit::<ExampleProgramTableCircuit<E>>();
let ptc_with_size = ProgramTableCircuit::new(PROGRAM_SIZE);
let prog_config =
zkvm_cs.register_table_circuit_param::<ProgramTableCircuit<E>>(ptc_with_size.clone());
zkvm_cs.register_global_state::<GlobalState>();

let mut zkvm_fixed_traces = ZKVMFixedTraces::default();

zkvm_fixed_traces.register_table_circuit::<ExampleProgramTableCircuit<E>>(
zkvm_fixed_traces.register_table_circuit_param::<ProgramTableCircuit<E>>(
&zkvm_cs,
&prog_config,
&program,
ptc_with_size.clone(),
);

let static_report = StaticReport::new(&zkvm_cs);
Expand Down Expand Up @@ -279,7 +281,7 @@ fn main() {

// assign program circuit
zkvm_witness
.assign_table_circuit::<ExampleProgramTableCircuit<E>>(&zkvm_cs, &prog_config, &program)
.assign_table_circuit::<ProgramTableCircuit<E>>(&zkvm_cs, &prog_config, &program)
.unwrap();

// get instance counts from witness matrices
Expand Down
10 changes: 3 additions & 7 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,13 +644,9 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
fn load_program_table(t_vec: &mut Vec<Vec<u64>>, program: &Program, challenge: [E; 2]) {
let mut cs = ConstraintSystem::<E>::new(|| "mock_program");
let mut cb = CircuitBuilder::new(&mut cs);
let config =
ProgramTableCircuit::<_, MOCK_PROGRAM_SIZE>::construct_circuit(&mut cb).unwrap();
let fixed = ProgramTableCircuit::<E, MOCK_PROGRAM_SIZE>::generate_fixed_traces(
&config,
cs.num_fixed,
program,
);
let mock_ptc = ProgramTableCircuit::new(MOCK_PROGRAM_SIZE);
let config = mock_ptc.construct_circuit(&mut cb).unwrap();
let fixed = mock_ptc.generate_fixed_traces(&config, cs.num_fixed, program);
for table_expr in &cs.lk_table_expressions {
for row in fixed.iter_rows() {
// TODO: Find a better way to obtain the row content.
Expand Down
16 changes: 8 additions & 8 deletions ceno_zkvm/src/scheme/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,12 @@ fn test_single_add_instance_e2e() {
// opcode circuits
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let halt_config = zkvm_cs.register_opcode_circuit::<HaltInstruction<E>>();
let u16_range_config = zkvm_cs.register_table_circuit::<U16TableCircuit<E>>();
let u16_range_config =
zkvm_cs.register_table_circuit_param::<U16TableCircuit<E>>(U16TableCircuit::default());

let prog_config = zkvm_cs.register_table_circuit::<ProgramTableCircuit<E, PROGRAM_SIZE>>();
let prog_config = zkvm_cs.register_table_circuit_param::<ProgramTableCircuit<E>>(
ProgramTableCircuit::new(PROGRAM_SIZE),
);

let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);
Expand All @@ -242,10 +245,11 @@ fn test_single_add_instance_e2e() {
&(),
);

zkvm_fixed_traces.register_table_circuit::<ProgramTableCircuit<E, PROGRAM_SIZE>>(
zkvm_fixed_traces.register_table_circuit_param::<ProgramTableCircuit<E>>(
&zkvm_cs,
&prog_config,
&program,
ProgramTableCircuit::new(PROGRAM_SIZE),
);

let pk = zkvm_cs
Expand Down Expand Up @@ -295,11 +299,7 @@ fn test_single_add_instance_e2e() {
.assign_table_circuit::<U16TableCircuit<E>>(&zkvm_cs, &u16_range_config, &())
.unwrap();
zkvm_witness
.assign_table_circuit::<ProgramTableCircuit<E, PROGRAM_SIZE>>(
&zkvm_cs,
&prog_config,
&program,
)
.assign_table_circuit::<ProgramTableCircuit<E>>(&zkvm_cs, &prog_config, &program)
.unwrap();

let pi = PublicValues::new(0, 0, 0, 0, 0, vec![0]);
Expand Down
27 changes: 23 additions & 4 deletions ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ impl<E: ExtensionField> Default for ZKVMConstraintSystem<E> {
}
}

impl<E: ExtensionField> ZKVMConstraintSystem<E> {
pub fn register_table_circuit<TC: TableCircuit<E> + Default>(&mut self) -> TC::TableConfig {
self.register_table_circuit_param(TC::default())
}
}

impl<E: ExtensionField> ZKVMConstraintSystem<E> {
pub fn register_opcode_circuit<OC: Instruction<E>>(&mut self) -> OC::InstructionConfig {
let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name()));
Expand All @@ -152,10 +158,10 @@ impl<E: ExtensionField> ZKVMConstraintSystem<E> {
config
}

pub fn register_table_circuit<TC: TableCircuit<E>>(&mut self) -> TC::TableConfig {
pub fn register_table_circuit_param<TC: TableCircuit<E>>(&mut self, tc: TC) -> TC::TableConfig {
let mut cs = ConstraintSystem::new(|| format!("riscv_table/{}", TC::name()));
let mut circuit_builder = CircuitBuilder::<E>::new(&mut cs);
let config = TC::construct_circuit(&mut circuit_builder).unwrap();
let config = tc.construct_circuit(&mut circuit_builder).unwrap();
assert!(self.circuit_css.insert(TC::name(), cs).is_none());

config
Expand Down Expand Up @@ -184,23 +190,35 @@ pub struct ZKVMFixedTraces<E: ExtensionField> {
pub circuit_fixed_traces: BTreeMap<String, Option<RowMajorMatrix<E::BaseField>>>,
}

impl<E: ExtensionField> ZKVMFixedTraces<E> {
pub fn register_table_circuit<TC: TableCircuit<E> + Default>(
&mut self,
cs: &ZKVMConstraintSystem<E>,
config: &TC::TableConfig,
input: &TC::FixedInput,
) {
self.register_table_circuit_param(cs, config, input, TC::default());
}
}

impl<E: ExtensionField> ZKVMFixedTraces<E> {
pub fn register_opcode_circuit<OC: Instruction<E>>(&mut self, _cs: &ZKVMConstraintSystem<E>) {
assert!(self.circuit_fixed_traces.insert(OC::name(), None).is_none());
}

pub fn register_table_circuit<TC: TableCircuit<E>>(
pub fn register_table_circuit_param<TC: TableCircuit<E>>(
&mut self,
cs: &ZKVMConstraintSystem<E>,
config: &TC::TableConfig,
input: &TC::FixedInput,
tc: TC,
) {
let cs = cs.get_cs(&TC::name()).expect("cs not found");
assert!(
self.circuit_fixed_traces
.insert(
TC::name(),
Some(TC::generate_fixed_traces(config, cs.num_fixed, input)),
Some(tc.generate_fixed_traces(config, cs.num_fixed, input)),
)
.is_none()
);
Expand Down Expand Up @@ -277,6 +295,7 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
cs: &ZKVMConstraintSystem<E>,
config: &TC::TableConfig,
input: &TC::WitnessInput,
// tc: TC,
) -> Result<(), ZKVMError> {
assert!(self.combined_lk_mlt.is_some());

Expand Down
7 changes: 6 additions & 1 deletion ceno_zkvm/src/tables/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@ pub use program::{InsnRecord, ProgramTableCircuit};
mod ram;
pub use ram::*;

pub trait TableCircuit<E: ExtensionField> {
pub trait TableCircuit<E: ExtensionField>
where
Self: Default,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Default requirement could move as far out as possible towards the function that actually needs it.

{
type TableConfig: Send + Sync;
type FixedInput: Send + Sync + ?Sized;
type WitnessInput: Send + Sync + ?Sized;

fn name() -> String;

fn construct_circuit(
&self,
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::TableConfig, ZKVMError>;

fn generate_fixed_traces(
&self,
config: &Self::TableConfig,
num_fixed: usize,
input: &Self::FixedInput,
Expand Down
5 changes: 5 additions & 0 deletions ceno_zkvm/src/tables/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub use ops_circuit::{OpsTable, OpsTableCircuit};

use crate::structs::ROMType;

#[derive(Default)]
pub struct AndTable;
impl OpsTable for AndTable {
const ROM_TYPE: ROMType = ROMType::And;
Expand All @@ -25,6 +26,7 @@ impl OpsTable for AndTable {
}
pub type AndTableCircuit<E> = OpsTableCircuit<E, AndTable>;

#[derive(Default)]
pub struct OrTable;
impl OpsTable for OrTable {
const ROM_TYPE: ROMType = ROMType::Or;
Expand All @@ -43,6 +45,7 @@ impl OpsTable for OrTable {
}
pub type OrTableCircuit<E> = OpsTableCircuit<E, OrTable>;

#[derive(Default)]
pub struct XorTable;
impl OpsTable for XorTable {
const ROM_TYPE: ROMType = ROMType::Xor;
Expand All @@ -61,6 +64,7 @@ impl OpsTable for XorTable {
}
pub type XorTableCircuit<E> = OpsTableCircuit<E, XorTable>;

#[derive(Default)]
pub struct LtuTable;
impl OpsTable for LtuTable {
const ROM_TYPE: ROMType = ROMType::Ltu;
Expand All @@ -79,6 +83,7 @@ impl OpsTable for LtuTable {
}
pub type LtuTableCircuit<E> = OpsTableCircuit<E, LtuTable>;

#[derive(Default)]
pub struct PowTable;
impl OpsTable for PowTable {
const ROM_TYPE: ROMType = ROMType::Pow;
Expand Down
6 changes: 4 additions & 2 deletions ceno_zkvm/src/tables/ops/ops_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ pub trait OpsTable {
}
}

#[derive(Default)]
pub struct OpsTableCircuit<E, R>(PhantomData<(E, R)>);

impl<E: ExtensionField, OP: OpsTable> TableCircuit<E> for OpsTableCircuit<E, OP> {
impl<E: ExtensionField, OP: OpsTable + Default> TableCircuit<E> for OpsTableCircuit<E, OP> {
type TableConfig = OpTableConfig;
type FixedInput = ();
type WitnessInput = ();
Expand All @@ -39,14 +40,15 @@ impl<E: ExtensionField, OP: OpsTable> TableCircuit<E> for OpsTableCircuit<E, OP>
format!("OPS_{:?}", OP::ROM_TYPE)
}

fn construct_circuit(cb: &mut CircuitBuilder<E>) -> Result<OpTableConfig, ZKVMError> {
fn construct_circuit(&self, cb: &mut CircuitBuilder<E>) -> Result<OpTableConfig, ZKVMError> {
cb.namespace(
|| Self::name(),
|cb| OpTableConfig::construct_circuit(cb, OP::ROM_TYPE, OP::len()),
)
}

fn generate_fixed_traces(
&self,
config: &OpTableConfig,
num_fixed: usize,
_input: &(),
Expand Down
29 changes: 22 additions & 7 deletions ceno_zkvm/src/tables/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,22 @@ pub struct ProgramTableConfig {
mlt: WitIn,
}

pub struct ProgramTableCircuit<E, const PROGRAM_SIZE: usize>(PhantomData<E>);
#[derive(Clone, Default)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default here is 0 size which is probably not what you want. Just leave it without default to force the use of the …_param function?

pub struct ProgramTableCircuit<E> {
program_size: usize,
phantom_data: PhantomData<E>,
}

impl<E: ExtensionField> ProgramTableCircuit<E> {
pub fn new(program_size: usize) -> Self {
ProgramTableCircuit {
program_size,
phantom_data: PhantomData,
}
}
}

impl<E: ExtensionField, const PROGRAM_SIZE: usize> TableCircuit<E>
for ProgramTableCircuit<E, PROGRAM_SIZE>
{
impl<E: ExtensionField> TableCircuit<E> for ProgramTableCircuit<E> {
type TableConfig = ProgramTableConfig;
type FixedInput = Program;
type WitnessInput = Program;
Expand All @@ -117,7 +128,10 @@ impl<E: ExtensionField, const PROGRAM_SIZE: usize> TableCircuit<E>
"PROGRAM".into()
}

fn construct_circuit(cb: &mut CircuitBuilder<E>) -> Result<ProgramTableConfig, ZKVMError> {
fn construct_circuit(
&self,
cb: &mut CircuitBuilder<E>,
) -> Result<ProgramTableConfig, ZKVMError> {
let record = InsnRecord([
cb.create_fixed(|| "pc")?,
cb.create_fixed(|| "kind")?,
Expand All @@ -137,7 +151,7 @@ impl<E: ExtensionField, const PROGRAM_SIZE: usize> TableCircuit<E>

cb.lk_table_record(
|| "prog table",
PROGRAM_SIZE,
self.program_size,
ROMType::Instruction,
record_exprs,
mlt.expr(),
Expand All @@ -147,13 +161,14 @@ impl<E: ExtensionField, const PROGRAM_SIZE: usize> TableCircuit<E>
}

fn generate_fixed_traces(
&self,
config: &ProgramTableConfig,
num_fixed: usize,
program: &Self::FixedInput,
) -> RowMajorMatrix<E::BaseField> {
let num_instructions = program.instructions.len();
let pc_base = program.base_address;
assert!(num_instructions <= PROGRAM_SIZE);
assert!(num_instructions <= self.program_size);

let mut fixed = RowMajorMatrix::<E::BaseField>::new(num_instructions, num_fixed);

Expand Down
8 changes: 4 additions & 4 deletions ceno_zkvm/src/tables/ram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod ram_circuit;
mod ram_impl;
pub use ram_circuit::{DynVolatileRamTable, MemFinalRecord, MemInitRecord, NonVolatileTable};

#[derive(Clone)]
#[derive(Clone, Default)]
pub struct MemTable;

impl DynVolatileRamTable for MemTable {
Expand All @@ -29,7 +29,7 @@ impl DynVolatileRamTable for MemTable {
pub type MemCircuit<E> = DynVolatileRamCircuit<E, MemTable>;

/// RegTable, fix size without offset
#[derive(Clone)]
#[derive(Clone, Default)]
pub struct RegTable;

impl NonVolatileTable for RegTable {
Expand All @@ -48,7 +48,7 @@ impl NonVolatileTable for RegTable {

pub type RegTableCircuit<E> = NonVolatileRamCircuit<E, RegTable>;

#[derive(Clone)]
#[derive(Clone, Default)]
pub struct StaticMemTable;

impl NonVolatileTable for StaticMemTable {
Expand All @@ -68,7 +68,7 @@ impl NonVolatileTable for StaticMemTable {

pub type StaticMemCircuit<E> = NonVolatileRamCircuit<E, StaticMemTable>;

#[derive(Clone)]
#[derive(Clone, Default)]
pub struct PubIOTable;

impl NonVolatileTable for PubIOTable {
Expand Down
Loading