diff --git a/alu_u32/src/add/mod.rs b/alu_u32/src/add/mod.rs index 77435858..f3f1db0b 100644 --- a/alu_u32/src/add/mod.rs +++ b/alu_u32/src/add/mod.rs @@ -11,9 +11,10 @@ use valida_opcodes::ADD32; use valida_range::MachineWithRangeChip; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; @@ -29,12 +30,12 @@ pub struct Add32Chip { pub operations: Vec, } -impl Chip for Add32Chip +impl Chip for Add32Chip where - F: PrimeField, - M: MachineWithGeneralBus + MachineWithRangeBus8, + M: MachineWithGeneralBus + MachineWithRangeBus8, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { let rows = self .operations .par_iter() @@ -44,12 +45,12 @@ where let mut trace = RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_ADD_COLS); - pad_to_power_of_two::(&mut trace.values); + pad_to_power_of_two::(&mut trace.values); trace } - fn global_sends(&self, machine: &M) -> Vec> { + fn global_sends(&self, machine: &M) -> Vec> { let sends = ADD_COL_MAP .output .0 @@ -66,8 +67,8 @@ where sends } - fn global_receives(&self, machine: &M) -> Vec> { - let opcode = VirtualPairCol::constant(M::F::from_canonical_u32(ADD32)); + fn global_receives(&self, machine: &M) -> Vec> { + let opcode = VirtualPairCol::constant(SC::Val::from_canonical_u32(ADD32)); let input_1 = ADD_COL_MAP.input_1.0.map(VirtualPairCol::single_main); let input_2 = ADD_COL_MAP.input_2.0.map(VirtualPairCol::single_main); let output = ADD_COL_MAP.output.0.map(VirtualPairCol::single_main); @@ -120,21 +121,22 @@ impl Add32Chip { } } -pub trait MachineWithAdd32Chip: MachineWithCpuChip { +pub trait MachineWithAdd32Chip: MachineWithCpuChip { fn add_u32(&self) -> &Add32Chip; fn add_u32_mut(&mut self) -> &mut Add32Chip; } instructions!(Add32Instruction); -impl Instruction for Add32Instruction +impl Instruction for Add32Instruction where - M: MachineWithAdd32Chip + MachineWithRangeChip<256>, + M: MachineWithAdd32Chip + MachineWithRangeChip, + F: Field, { const OPCODE: u32 = ADD32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; diff --git a/alu_u32/src/bitwise/mod.rs b/alu_u32/src/bitwise/mod.rs index 206441e3..5d16f176 100644 --- a/alu_u32/src/bitwise/mod.rs +++ b/alu_u32/src/bitwise/mod.rs @@ -10,9 +10,10 @@ use valida_machine::{instructions, Chip, Instruction, Interaction, Operands, Wor use valida_opcodes::{AND32, OR32, XOR32}; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; @@ -30,12 +31,12 @@ pub struct Bitwise32Chip { pub operations: Vec, } -impl Chip for Bitwise32Chip +impl Chip for Bitwise32Chip where - F: PrimeField, - M: MachineWithGeneralBus, + M: MachineWithGeneralBus, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { let rows = self .operations .par_iter() @@ -47,19 +48,19 @@ where NUM_BITWISE_COLS, ); - pad_to_power_of_two::(&mut trace.values); + pad_to_power_of_two::(&mut trace.values); trace } - fn global_receives(&self, machine: &M) -> Vec> { + fn global_receives(&self, machine: &M) -> Vec> { let opcode = VirtualPairCol::new_main( vec![ - (COL_MAP.is_and, M::F::from_canonical_u32(AND32)), - (COL_MAP.is_or, M::F::from_canonical_u32(OR32)), - (COL_MAP.is_xor, M::F::from_canonical_u32(XOR32)), + (COL_MAP.is_and, SC::Val::from_canonical_u32(AND32)), + (COL_MAP.is_or, SC::Val::from_canonical_u32(OR32)), + (COL_MAP.is_xor, SC::Val::from_canonical_u32(XOR32)), ], - M::F::zero(), + SC::Val::zero(), ); let input_1 = COL_MAP.input_1.0.map(VirtualPairCol::single_main); let input_2 = COL_MAP.input_2.0.map(VirtualPairCol::single_main); @@ -126,21 +127,22 @@ impl Bitwise32Chip { } } -pub trait MachineWithBitwise32Chip: MachineWithCpuChip { +pub trait MachineWithBitwise32Chip: MachineWithCpuChip { fn bitwise_u32(&self) -> &Bitwise32Chip; fn bitwise_u32_mut(&mut self) -> &mut Bitwise32Chip; } instructions!(And32Instruction, Or32Instruction, Xor32Instruction); -impl Instruction for Xor32Instruction +impl Instruction for Xor32Instruction where - M: MachineWithBitwise32Chip, + M: MachineWithBitwise32Chip, + F: Field, { const OPCODE: u32 = XOR32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -171,14 +173,15 @@ where } } -impl Instruction for And32Instruction +impl Instruction for And32Instruction where - M: MachineWithBitwise32Chip, + M: MachineWithBitwise32Chip, + F: Field, { const OPCODE: u32 = AND32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -209,14 +212,15 @@ where } } -impl Instruction for Or32Instruction +impl Instruction for Or32Instruction where - M: MachineWithBitwise32Chip, + M: MachineWithBitwise32Chip, + F: Field, { const OPCODE: u32 = OR32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; diff --git a/alu_u32/src/div/mod.rs b/alu_u32/src/div/mod.rs index 38f0814c..410cb0c8 100644 --- a/alu_u32/src/div/mod.rs +++ b/alu_u32/src/div/mod.rs @@ -12,9 +12,10 @@ use valida_opcodes::{DIV32, SDIV32}; use valida_range::MachineWithRangeChip; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; @@ -31,12 +32,12 @@ pub struct Div32Chip { pub operations: Vec, } -impl Chip for Div32Chip +impl Chip for Div32Chip where - F: PrimeField, - M: MachineWithGeneralBus, + M: MachineWithGeneralBus, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { let rows = self .operations .par_iter() @@ -46,18 +47,18 @@ where let mut trace = RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_DIV_COLS); - pad_to_power_of_two::(&mut trace.values); + pad_to_power_of_two::(&mut trace.values); trace } - fn global_receives(&self, machine: &M) -> Vec> { + fn global_receives(&self, machine: &M) -> Vec> { let opcode = VirtualPairCol::new_main( vec![ - (DIV_COL_MAP.is_div, M::F::from_canonical_u32(DIV32)), - (DIV_COL_MAP.is_sdiv, M::F::from_canonical_u32(SDIV32)), + (DIV_COL_MAP.is_div, SC::Val::from_canonical_u32(DIV32)), + (DIV_COL_MAP.is_sdiv, SC::Val::from_canonical_u32(SDIV32)), ], - M::F::zero(), + SC::Val::zero(), ); let input_1 = DIV_COL_MAP.input_1.0.map(VirtualPairCol::single_main); let input_2 = DIV_COL_MAP.input_2.0.map(VirtualPairCol::single_main); @@ -102,21 +103,22 @@ impl Div32Chip { } } -pub trait MachineWithDiv32Chip: MachineWithCpuChip { +pub trait MachineWithDiv32Chip: MachineWithCpuChip { fn div_u32(&self) -> &Div32Chip; fn div_u32_mut(&mut self) -> &mut Div32Chip; } instructions!(Div32Instruction, SDiv32Instruction); -impl Instruction for Div32Instruction +impl Instruction for Div32Instruction where - M: MachineWithDiv32Chip + MachineWithRangeChip<256>, + M: MachineWithDiv32Chip + MachineWithRangeChip, + F: Field, { const OPCODE: u32 = DIV32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -149,14 +151,15 @@ where } } -impl Instruction for SDiv32Instruction +impl Instruction for SDiv32Instruction where - M: MachineWithDiv32Chip + MachineWithRangeChip<256>, + M: MachineWithDiv32Chip + MachineWithRangeChip, + F: Field, { const OPCODE: u32 = SDIV32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index c493ba66..2dd34959 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -13,9 +13,10 @@ use valida_machine::{ use valida_opcodes::LT32; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; @@ -31,12 +32,12 @@ pub struct Lt32Chip { pub operations: Vec, } -impl Chip for Lt32Chip +impl Chip for Lt32Chip where - F: PrimeField, - M: MachineWithGeneralBus, + M: MachineWithGeneralBus, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { let rows = self .operations .par_iter() @@ -46,17 +47,17 @@ where let mut trace = RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_LT_COLS); - pad_to_power_of_two::(&mut trace.values); + pad_to_power_of_two::(&mut trace.values); trace } - fn global_receives(&self, machine: &M) -> Vec> { - let opcode = VirtualPairCol::constant(M::F::from_canonical_u32(LT32)); + fn global_receives(&self, machine: &M) -> Vec> { + let opcode = VirtualPairCol::constant(SC::Val::from_canonical_u32(LT32)); let input_1 = LT_COL_MAP.input_1.0.map(VirtualPairCol::single_main); let input_2 = LT_COL_MAP.input_2.0.map(VirtualPairCol::single_main); let output = (0..MEMORY_CELL_BYTES - 1) - .map(|_| VirtualPairCol::constant(M::F::zero())) + .map(|_| VirtualPairCol::constant(SC::Val::zero())) .chain(iter::once(VirtualPairCol::single_main(LT_COL_MAP.output))); let mut fields = vec![opcode]; @@ -107,21 +108,22 @@ impl Lt32Chip { } } -pub trait MachineWithLt32Chip: MachineWithCpuChip { +pub trait MachineWithLt32Chip: MachineWithCpuChip { fn lt_u32(&self) -> &Lt32Chip; fn lt_u32_mut(&mut self) -> &mut Lt32Chip; } instructions!(Lt32Instruction); -impl Instruction for Lt32Instruction +impl Instruction for Lt32Instruction where - M: MachineWithLt32Chip, + M: MachineWithLt32Chip, + F: Field, { const OPCODE: u32 = LT32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; diff --git a/alu_u32/src/mul/mod.rs b/alu_u32/src/mul/mod.rs index ff6d9c45..45d4c800 100644 --- a/alu_u32/src/mul/mod.rs +++ b/alu_u32/src/mul/mod.rs @@ -11,8 +11,9 @@ use valida_range::MachineWithRangeChip; use core::borrow::BorrowMut; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; +use valida_machine::config::StarkConfig; pub mod columns; pub mod stark; @@ -29,31 +30,31 @@ pub struct Mul32Chip { pub operations: Vec, } -impl Chip for Mul32Chip +impl Chip for Mul32Chip where - F: PrimeField, - M: MachineWithGeneralBus, + M: MachineWithGeneralBus, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { const MIN_LENGTH: usize = 1 << 10; // for the range check counter let num_ops = self.operations.len(); let num_padded_ops = num_ops.next_power_of_two().max(MIN_LENGTH); - let mut values = vec![F::zero(); num_padded_ops * NUM_MUL_COLS]; + let mut values = vec![SC::Val::zero(); num_padded_ops * NUM_MUL_COLS]; // Encode the real operations. for (i, op) in self.operations.iter().enumerate() { let row = &mut values[i * NUM_MUL_COLS..(i + 1) * NUM_MUL_COLS]; - let cols: &mut Mul32Cols = row.borrow_mut(); - cols.counter = F::from_canonical_usize(i + 1); + let cols: &mut Mul32Cols = row.borrow_mut(); + cols.counter = SC::Val::from_canonical_usize(i + 1); self.op_to_row(op, cols); } // Encode dummy operations as needed to pad the trace. for i in num_ops..num_padded_ops { let row = &mut values[i * NUM_MUL_COLS..(i + 1) * NUM_MUL_COLS]; - let cols: &mut Mul32Cols = row.borrow_mut(); - cols.counter = F::from_canonical_usize(i + 1); + let cols: &mut Mul32Cols = row.borrow_mut(); + cols.counter = SC::Val::from_canonical_usize(i + 1); } RowMajorMatrix { @@ -62,14 +63,14 @@ where } } - fn global_receives(&self, machine: &M) -> Vec> { + fn global_receives(&self, machine: &M) -> Vec> { let opcode = VirtualPairCol::new_main( vec![ - (MUL_COL_MAP.is_mul, M::F::from_canonical_u32(MUL32)), - (MUL_COL_MAP.is_mulhs, M::F::from_canonical_u32(MULHS32)), - (MUL_COL_MAP.is_mulhu, M::F::from_canonical_u32(MULHU32)), + (MUL_COL_MAP.is_mul, SC::Val::from_canonical_u32(MUL32)), + (MUL_COL_MAP.is_mulhs, SC::Val::from_canonical_u32(MULHS32)), + (MUL_COL_MAP.is_mulhu, SC::Val::from_canonical_u32(MULHU32)), ], - M::F::zero(), + SC::Val::zero(), ); let input_1 = MUL_COL_MAP.input_1.0.map(VirtualPairCol::single_main); let input_2 = MUL_COL_MAP.input_2.0.map(VirtualPairCol::single_main); @@ -94,7 +95,7 @@ where vec![receive] } - fn local_sends(&self) -> Vec> { + fn local_sends(&self) -> Vec> { // TODO vec![] } @@ -131,21 +132,22 @@ impl Mul32Chip { } } -pub trait MachineWithMul32Chip: MachineWithCpuChip { +pub trait MachineWithMul32Chip: MachineWithCpuChip { fn mul_u32(&self) -> &Mul32Chip; fn mul_u32_mut(&mut self) -> &mut Mul32Chip; } instructions!(Mul32Instruction, Mulhs32Instruction, Mulhu32Instruction); -impl Instruction for Mul32Instruction +impl Instruction for Mul32Instruction where - M: MachineWithMul32Chip + MachineWithRangeChip<256>, + M: MachineWithMul32Chip + MachineWithRangeChip, + F: Field, { const OPCODE: u32 = MUL32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -179,14 +181,15 @@ where } } -impl Instruction for Mulhs32Instruction +impl Instruction for Mulhs32Instruction where - M: MachineWithMul32Chip + MachineWithRangeChip<256>, + M: MachineWithMul32Chip + MachineWithRangeChip, + F: Field, { const OPCODE: u32 = MULHS32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -220,14 +223,15 @@ where } } -impl Instruction for Mulhu32Instruction +impl Instruction for Mulhu32Instruction where - M: MachineWithMul32Chip + MachineWithRangeChip<256>, + M: MachineWithMul32Chip + MachineWithRangeChip, + F: Field, { const OPCODE: u32 = MULHU32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; diff --git a/alu_u32/src/shift/mod.rs b/alu_u32/src/shift/mod.rs index b6855211..b9a0452a 100644 --- a/alu_u32/src/shift/mod.rs +++ b/alu_u32/src/shift/mod.rs @@ -12,9 +12,10 @@ use valida_machine::{instructions, Chip, Instruction, Interaction, Operands, Sra use valida_opcodes::{DIV32, MUL32, SDIV32, SHL32, SHR32, SRA32}; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; @@ -32,12 +33,12 @@ pub struct Shift32Chip { pub operations: Vec, } -impl Chip for Shift32Chip +impl Chip for Shift32Chip where - F: PrimeField, - M: MachineWithGeneralBus + MachineWithRangeBus8, + M: MachineWithGeneralBus + MachineWithRangeBus8, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { let rows = self .operations .par_iter() @@ -49,19 +50,19 @@ where NUM_SHIFT_COLS, ); - pad_to_power_of_two::(&mut trace.values); + pad_to_power_of_two::(&mut trace.values); trace } - fn global_sends(&self, machine: &M) -> Vec> { + fn global_sends(&self, machine: &M) -> Vec> { let opcode = VirtualPairCol::new_main( vec![ - (COL_MAP.is_shl, M::F::from_canonical_u32(MUL32)), - (COL_MAP.is_shr, M::F::from_canonical_u32(DIV32)), - (COL_MAP.is_sra, M::F::from_canonical_u32(SDIV32)), + (COL_MAP.is_shl, SC::Val::from_canonical_u32(MUL32)), + (COL_MAP.is_shr, SC::Val::from_canonical_u32(DIV32)), + (COL_MAP.is_sra, SC::Val::from_canonical_u32(SDIV32)), ], - M::F::zero(), + SC::Val::zero(), ); let input_1 = COL_MAP.input_1.0.map(VirtualPairCol::single_main); let input_2 = COL_MAP.power_of_two.0.map(VirtualPairCol::single_main); @@ -84,14 +85,14 @@ where vec![send] } - fn global_receives(&self, machine: &M) -> Vec> { + fn global_receives(&self, machine: &M) -> Vec> { let opcode = VirtualPairCol::new_main( vec![ - (COL_MAP.is_shl, M::F::from_canonical_u32(SHL32)), - (COL_MAP.is_shr, M::F::from_canonical_u32(SHR32)), - (COL_MAP.is_sra, M::F::from_canonical_u32(SRA32)), + (COL_MAP.is_shl, SC::Val::from_canonical_u32(SHL32)), + (COL_MAP.is_shr, SC::Val::from_canonical_u32(SHR32)), + (COL_MAP.is_sra, SC::Val::from_canonical_u32(SRA32)), ], - M::F::zero(), + SC::Val::zero(), ); let input_1 = COL_MAP.input_1.0.map(VirtualPairCol::single_main); let input_2 = COL_MAP.input_2.0.map(VirtualPairCol::single_main); @@ -162,21 +163,22 @@ impl Shift32Chip { } } -pub trait MachineWithShift32Chip: MachineWithCpuChip { +pub trait MachineWithShift32Chip: MachineWithCpuChip { fn shift_u32(&self) -> &Shift32Chip; fn shift_u32_mut(&mut self) -> &mut Shift32Chip; } instructions!(Shl32Instruction, Shr32Instruction, Sra32Instruction); -impl Instruction for Shl32Instruction +impl Instruction for Shl32Instruction where - M: MachineWithShift32Chip + MachineWithMul32Chip, + M: MachineWithShift32Chip + MachineWithMul32Chip, + F: Field, { const OPCODE: u32 = SHL32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -215,14 +217,15 @@ where } } -impl Instruction for Shr32Instruction +impl Instruction for Shr32Instruction where - M: MachineWithShift32Chip + MachineWithDiv32Chip, + M: MachineWithShift32Chip + MachineWithDiv32Chip, + F: Field, { const OPCODE: u32 = SHR32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -261,14 +264,15 @@ where } } -impl Instruction for Sra32Instruction +impl Instruction for Sra32Instruction where - M: MachineWithShift32Chip + MachineWithDiv32Chip, + M: MachineWithShift32Chip + MachineWithDiv32Chip, + F: Field, { const OPCODE: u32 = SRA32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; diff --git a/alu_u32/src/sub/mod.rs b/alu_u32/src/sub/mod.rs index 623301cc..bd92797f 100644 --- a/alu_u32/src/sub/mod.rs +++ b/alu_u32/src/sub/mod.rs @@ -11,9 +11,10 @@ use valida_opcodes::SUB32; use valida_range::MachineWithRangeChip; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; @@ -29,12 +30,12 @@ pub struct Sub32Chip { pub operations: Vec, } -impl Chip for Sub32Chip +impl Chip for Sub32Chip where - F: PrimeField, - M: MachineWithGeneralBus + MachineWithRangeBus8, + M: MachineWithGeneralBus + MachineWithRangeBus8, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { let rows = self .operations .par_iter() @@ -44,12 +45,12 @@ where let mut trace = RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_SUB_COLS); - pad_to_power_of_two::(&mut trace.values); + pad_to_power_of_two::(&mut trace.values); trace } - fn global_sends(&self, machine: &M) -> Vec> { + fn global_sends(&self, machine: &M) -> Vec> { let sends = SUB_COL_MAP .output .0 @@ -66,8 +67,8 @@ where sends } - fn global_receives(&self, machine: &M) -> Vec> { - let opcode = VirtualPairCol::constant(M::F::from_canonical_u32(SUB32)); + fn global_receives(&self, machine: &M) -> Vec> { + let opcode = VirtualPairCol::constant(SC::Val::from_canonical_u32(SUB32)); let input_1 = SUB_COL_MAP.input_1.0.map(VirtualPairCol::single_main); let input_2 = SUB_COL_MAP.input_2.0.map(VirtualPairCol::single_main); let output = SUB_COL_MAP.output.0.map(VirtualPairCol::single_main); @@ -116,21 +117,22 @@ impl Sub32Chip { } } -pub trait MachineWithSub32Chip: MachineWithCpuChip { +pub trait MachineWithSub32Chip: MachineWithCpuChip { fn sub_u32(&self) -> &Sub32Chip; fn sub_u32_mut(&mut self) -> &mut Sub32Chip; } instructions!(Sub32Instruction); -impl Instruction for Sub32Instruction +impl Instruction for Sub32Instruction where - M: MachineWithSub32Chip + MachineWithRangeChip<256>, + M: MachineWithSub32Chip + MachineWithRangeChip, + F: Field, { const OPCODE: u32 = SUB32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; diff --git a/basic/src/bin/valida.rs b/basic/src/bin/valida.rs index 90135586..6220a318 100644 --- a/basic/src/bin/valida.rs +++ b/basic/src/bin/valida.rs @@ -23,7 +23,7 @@ struct Args { fn main() { let args = Args::parse(); - let mut machine = BasicMachine::::default(); + let mut machine = BasicMachine::::default(); let rom = match ProgramROM::from_file(&args.program) { Ok(contents) => contents, Err(e) => panic!("Failure to load file: {}. {}", &args.program, e), diff --git a/basic/src/lib.rs b/basic/src/lib.rs index 48fa4b97..c93917e5 100644 --- a/basic/src/lib.rs +++ b/basic/src/lib.rs @@ -3,7 +3,8 @@ extern crate alloc; -use p3_field::TwoAdicField; +use core::marker::PhantomData; +use p3_field::{Field, PrimeField32, TwoAdicField}; use valida_alu_u32::{ add::{Add32Chip, Add32Instruction, MachineWithAdd32Chip}, bitwise::{ @@ -39,10 +40,11 @@ use valida_program::{MachineWithProgramChip, ProgramChip}; use valida_range::{MachineWithRangeChip, RangeCheckerChip}; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; #[derive(Machine, Default)] -#[machine_fields(F, EF)] -pub struct BasicMachine> { +#[machine_fields(F)] +pub struct BasicMachine { // Core instructions #[instruction] load32: Load32Instruction, @@ -122,45 +124,34 @@ pub struct BasicMachine> { #[chip] range: RangeCheckerChip<256>, - _phantom_base: core::marker::PhantomData, - _phantom_extension: core::marker::PhantomData, + _phantom_sc: PhantomData F>, } -impl> MachineWithGeneralBus - for BasicMachine -{ +impl MachineWithGeneralBus for BasicMachine { fn general_bus(&self) -> BusArgument { BusArgument::Global(0) } } -impl> MachineWithProgramBus - for BasicMachine -{ +impl MachineWithProgramBus for BasicMachine { fn program_bus(&self) -> BusArgument { BusArgument::Global(1) } } -impl> MachineWithMemBus - for BasicMachine -{ +impl MachineWithMemBus for BasicMachine { fn mem_bus(&self) -> BusArgument { BusArgument::Global(2) } } -impl> MachineWithRangeBus8 - for BasicMachine -{ +impl MachineWithRangeBus8 for BasicMachine { fn range_bus(&self) -> BusArgument { BusArgument::Global(3) } } -impl> MachineWithCpuChip - for BasicMachine -{ +impl MachineWithCpuChip for BasicMachine { fn cpu(&self) -> &CpuChip { &self.cpu } @@ -170,9 +161,7 @@ impl> MachineWithCpuChip } } -impl> MachineWithProgramChip - for BasicMachine -{ +impl MachineWithProgramChip for BasicMachine { fn program(&self) -> &ProgramChip { &self.program } @@ -182,9 +171,7 @@ impl> MachineWithProgramCh } } -impl> MachineWithMemoryChip - for BasicMachine -{ +impl MachineWithMemoryChip for BasicMachine { fn mem(&self) -> &MemoryChip { &self.mem } @@ -194,9 +181,7 @@ impl> MachineWithMemoryChi } } -impl> MachineWithAdd32Chip - for BasicMachine -{ +impl MachineWithAdd32Chip for BasicMachine { fn add_u32(&self) -> &Add32Chip { &self.add_u32 } @@ -206,9 +191,7 @@ impl> MachineWithAdd32Chip } } -impl> MachineWithSub32Chip - for BasicMachine -{ +impl MachineWithSub32Chip for BasicMachine { fn sub_u32(&self) -> &Sub32Chip { &self.sub_u32 } @@ -218,9 +201,7 @@ impl> MachineWithSub32Chip } } -impl> MachineWithMul32Chip - for BasicMachine -{ +impl MachineWithMul32Chip for BasicMachine { fn mul_u32(&self) -> &Mul32Chip { &self.mul_u32 } @@ -230,9 +211,7 @@ impl> MachineWithMul32Chip } } -impl> MachineWithDiv32Chip - for BasicMachine -{ +impl MachineWithDiv32Chip for BasicMachine { fn div_u32(&self) -> &Div32Chip { &self.div_u32 } @@ -242,9 +221,7 @@ impl> MachineWithDiv32Chip } } -impl> MachineWithBitwise32Chip - for BasicMachine -{ +impl MachineWithBitwise32Chip for BasicMachine { fn bitwise_u32(&self) -> &Bitwise32Chip { &self.bitwise_u32 } @@ -254,9 +231,7 @@ impl> MachineWithBitwise32 } } -impl> MachineWithLt32Chip - for BasicMachine -{ +impl MachineWithLt32Chip for BasicMachine { fn lt_u32(&self) -> &Lt32Chip { &self.lt_u32 } @@ -266,9 +241,7 @@ impl> MachineWithLt32Chip } } -impl> MachineWithShift32Chip - for BasicMachine -{ +impl MachineWithShift32Chip for BasicMachine { fn shift_u32(&self) -> &Shift32Chip { &self.shift_u32 } @@ -278,9 +251,7 @@ impl> MachineWithShift32Ch } } -impl> MachineWithOutputChip - for BasicMachine -{ +impl MachineWithOutputChip for BasicMachine { fn output(&self) -> &OutputChip { &self.output } @@ -290,9 +261,7 @@ impl> MachineWithOutputChi } } -impl> MachineWithRangeChip<256> - for BasicMachine -{ +impl MachineWithRangeChip for BasicMachine { fn range(&self) -> &RangeCheckerChip<256> { &self.range } diff --git a/basic/tests/test_interpreter.rs b/basic/tests/test_interpreter.rs index fbdb0a4f..4c6daf4d 100644 --- a/basic/tests/test_interpreter.rs +++ b/basic/tests/test_interpreter.rs @@ -9,7 +9,7 @@ use valida_program::MachineWithProgramChip; #[test] fn run_fibonacci() { - let mut machine = BasicMachine::::default(); + let mut machine = BasicMachine::::default(); let asm_path = "tests/programs/assembly/fibonacci.val"; let asm = read_to_string(asm_path).expect("Failed to read asm"); let rom = ProgramROM::from_machine_code(&assemble(&asm).unwrap()); diff --git a/basic/tests/test_prover.rs b/basic/tests/test_prover.rs index 9b459370..d19e37a0 100644 --- a/basic/tests/test_prover.rs +++ b/basic/tests/test_prover.rs @@ -55,35 +55,35 @@ fn prove_fibonacci() { //... program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-4, 0, 0, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-8, 0, 0, 0, 25]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-16, -8, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-20, 0, 0, 0, 28]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-28, fib_bb0, -28, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-12, -24, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([4, -12, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands::default(), }, ]); @@ -97,23 +97,23 @@ fn prove_fibonacci() { // beq .LBB0_1, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-4, 12, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-8, 0, 0, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-12, 0, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-16, 0, 0, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([fib_bb0_1, 0, 0, 0, 0]), }, ]); @@ -123,11 +123,11 @@ fn prove_fibonacci() { // beq .LBB0_4, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([fib_bb0_2, -16, -4, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([fib_bb0_4, 0, 0, 0, 0]), }, ]); @@ -139,19 +139,19 @@ fn prove_fibonacci() { // beq .LBB0_3, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-20, -8, -12, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-8, -12, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-12, -20, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([fib_bb0_3, 0, 0, 0, 0]), }, ]); @@ -161,11 +161,11 @@ fn prove_fibonacci() { // beq .LBB0_1, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-16, -16, 1, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([fib_bb0_1, 0, 0, 0, 0]), }, ]); @@ -175,16 +175,16 @@ fn prove_fibonacci() { // jalv -4(fp), 0(fp), 8(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([4, -8, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: , Val>>::OPCODE, operands: Operands([-4, 0, 8, 0, 0]), }, ]); - let mut machine = BasicMachine::::default(); + let mut machine = BasicMachine::::default(); let rom = ProgramROM::new(program); machine.program_mut().set_program_rom(&rom); machine.cpu_mut().fp = 0x1000; diff --git a/bus/src/lib.rs b/bus/src/lib.rs index 63c2fdf2..49830be5 100644 --- a/bus/src/lib.rs +++ b/bus/src/lib.rs @@ -1,3 +1,4 @@ +use p3_field::Field; use valida_machine::{BusArgument, Machine}; #[derive(Default)] @@ -6,22 +7,22 @@ pub struct CpuMemBus {} #[derive(Default)] pub struct SharedCoprocessorBus {} -pub trait MachineWithGeneralBus: Machine { +pub trait MachineWithGeneralBus: Machine { fn general_bus(&self) -> BusArgument; } -pub trait MachineWithProgramBus: Machine { +pub trait MachineWithProgramBus: Machine { fn program_bus(&self) -> BusArgument; } -pub trait MachineWithMemBus: Machine { +pub trait MachineWithMemBus: Machine { fn mem_bus(&self) -> BusArgument; } -pub trait MachineWithRangeBus8: Machine { +pub trait MachineWithRangeBus8: Machine { fn range_bus(&self) -> BusArgument; } -pub trait MachineWithPowerOfTwoBus: Machine { +pub trait MachineWithPowerOfTwoBus: Machine { fn power_of_two_bus(&self) -> BusArgument; } diff --git a/cpu/src/lib.rs b/cpu/src/lib.rs index b5c0cea1..6fef97d7 100644 --- a/cpu/src/lib.rs +++ b/cpu/src/lib.rs @@ -20,9 +20,10 @@ use valida_opcodes::{ use valida_util::batch_multiplicative_inverse; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; pub mod columns; pub mod stark; @@ -58,20 +59,21 @@ pub struct Registers { fp: u32, } -impl Chip for CpuChip +impl Chip for CpuChip where - M: MachineWithProgramBus - + MachineWithMemoryChip - + MachineWithGeneralBus - + MachineWithMemBus + M: MachineWithProgramBus + + MachineWithMemoryChip + + MachineWithGeneralBus + + MachineWithMemBus + Sync, + SC: StarkConfig, { - fn generate_trace(&self, machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, machine: &M) -> RowMajorMatrix { let mut rows = self .operations .par_iter() .enumerate() - .map(|(n, op)| self.op_to_row(n, op, machine)) + .map(|(n, op)| self.op_to_row::(n, op, machine)) .collect::>(); // Set diff, diff_inv, and not_equal @@ -85,7 +87,7 @@ where trace } - fn global_sends(&self, machine: &M) -> Vec> { + fn global_sends(&self, machine: &M) -> Vec> { // Memory bus channels let mem_sends = (0..3).map(|i| { let channel = &CPU_COL_MAP.mem_channels[i]; @@ -148,84 +150,80 @@ where } impl CpuChip { - fn op_to_row>( - &self, - clk: usize, - op: &Operation, - machine: &M, - ) -> [F; NUM_CPU_COLS] + fn op_to_row(&self, clk: usize, op: &Operation, machine: &M) -> [SC::Val; NUM_CPU_COLS] where - M: MachineWithMemoryChip, + M: MachineWithMemoryChip, + SC: StarkConfig, { - let mut row = [F::zero(); NUM_CPU_COLS]; - let cols: &mut CpuCols = unsafe { transmute(&mut row) }; + let mut row = [SC::Val::zero(); NUM_CPU_COLS]; + let cols: &mut CpuCols = unsafe { transmute(&mut row) }; - cols.pc = F::from_canonical_u32(self.registers[clk].pc); - cols.fp = F::from_canonical_u32(self.registers[clk].fp); - cols.clk = F::from_canonical_usize(clk); + cols.pc = SC::Val::from_canonical_u32(self.registers[clk].pc); + cols.fp = SC::Val::from_canonical_u32(self.registers[clk].fp); + cols.clk = SC::Val::from_canonical_usize(clk); self.set_instruction_values(clk, cols); - self.set_memory_channel_values(clk, cols, machine); + self.set_memory_channel_values::(clk, cols, machine); match op { Operation::Store32 => { - cols.opcode_flags.is_store = F::one(); + cols.opcode_flags.is_store = SC::Val::one(); } Operation::Load32 => { - cols.opcode_flags.is_load = F::one(); + cols.opcode_flags.is_load = SC::Val::one(); } Operation::Jal => { - cols.opcode_flags.is_jal = F::one(); + cols.opcode_flags.is_jal = SC::Val::one(); } Operation::Jalv => { - cols.opcode_flags.is_jalv = F::one(); + cols.opcode_flags.is_jalv = SC::Val::one(); } Operation::Beq(imm) => { - cols.opcode_flags.is_beq = F::one(); + cols.opcode_flags.is_beq = SC::Val::one(); self.set_imm_value(cols, *imm); } Operation::Bne(imm) => { - cols.opcode_flags.is_bne = F::one(); + cols.opcode_flags.is_bne = SC::Val::one(); self.set_imm_value(cols, *imm); } Operation::Imm32 => { - cols.opcode_flags.is_imm32 = F::one(); + cols.opcode_flags.is_imm32 = SC::Val::one(); } Operation::Bus(imm) => { - cols.opcode_flags.is_bus_op = F::one(); + cols.opcode_flags.is_bus_op = SC::Val::one(); self.set_imm_value(cols, *imm); } Operation::BusWithMemory(imm) => { - cols.opcode_flags.is_bus_op = F::one(); - cols.opcode_flags.is_bus_op_with_mem = F::one(); + cols.opcode_flags.is_bus_op = SC::Val::one(); + cols.opcode_flags.is_bus_op_with_mem = SC::Val::one(); self.set_imm_value(cols, *imm); } Operation::ReadAdvice => { - cols.opcode_flags.is_advice = F::one(); + cols.opcode_flags.is_advice = SC::Val::one(); } Operation::Stop => { - cols.opcode_flags.is_stop = F::one(); + cols.opcode_flags.is_stop = SC::Val::one(); } } row } - fn set_instruction_values(&self, clk: usize, cols: &mut CpuCols) { + fn set_instruction_values(&self, clk: usize, cols: &mut CpuCols) { cols.instruction.opcode = F::from_canonical_u32(self.instructions[clk].opcode); cols.instruction.operands = Operands::::from_i32_slice(&self.instructions[clk].operands.0); } - fn set_memory_channel_values>( + fn set_memory_channel_values, SC: StarkConfig>( &self, clk: usize, - cols: &mut CpuCols, + cols: &mut CpuCols, machine: &M, ) { - cols.mem_channels[0].is_read = F::one(); - cols.mem_channels[1].is_read = F::one(); - cols.mem_channels[2].is_read = F::zero(); + cols.mem_channels[0].is_read = SC::Val::one(); + cols.mem_channels[1].is_read = SC::Val::one(); + cols.mem_channels[2].is_read = SC::Val::zero(); let memory = machine.mem(); for ops in memory.operations.get(&(clk as u32)).iter() { @@ -234,20 +232,22 @@ impl CpuChip { match op { MemoryOperation::Read(addr, value) => { if is_first_read { - cols.mem_channels[0].used = F::one(); - cols.mem_channels[0].addr = F::from_canonical_u32(*addr); - cols.mem_channels[0].value = value.transform(F::from_canonical_u8); + cols.mem_channels[0].used = SC::Val::one(); + cols.mem_channels[0].addr = SC::Val::from_canonical_u32(*addr); + cols.mem_channels[0].value = + value.transform(SC::Val::from_canonical_u8); is_first_read = false; } else { - cols.mem_channels[1].used = F::one(); - cols.mem_channels[1].addr = F::from_canonical_u32(*addr); - cols.mem_channels[1].value = value.transform(F::from_canonical_u8); + cols.mem_channels[1].used = SC::Val::one(); + cols.mem_channels[1].addr = SC::Val::from_canonical_u32(*addr); + cols.mem_channels[1].value = + value.transform(SC::Val::from_canonical_u8); } } MemoryOperation::Write(addr, value) => { - cols.mem_channels[2].used = F::one(); - cols.mem_channels[2].addr = F::from_canonical_u32(*addr); - cols.mem_channels[2].value = value.transform(F::from_canonical_u8); + cols.mem_channels[2].used = SC::Val::one(); + cols.mem_channels[2].addr = SC::Val::from_canonical_u32(*addr); + cols.mem_channels[2].value = value.transform(SC::Val::from_canonical_u8); } _ => {} } @@ -325,7 +325,7 @@ impl CpuChip { }); } - fn set_imm_value(&self, cols: &mut CpuCols, imm: Option>) { + fn set_imm_value(&self, cols: &mut CpuCols, imm: Option>) { if let Some(imm) = imm { cols.opcode_flags.is_imm_op = F::one(); cols.mem_channels[1].value = imm.transform(F::from_canonical_u8); @@ -333,7 +333,7 @@ impl CpuChip { } } -pub trait MachineWithCpuChip: MachineWithMemoryChip { +pub trait MachineWithCpuChip: MachineWithMemoryChip { fn cpu(&self) -> &CpuChip; fn cpu_mut(&mut self) -> &mut CpuChip; } @@ -352,9 +352,10 @@ instructions!( /// Non-deterministic instructions -impl Instruction for ReadAdviceInstruction +impl Instruction for ReadAdviceInstruction where - M: MachineWithCpuChip, + M: MachineWithCpuChip, + F: Field, { const OPCODE: u32 = READ_ADVICE; @@ -364,7 +365,7 @@ where fn execute_with_advice(state: &mut M, ops: Operands, advice: &mut Adv) where - M: MachineWithCpuChip, + M: MachineWithCpuChip, Adv: AdviceProvider, { let clk = state.cpu().clock; @@ -378,22 +379,25 @@ where .write(clk, mem_addr as u32, Word::from_u8(advice_byte), true); state.cpu_mut().pc += 1; - state - .cpu_mut() - .push_op(Operation::ReadAdvice, >::OPCODE, ops); + state.cpu_mut().push_op( + Operation::ReadAdvice, + >::OPCODE, + ops, + ); } } /// Deterministic instructions -impl Instruction for Load32Instruction +impl Instruction for Load32Instruction where - M: MachineWithCpuChip, + M: MachineWithCpuChip, + F: Field, { const OPCODE: u32 = LOAD32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let fp = state.cpu().fp; @@ -422,14 +426,15 @@ where } } -impl Instruction for Store32Instruction +impl Instruction for Store32Instruction where - M: MachineWithCpuChip, + M: MachineWithCpuChip, + F: Field, { const OPCODE: u32 = STORE32; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let read_addr = (state.cpu().fp as i32 + ops.c()) as u32; let write_addr_loc = (state.cpu().fp as i32 + ops.b()) as u32; @@ -446,9 +451,10 @@ where } } -impl Instruction for JalInstruction +impl Instruction for JalInstruction where - M: MachineWithCpuChip, + M: MachineWithCpuChip, + F: Field, { const OPCODE: u32 = JAL; @@ -466,18 +472,19 @@ where state.cpu_mut().fp = (state.cpu().fp as i32 + ops.c()) as u32; state .cpu_mut() - .push_op(Operation::Jal, >::OPCODE, ops); + .push_op(Operation::Jal, >::OPCODE, ops); } } -impl Instruction for JalvInstruction +impl Instruction for JalvInstruction where - M: MachineWithCpuChip, + M: MachineWithCpuChip, + F: Field, { const OPCODE: u32 = JALV; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; // Store pc + 1 to local stack variable at offset a @@ -504,14 +511,15 @@ where } } -impl Instruction for BeqInstruction +impl Instruction for BeqInstruction where - M: MachineWithCpuChip, + M: MachineWithCpuChip, + F: Field, { const OPCODE: u32 = BEQ; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let mut imm: Option> = None; let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; @@ -538,14 +546,15 @@ where } } -impl Instruction for BneInstruction +impl Instruction for BneInstruction where - M: MachineWithCpuChip, + M: MachineWithCpuChip, + F: Field, { const OPCODE: u32 = BNE; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let mut imm: Option> = None; let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; @@ -572,9 +581,10 @@ where } } -impl Instruction for Imm32Instruction +impl Instruction for Imm32Instruction where - M: MachineWithCpuChip, + M: MachineWithCpuChip, + F: Field, { const OPCODE: u32 = IMM32; @@ -586,13 +596,14 @@ where state.cpu_mut().pc += 1; state .cpu_mut() - .push_op(Operation::Imm32, >::OPCODE, ops); + .push_op(Operation::Imm32, >::OPCODE, ops); } } -impl Instruction for StopInstruction +impl Instruction for StopInstruction where - M: MachineWithCpuChip, + M: MachineWithCpuChip, + F: Field, { const OPCODE: u32 = STOP; @@ -600,7 +611,7 @@ where state.cpu_mut().pc = state.cpu().pc; state .cpu_mut() - .push_op(Operation::Stop, >::OPCODE, ops); + .push_op(Operation::Stop, >::OPCODE, ops); } } diff --git a/derive/src/lib.rs b/derive/src/lib.rs index f4db9c2e..f184fef0 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -10,24 +10,19 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::quote; use syn::parse::{Parse, ParseStream}; -use syn::{spanned::Spanned, Data, Field, Fields, Ident, Token}; +use syn::{spanned::Spanned, Data, Field, Fields, Ident}; +// TODO: now trivial with a single field struct MachineFields { - base_field: Ident, - ext_field: Ident, + val: Ident, } impl Parse for MachineFields { fn parse(input: ParseStream) -> syn::Result { let content; syn::parenthesized!(content in input); - let base_field = content.parse()?; - content.parse::()?; - let ext_field = content.parse()?; - Ok(MachineFields { - base_field, - ext_field, - }) + let val = content.parse()?; + Ok(MachineFields { val }) } } @@ -56,30 +51,25 @@ fn impl_machine(machine: &syn::DeriveInput) -> TokenStream { .copied() .collect::>(); - let name = &machine.ident; - let run = run_method(machine, &instructions); - let prove = prove_method(&chips); - let verify = verify_method(&chips); - - let (impl_generics, ty_generics, where_clause) = machine.generics.split_for_impl(); - let machine_fields = machine .attrs .iter() .filter(|a| a.path.segments.len() == 1 && a.path.segments[0].ident == "machine_fields") .next() .expect("machine_fields attribute required to derive Machine"); - let machine_fields: MachineFields = syn::parse2(machine_fields.tokens.clone()).expect( - "Invalid machine_fields attribute, expected #[machine_fields(, )]", - ); + let machine_fields: MachineFields = syn::parse2(machine_fields.tokens.clone()) + .expect("Invalid machine_fields attribute, expected #[machine_fields()]"); + let val = &machine_fields.val; + + let name = &machine.ident; + let run = run_method(machine, &instructions, &val); + let prove = prove_method(&chips); + let verify = verify_method(&chips); - let base_field = &machine_fields.base_field; - let ext_field = &machine_fields.ext_field; + let (impl_generics, ty_generics, where_clause) = machine.generics.split_for_impl(); let stream = quote! { - impl #impl_generics Machine for #name #ty_generics #where_clause { - type F = #base_field; - type EF = #ext_field; + impl #impl_generics Machine<#val> for #name #ty_generics #where_clause { #run #prove #verify @@ -137,7 +127,7 @@ fn chip_methods(chip: &Field) -> TokenStream2 { } } -fn run_method(machine: &syn::DeriveInput, instructions: &[&Field]) -> TokenStream2 { +fn run_method(machine: &syn::DeriveInput, instructions: &[&Field], val: &Ident) -> TokenStream2 { let name = &machine.ident; let (_, ty_generics, _) = machine.generics.split_for_impl(); @@ -146,7 +136,8 @@ fn run_method(machine: &syn::DeriveInput, instructions: &[&Field]) -> TokenStrea .map(|inst| { let ty = &inst.ty; quote! { - <#ty as Instruction<#name #ty_generics>>::OPCODE => + // TODO: Self instead of #name #ty_generics? + <#ty as Instruction<#name #ty_generics, #val>>::OPCODE => #ty::execute_with_advice::(self, ops, advice), } }) @@ -169,7 +160,7 @@ fn run_method(machine: &syn::DeriveInput, instructions: &[&Field]) -> TokenStrea self.read_word(pc as usize); // A STOP instruction signals the end of the program - if opcode == >::OPCODE { + if opcode == >::OPCODE { break; } } @@ -202,7 +193,7 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { let chip_name = chip.ident.as_ref().unwrap(); quote! { #[cfg(debug_assertions)] - check_constraints( + check_constraints::( self, self.#chip_name(), &main_traces[#n], @@ -217,9 +208,7 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { quote! { #[tracing::instrument(name = "prove machine execution", skip_all)] - fn prove(&self, config: &SC) -> ::valida_machine::proof::MachineProof - where - SC: ::valida_machine::config::StarkConfig, + fn prove>(&self, config: &SC) -> ::valida_machine::proof::MachineProof { use ::valida_machine::__internal::*; use ::valida_machine::__internal::p3_challenger::{CanObserve, FieldChallenger}; @@ -232,7 +221,7 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { use alloc::vec::Vec; use alloc::boxed::Box; - let mut chips: [Box<&dyn Chip>; #num_chips] = [ #chip_list ]; + let mut chips: [Box<&dyn Chip>; #num_chips] = [ #chip_list ]; let mut challenger = config.challenger(); @@ -293,7 +282,7 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { #prove_starks #[cfg(debug_assertions)] - check_cumulative_sums::(&perm_traces[..]); + check_cumulative_sums(&perm_traces[..]); MachineProof { // opening_proof, @@ -306,11 +295,9 @@ fn prove_method(chips: &[&Field]) -> TokenStream2 { fn verify_method(_chips: &[&Field]) -> TokenStream2 { quote! { - fn verify( + fn verify>( proof: &::valida_machine::proof::MachineProof, ) -> core::result::Result<(), ()> - where - SC: ::valida_machine::config::StarkConfig { Ok(()) // TODO } diff --git a/machine/src/__internal/check_constraints.rs b/machine/src/__internal/check_constraints.rs index 029bc537..424c7af5 100644 --- a/machine/src/__internal/check_constraints.rs +++ b/machine/src/__internal/check_constraints.rs @@ -1,23 +1,25 @@ use crate::__internal::DebugConstraintBuilder; use crate::chip::eval_permutation_constraints; +use crate::config::StarkConfig; use crate::{Chip, Machine}; use p3_air::{Air, TwoRowMatrixView}; -use p3_field::AbstractField; +use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use p3_matrix::MatrixRowSlices; use p3_maybe_rayon::{MaybeIntoParIter, ParallelIterator}; /// Check that all constraints vanish on the subgroup. -pub fn check_constraints( +pub fn check_constraints( machine: &M, air: &A, - main: &RowMajorMatrix, - perm: &RowMajorMatrix, - perm_challenges: &[M::EF], + main: &RowMajorMatrix, + perm: &RowMajorMatrix, + perm_challenges: &[SC::Challenge], ) where - M: Machine + Sync, - A: for<'a> Air> + Chip, + M: Machine + Sync, + A: Chip + for<'a> Air>, + SC: StarkConfig, { assert_eq!(main.height(), perm.height()); let height = main.height(); @@ -63,16 +65,16 @@ pub fn check_constraints( next: &perm_next, }, perm_challenges, - is_first_row: M::F::zero(), - is_last_row: M::F::zero(), - is_transition: M::F::one(), + is_first_row: SC::Val::zero(), + is_last_row: SC::Val::zero(), + is_transition: SC::Val::one(), }; if i == 0 { - builder.is_first_row = M::F::one(); + builder.is_first_row = SC::Val::one(); } if i == height - 1 { - builder.is_last_row = M::F::one(); - builder.is_transition = M::F::zero(); + builder.is_last_row = SC::Val::one(); + builder.is_transition = SC::Val::zero(); } air.eval(&mut builder); @@ -81,13 +83,10 @@ pub fn check_constraints( } /// Check that the combined cumulative sum across all lookup tables is zero. -pub fn check_cumulative_sums(perms: &[RowMajorMatrix]) -where - M: Machine + Sync, -{ - let sum: M::EF = perms +pub fn check_cumulative_sums(perms: &[RowMajorMatrix]) { + let sum: Challenge = perms .iter() .map(|perm| *perm.row_slice(perm.height() - 1).last().unwrap()) .sum(); - assert_eq!(sum, M::EF::zero()); + assert_eq!(sum, Challenge::zero()); } diff --git a/machine/src/__internal/debug_builder.rs b/machine/src/__internal/debug_builder.rs index cb237977..6cfcaa08 100644 --- a/machine/src/__internal/debug_builder.rs +++ b/machine/src/__internal/debug_builder.rs @@ -1,70 +1,30 @@ +use crate::config::StarkConfig; use crate::{Machine, ValidaAirBuilder}; use p3_air::{AirBuilder, PairBuilder, PermutationAirBuilder, TwoRowMatrixView}; -use p3_field::{ExtensionField, Field}; +use p3_field::AbstractField; /// An `AirBuilder` which asserts that each constraint is zero, allowing any failed constraints to /// be detected early. -pub struct DebugConstraintBuilder<'a, F: Field, EF: ExtensionField, M: Machine> { +pub struct DebugConstraintBuilder<'a, M: Machine, SC: StarkConfig> { pub(crate) machine: &'a M, - pub(crate) main: TwoRowMatrixView<'a, F>, - pub(crate) preprocessed: TwoRowMatrixView<'a, F>, - pub(crate) perm: TwoRowMatrixView<'a, EF>, - pub(crate) perm_challenges: &'a [EF], - pub(crate) is_first_row: F, - pub(crate) is_last_row: F, - pub(crate) is_transition: F, + pub(crate) main: TwoRowMatrixView<'a, SC::Val>, + pub(crate) preprocessed: TwoRowMatrixView<'a, SC::Val>, + pub(crate) perm: TwoRowMatrixView<'a, SC::Challenge>, + pub(crate) perm_challenges: &'a [SC::Challenge], + pub(crate) is_first_row: SC::Val, + pub(crate) is_last_row: SC::Val, + pub(crate) is_transition: SC::Val, } -impl<'a, F, EF, M> PermutationAirBuilder for DebugConstraintBuilder<'a, F, EF, M> +impl<'a, M, SC> AirBuilder for DebugConstraintBuilder<'a, M, SC> where - F: Field, - EF: ExtensionField, - M: Machine, + M: Machine, + SC: StarkConfig, { - type EF = M::EF; - type VarEF = M::EF; - type ExprEF = M::EF; - type MP = TwoRowMatrixView<'a, EF>; - - fn permutation(&self) -> Self::MP { - self.perm - } - - fn permutation_randomness(&self) -> &[Self::EF] { - // TODO: implement - self.perm_challenges - } -} - -impl<'a, F, EF, M> PairBuilder for DebugConstraintBuilder<'a, F, EF, M> -where - F: Field, - EF: ExtensionField, - M: Machine, -{ - fn preprocessed(&self) -> Self::M { - self.preprocessed - } -} - -impl<'a, M: Machine> ValidaAirBuilder for DebugConstraintBuilder<'a, M::F, M::EF, M> { - type Machine = M; - - fn machine(&self) -> &Self::Machine { - self.machine - } -} - -impl<'a, F, EF, M> AirBuilder for DebugConstraintBuilder<'a, F, EF, M> -where - F: Field, - EF: ExtensionField, - M: Machine, -{ - type F = F; - type Expr = F; - type Var = F; - type M = TwoRowMatrixView<'a, F>; + type F = SC::Val; + type Expr = SC::Val; + type Var = SC::Val; + type M = TwoRowMatrixView<'a, SC::Val>; fn is_first_row(&self) -> Self::Expr { self.is_first_row @@ -87,6 +47,51 @@ where } fn assert_zero>(&mut self, x: I) { - assert_eq!(x.into(), F::zero(), "constraints must evaluate to zero"); + assert_eq!( + x.into(), + SC::Val::zero(), + "constraints must evaluate to zero" + ); + } +} + +impl<'a, M, SC> PairBuilder for DebugConstraintBuilder<'a, M, SC> +where + M: Machine, + SC: StarkConfig, +{ + fn preprocessed(&self) -> Self::M { + self.preprocessed + } +} + +impl<'a, M, SC> PermutationAirBuilder for DebugConstraintBuilder<'a, M, SC> +where + M: Machine, + SC: StarkConfig, +{ + type EF = SC::Challenge; + type VarEF = SC::Challenge; + type ExprEF = SC::Challenge; + type MP = TwoRowMatrixView<'a, SC::Challenge>; + + fn permutation(&self) -> Self::MP { + self.perm + } + + fn permutation_randomness(&self) -> &[Self::EF] { + self.perm_challenges + } +} + +impl<'a, M: Machine, SC> ValidaAirBuilder for DebugConstraintBuilder<'a, M, SC> +where + M: Machine, + SC: StarkConfig, +{ + type Machine = M; + + fn machine(&self) -> &Self::Machine { + self.machine } } diff --git a/machine/src/__internal/folding_builder.rs b/machine/src/__internal/folding_builder.rs index 7f9a4fd2..85ecfea3 100644 --- a/machine/src/__internal/folding_builder.rs +++ b/machine/src/__internal/folding_builder.rs @@ -1,68 +1,27 @@ use crate::{Machine, ValidaAirBuilder}; use p3_air::{AirBuilder, PairBuilder, PermutationAirBuilder, TwoRowMatrixView}; -use p3_field::{ExtensionField, Field}; +use valida_machine::config::StarkConfig; -pub struct ConstraintFolder<'a, F: Field, EF: ExtensionField, M: Machine> { +pub struct ConstraintFolder<'a, M: Machine, SC: StarkConfig> { pub(crate) machine: &'a M, - pub(crate) main: TwoRowMatrixView<'a, F>, - pub(crate) preprocessed: TwoRowMatrixView<'a, F>, - pub(crate) perm: TwoRowMatrixView<'a, EF>, - pub(crate) rand_elems: &'a [EF], - pub(crate) is_first_row: F, - pub(crate) is_last_row: F, - pub(crate) is_transition: F, + pub(crate) main: TwoRowMatrixView<'a, SC::Val>, + pub(crate) preprocessed: TwoRowMatrixView<'a, SC::Val>, + pub(crate) perm: TwoRowMatrixView<'a, SC::Challenge>, + pub(crate) rand_elems: &'a [SC::Challenge], + pub(crate) is_first_row: SC::Val, + pub(crate) is_last_row: SC::Val, + pub(crate) is_transition: SC::Val, } -impl<'a, F, EF, M> PermutationAirBuilder for ConstraintFolder<'a, F, EF, M> +impl<'a, M, SC> AirBuilder for ConstraintFolder<'a, M, SC> where - F: Field, - EF: ExtensionField, - M: Machine, + M: Machine, + SC: StarkConfig, { - type EF = M::EF; - type VarEF = M::EF; - type ExprEF = M::EF; - type MP = TwoRowMatrixView<'a, EF>; - - fn permutation(&self) -> Self::MP { - self.perm - } - - fn permutation_randomness(&self) -> &[Self::EF] { - // TODO: implement - self.rand_elems - } -} - -impl<'a, F, EF, M> PairBuilder for ConstraintFolder<'a, F, EF, M> -where - F: Field, - EF: ExtensionField, - M: Machine, -{ - fn preprocessed(&self) -> Self::M { - self.preprocessed - } -} - -impl<'a, M: Machine> ValidaAirBuilder for ConstraintFolder<'a, M::F, M::EF, M> { - type Machine = M; - - fn machine(&self) -> &Self::Machine { - self.machine - } -} - -impl<'a, F, EF, M> AirBuilder for ConstraintFolder<'a, F, EF, M> -where - F: Field, - EF: ExtensionField, - M: Machine, -{ - type F = F; - type Expr = F; - type Var = F; - type M = TwoRowMatrixView<'a, F>; + type F = SC::Val; + type Expr = SC::Val; // TODO: PackedVal + type Var = SC::Val; // TODO: PackedVal + type M = TwoRowMatrixView<'a, SC::Val>; // TODO: PackedVal fn is_first_row(&self) -> Self::Expr { self.is_first_row @@ -88,3 +47,45 @@ where // TODO } } + +impl<'a, M, SC> PairBuilder for ConstraintFolder<'a, M, SC> +where + M: Machine, + SC: StarkConfig, +{ + fn preprocessed(&self) -> Self::M { + self.preprocessed + } +} + +impl<'a, M, SC> PermutationAirBuilder for ConstraintFolder<'a, M, SC> +where + M: Machine, + SC: StarkConfig, +{ + type EF = SC::Challenge; + type VarEF = SC::Challenge; + type ExprEF = SC::Challenge; + type MP = TwoRowMatrixView<'a, SC::Challenge>; // TODO: packed challenge? + + fn permutation(&self) -> Self::MP { + self.perm + } + + fn permutation_randomness(&self) -> &[Self::EF] { + // TODO: implement + self.rand_elems + } +} + +impl<'a, M, SC> ValidaAirBuilder for ConstraintFolder<'a, M, SC> +where + M: Machine, + SC: StarkConfig, +{ + type Machine = M; + + fn machine(&self) -> &Self::Machine { + self.machine + } +} diff --git a/machine/src/__internal/prove.rs b/machine/src/__internal/prove.rs index 4565f5d1..405de103 100644 --- a/machine/src/__internal/prove.rs +++ b/machine/src/__internal/prove.rs @@ -1,8 +1,6 @@ -use crate::__internal::ConstraintFolder; use crate::config::StarkConfig; use crate::proof::ChipProof; use crate::{Chip, Machine}; -use p3_air::Air; pub fn prove( _machine: &M, @@ -11,9 +9,9 @@ pub fn prove( _challenger: &mut SC::Challenger, ) -> ChipProof where - M: Machine, - A: for<'a> Air> + Chip, - SC: StarkConfig, + M: Machine, + A: Chip, + SC: StarkConfig, { // TODO: Sumcheck ChipProof diff --git a/machine/src/chip.rs b/machine/src/chip.rs index b793352f..7719a36b 100644 --- a/machine/src/chip.rs +++ b/machine/src/chip.rs @@ -1,35 +1,38 @@ use crate::Machine; -use crate::__internal::ConstraintFolder; +use crate::__internal::{ConstraintFolder, DebugConstraintBuilder}; use alloc::vec; use alloc::vec::Vec; use valida_util::batch_multiplicative_inverse; +use crate::config::StarkConfig; use p3_air::{Air, AirBuilder, PairBuilder, PermutationAirBuilder, VirtualPairCol}; -use p3_field::{AbstractExtensionField, AbstractField, ExtensionField, Field, Powers, PrimeField}; +use p3_field::{AbstractExtensionField, AbstractField, ExtensionField, Field, Powers}; use p3_matrix::{dense::RowMajorMatrix, Matrix, MatrixRowSlices}; -pub trait Chip: for<'a> Air> { +pub trait Chip, SC: StarkConfig>: + for<'a> Air> + for<'a> Air> +{ /// Generate the main trace for the chip given the provided machine. - fn generate_trace(&self, machine: &M) -> RowMajorMatrix; + fn generate_trace(&self, machine: &M) -> RowMajorMatrix; - fn local_sends(&self) -> Vec> { + fn local_sends(&self) -> Vec> { vec![] } - fn local_receives(&self) -> Vec> { + fn local_receives(&self) -> Vec> { vec![] } - fn global_sends(&self, _machine: &M) -> Vec> { + fn global_sends(&self, _machine: &M) -> Vec> { vec![] } - fn global_receives(&self, _machine: &M) -> Vec> { + fn global_receives(&self, _machine: &M) -> Vec> { vec![] } - fn all_interactions(&self, machine: &M) -> Vec<(Interaction, InteractionType)> { - let mut interactions: Vec<(Interaction, InteractionType)> = vec![]; + fn all_interactions(&self, machine: &M) -> Vec<(Interaction, InteractionType)> { + let mut interactions: Vec<(Interaction, InteractionType)> = vec![]; interactions.extend( self.local_sends() .into_iter() @@ -105,15 +108,15 @@ impl Interaction { /// Generate the permutation trace for a chip with the provided machine. /// This is called only after `generate_trace` has been called on all chips. -pub fn generate_permutation_trace( +pub fn generate_permutation_trace( machine: &M, - chip: &dyn Chip, - main: &RowMajorMatrix, - random_elements: Vec, -) -> RowMajorMatrix + chip: &dyn Chip, + main: &RowMajorMatrix, + random_elements: Vec, +) -> RowMajorMatrix where - F: Field, - M: Machine, + M: Machine, + SC: StarkConfig, { let all_interactions = chip.all_interactions(machine); let (alphas_local, alphas_global) = generate_rlc_elements(machine, chip, &random_elements); @@ -134,7 +137,7 @@ where let mut perm_values = Vec::with_capacity(main.height() * perm_width); for (n, main_row) in main.rows().enumerate() { - let mut row = vec![M::EF::zero(); perm_width]; + let mut row = vec![SC::Challenge::zero(); perm_width]; for (m, (interaction, _)) in all_interactions.iter().enumerate() { let alpha_m = if interaction.is_local() { alphas_local[interaction.argument_index()] @@ -160,7 +163,7 @@ where let mut perm = RowMajorMatrix::new(perm_values, perm_width); // Compute the running sum column - let mut phi = vec![M::EF::zero(); perm.height()]; + let mut phi = vec![SC::Challenge::zero(); perm.height()]; for (n, (main_row, perm_row)) in main.rows().zip(perm.rows()).enumerate() { if n > 0 { phi[n] = phi[n - 1]; @@ -173,13 +176,13 @@ where for (m, (interaction, interaction_type)) in all_interactions.iter().enumerate() { let mult = interaction .count - .apply::(preprocessed_row, main_row); + .apply::(preprocessed_row, main_row); match interaction_type { InteractionType::LocalSend | InteractionType::GlobalSend => { - phi[n] += M::EF::from_base(mult) * perm_row[m]; + phi[n] += SC::Challenge::from_base(mult) * perm_row[m]; } InteractionType::LocalReceive | InteractionType::GlobalReceive => { - phi[n] -= M::EF::from_base(mult) * perm_row[m]; + phi[n] -= SC::Challenge::from_base(mult) * perm_row[m]; } } } @@ -192,12 +195,15 @@ where perm } -pub fn eval_permutation_constraints(chip: &C, builder: &mut AB, cumulative_sum: AB::EF) -where - F: PrimeField, - M: Machine, - C: Chip + Air, - AB: ValidaAirBuilder, +pub fn eval_permutation_constraints( + chip: &C, + builder: &mut AB, + cumulative_sum: AB::EF, +) where + M: Machine, + C: Chip + Air, + SC: StarkConfig, + AB: ValidaAirBuilder, { let rand_elems = builder.permutation_randomness().to_vec(); @@ -272,15 +278,14 @@ where ); } -fn generate_rlc_elements( +fn generate_rlc_elements( machine: &M, - chip: &dyn Chip, - random_elements: &[EF], -) -> (Vec, Vec) + chip: &dyn Chip, + random_elements: &[SC::Challenge], +) -> (Vec, Vec) where - F: AbstractField, - EF: AbstractExtensionField, - M: Machine, + M: Machine, + SC: StarkConfig, { let alphas_local = random_elements[0] .powers() diff --git a/machine/src/config.rs b/machine/src/config.rs index 617cfe34..c6a7aea4 100644 --- a/machine/src/config.rs +++ b/machine/src/config.rs @@ -1,12 +1,12 @@ use core::marker::PhantomData; use p3_challenger::{CanObserve, FieldChallenger}; use p3_commit::{Pcs, UnivariatePcsWithLde}; -use p3_field::{AbstractExtensionField, ExtensionField, Field, PackedField, TwoAdicField}; +use p3_field::{AbstractExtensionField, ExtensionField, PackedField, PrimeField32, TwoAdicField}; use p3_matrix::dense::RowMajorMatrix; pub trait StarkConfig { /// The field over which trace data is encoded. - type Val: Field; + type Val: PrimeField32 + TwoAdicField; // TODO: Relax to Field? type PackedVal: PackedField; /// The field from which most random challenges are drawn. @@ -51,7 +51,7 @@ impl impl StarkConfig for StarkConfigImpl where - Val: Field, + Val: PrimeField32 + TwoAdicField, // TODO: Relax to Field? Challenge: ExtensionField + TwoAdicField, PackedChallenge: AbstractExtensionField, Pcs: UnivariatePcsWithLde, Challenger>, diff --git a/machine/src/lib.rs b/machine/src/lib.rs index fb4bc1e3..1b6c51d5 100644 --- a/machine/src/lib.rs +++ b/machine/src/lib.rs @@ -31,7 +31,7 @@ pub const CPU_MEMORY_CHANNELS: usize = 3; pub const MEMORY_CELL_BYTES: usize = 4; pub const LOOKUP_DEGREE_BOUND: usize = 3; -pub trait Instruction { +pub trait Instruction, F: Field> { const OPCODE: u32; fn execute(state: &mut M, ops: Operands); @@ -52,7 +52,7 @@ pub struct InstructionWord { } impl InstructionWord { - pub fn flatten(&self) -> [F; INSTRUCTION_ELEMENTS] { + pub fn flatten(&self) -> [F; INSTRUCTION_ELEMENTS] { let mut result = [F::default(); INSTRUCTION_ELEMENTS]; result[0] = F::from_canonical_u32(self.opcode); result[1..].copy_from_slice(&Operands::::from_i32_slice(&self.operands.0).0); @@ -87,7 +87,7 @@ impl Operands { } } -impl Operands { +impl Operands { pub fn from_i32_slice(slice: &[i32]) -> Self { let mut operands = [F::zero(); OPERAND_ELEMENTS]; for (i, &operand) in slice.iter().enumerate() { @@ -152,17 +152,10 @@ impl ProgramROM { } } -pub trait Machine { - type F: PrimeField64; - type EF: ExtensionField; - +pub trait Machine { fn run(&mut self, program: &ProgramROM, advice: &mut Adv); - fn prove(&self, config: &SC) -> MachineProof - where - SC: StarkConfig; + fn prove>(&self, config: &SC) -> MachineProof; - fn verify(proof: &MachineProof) -> Result<(), ()> - where - SC: StarkConfig; + fn verify>(proof: &MachineProof) -> Result<(), ()>; } diff --git a/memory/src/lib.rs b/memory/src/lib.rs index 206185b9..360c8b56 100644 --- a/memory/src/lib.rs +++ b/memory/src/lib.rs @@ -12,9 +12,10 @@ use valida_machine::{BusArgument, Chip, Interaction, Machine, Word}; use valida_util::batch_multiplicative_inverse; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; pub mod columns; pub mod stark; @@ -49,7 +50,7 @@ pub struct MemoryChip { pub operations: BTreeMap>, } -pub trait MachineWithMemoryChip: Machine { +pub trait MachineWithMemoryChip: Machine { fn mem(&self) -> &MemoryChip; fn mem_mut(&mut self) -> &mut MemoryChip; } @@ -94,11 +95,12 @@ impl MemoryChip { } } -impl Chip for MemoryChip +impl Chip for MemoryChip where - M: MachineWithMemBus, + M: MachineWithMemBus, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { let mut ops = self .operations .par_iter() @@ -134,7 +136,7 @@ where trace } - fn local_sends(&self) -> Vec> { + fn local_sends(&self) -> Vec> { let sends = Interaction { fields: vec![VirtualPairCol::single_main(MEM_COL_MAP.diff)], count: VirtualPairCol::one(), @@ -143,7 +145,7 @@ where vec![sends] } - fn local_receives(&self) -> Vec> { + fn local_receives(&self) -> Vec> { let receives = Interaction { fields: vec![VirtualPairCol::single_main(MEM_COL_MAP.counter)], count: VirtualPairCol::single_main(MEM_COL_MAP.counter_mult), @@ -152,8 +154,8 @@ where vec![receives] } - fn global_receives(&self, machine: &M) -> Vec> { - let is_read: VirtualPairCol = VirtualPairCol::single_main(MEM_COL_MAP.is_read); + fn global_receives(&self, machine: &M) -> Vec> { + let is_read: VirtualPairCol = VirtualPairCol::single_main(MEM_COL_MAP.is_read); let clk = VirtualPairCol::single_main(MEM_COL_MAP.clk); let addr = VirtualPairCol::single_main(MEM_COL_MAP.addr); let value = MEM_COL_MAP.value.0.map(VirtualPairCol::single_main); diff --git a/native_field/src/lib.rs b/native_field/src/lib.rs index a4df31f7..382db7ea 100644 --- a/native_field/src/lib.rs +++ b/native_field/src/lib.rs @@ -8,15 +8,16 @@ use columns::{NativeFieldCols, COL_MAP, NUM_NATIVE_FIELD_COLS}; use core::mem::transmute; use valida_bus::{MachineWithGeneralBus, MachineWithRangeBus8}; use valida_cpu::MachineWithCpuChip; -use valida_machine::{instructions, Chip, Instruction, Interaction, Machine, Operands, Word}; +use valida_machine::{instructions, Chip, Instruction, Interaction, Operands, Word}; use valida_opcodes::{ADD, MUL, SUB}; use valida_range::MachineWithRangeChip; use valida_util::pad_to_power_of_two; use p3_air::VirtualPairCol; -use p3_field::{Field, PrimeField32}; +use p3_field::{AbstractField, Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; pub mod columns; pub mod stark; @@ -32,12 +33,12 @@ pub struct NativeFieldChip { operations: Vec, } -impl Chip for NativeFieldChip +impl Chip for NativeFieldChip where - F: PrimeField32, - M: MachineWithGeneralBus + MachineWithRangeBus8, + M: MachineWithGeneralBus + MachineWithRangeBus8, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { let rows = self .operations .par_iter() @@ -49,12 +50,12 @@ where NUM_NATIVE_FIELD_COLS, ); - pad_to_power_of_two::(&mut trace.values); + pad_to_power_of_two::(&mut trace.values); trace } - fn global_sends(&self, machine: &M) -> Vec> { + fn global_sends(&self, machine: &M) -> Vec> { let sends = COL_MAP .output .0 @@ -74,14 +75,14 @@ where sends } - fn global_receives(&self, machine: &M) -> Vec> { + fn global_receives(&self, machine: &M) -> Vec> { let opcode = VirtualPairCol::new_main( vec![ - (COL_MAP.is_add, M::F::from_canonical_u32(ADD)), - (COL_MAP.is_sub, M::F::from_canonical_u32(SUB)), - (COL_MAP.is_mul, M::F::from_canonical_u32(MUL)), + (COL_MAP.is_add, SC::Val::from_canonical_u32(ADD)), + (COL_MAP.is_sub, SC::Val::from_canonical_u32(SUB)), + (COL_MAP.is_mul, SC::Val::from_canonical_u32(MUL)), ], - M::F::zero(), + SC::Val::zero(), ); let input_1 = COL_MAP.input_1.0.map(VirtualPairCol::single_main); let input_2 = COL_MAP.input_2.0.map(VirtualPairCol::single_main); @@ -137,22 +138,22 @@ impl NativeFieldChip { } } -pub trait MachineWithNativeFieldChip: MachineWithCpuChip { +pub trait MachineWithNativeFieldChip: MachineWithCpuChip { fn native_field(&self) -> NativeFieldChip; fn native_field_mut(&self) -> &mut NativeFieldChip; } instructions!(AddInstruction, SubInstruction, MulInstruction); -impl Instruction for AddInstruction +impl Instruction for AddInstruction where - M: MachineWithNativeFieldChip + MachineWithRangeChip<256> + Machine, + M: MachineWithNativeFieldChip + MachineWithRangeChip, F: PrimeField32, { const OPCODE: u32 = ADD; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -186,15 +187,15 @@ where } } -impl Instruction for SubInstruction +impl Instruction for SubInstruction where - M: MachineWithNativeFieldChip + MachineWithRangeChip<256> + Machine, + M: MachineWithNativeFieldChip + MachineWithRangeChip, F: PrimeField32, { const OPCODE: u32 = SUB; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -228,15 +229,15 @@ where } } -impl Instruction for MulInstruction +impl Instruction for MulInstruction where - M: MachineWithNativeFieldChip + MachineWithRangeChip<256> + Machine, + M: MachineWithNativeFieldChip + MachineWithRangeChip, F: PrimeField32, { const OPCODE: u32 = MUL; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -256,7 +257,7 @@ where .read(clk, read_addr_2, true, pc, opcode, 1, "") }; - let a_m31 = M::F::from_canonical_u32(b.into()) * M::F::from_canonical_u32(c.into()); + let a_m31 = F::from_canonical_u32(b.into()) * F::from_canonical_u32(c.into()); let a = Word::from(a_m31.as_canonical_u32()); state.mem_mut().write(clk, write_addr, a, true); diff --git a/output/src/lib.rs b/output/src/lib.rs index f1684db0..31fc6179 100644 --- a/output/src/lib.rs +++ b/output/src/lib.rs @@ -9,9 +9,10 @@ use valida_machine::{ use valida_opcodes::WRITE; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::*; +use valida_machine::config::StarkConfig; use valida_util::pad_to_power_of_two; pub mod columns; @@ -28,12 +29,12 @@ impl OutputChip { } } -impl Chip for OutputChip +impl Chip for OutputChip where - F: PrimeField, - M: MachineWithGeneralBus, + M: MachineWithGeneralBus, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { let table_len = self.values.len() as u32; let mut rows = self .values @@ -47,15 +48,15 @@ where let num_rows = (clk_diff / table_len) as usize + 1; let mut rows = Vec::with_capacity(num_rows); for n in 0..num_rows { - let mut row = [M::F::zero(); NUM_OUTPUT_COLS]; - let cols: &mut OutputCols = unsafe { transmute(&mut row) }; + let mut row = [SC::Val::zero(); NUM_OUTPUT_COLS]; + let cols: &mut OutputCols = unsafe { transmute(&mut row) }; if n == 0 { - cols.is_real = M::F::one(); - cols.clk = M::F::from_canonical_u32(clk_1); - cols.value = M::F::from_canonical_u8(val_1); + cols.is_real = SC::Val::one(); + cols.clk = SC::Val::from_canonical_u32(clk_1); + cols.value = SC::Val::from_canonical_u8(val_1); } else { // Dummy output to satisfy range check - cols.clk = M::F::from_canonical_u32(clk_1 + table_len * (n + 1) as u32); + cols.clk = SC::Val::from_canonical_u32(clk_1 + table_len * (n + 1) as u32); } rows.push(row); } @@ -63,12 +64,12 @@ where // Compute clock diffs rows.iter() .map(|row| row[OUTPUT_COL_MAP.clk]) - .chain(iter::once(M::F::from_canonical_u32(clk_2))) + .chain(iter::once(SC::Val::from_canonical_u32(clk_2))) .collect::>() .windows(2) .enumerate() .for_each(|(n, clks)| { - let cols: &mut OutputCols = unsafe { transmute(&mut rows[n]) }; + let cols: &mut OutputCols = unsafe { transmute(&mut rows[n]) }; cols.diff = clks[1] - clks[0]; }); @@ -79,11 +80,11 @@ where // Add final row if let Some(last_row) = self.values.last() { - let mut row = [M::F::zero(); NUM_OUTPUT_COLS]; - let cols: &mut OutputCols = unsafe { transmute(&mut row) }; - cols.is_real = M::F::one(); - cols.clk = M::F::from_canonical_u32(last_row.0); - cols.value = M::F::from_canonical_u8(last_row.1); + let mut row = [SC::Val::zero(); NUM_OUTPUT_COLS]; + let cols: &mut OutputCols = unsafe { transmute(&mut row) }; + cols.is_real = SC::Val::one(); + cols.clk = SC::Val::from_canonical_u32(last_row.0); + cols.value = SC::Val::from_canonical_u8(last_row.1); rows.push(row); } @@ -91,11 +92,11 @@ where // re-enable local_sends and local_receives let mut values = rows.concat(); - pad_to_power_of_two::(&mut values); + pad_to_power_of_two::(&mut values); RowMajorMatrix::new(values, NUM_OUTPUT_COLS) } - //fn local_sends(&self) -> Vec> { + //fn local_sends(&self) -> Vec> { // let sends = Interaction { // fields: vec![VirtualPairCol::single_main(OUTPUT_COL_MAP.diff)], // count: VirtualPairCol::one(), @@ -104,7 +105,7 @@ where // vec![sends] //} - //fn local_receives(&self) -> Vec> { + //fn local_receives(&self) -> Vec> { // let receives = Interaction { // fields: vec![VirtualPairCol::single_main(OUTPUT_COL_MAP.counter)], // count: VirtualPairCol::single_main(OUTPUT_COL_MAP.counter_mult), @@ -113,12 +114,12 @@ where // vec![receives] //} - fn global_receives(&self, machine: &M) -> Vec> { + fn global_receives(&self, machine: &M) -> Vec> { let opcode = VirtualPairCol::single_main(OUTPUT_COL_MAP.opcode); let clk = VirtualPairCol::single_main(OUTPUT_COL_MAP.clk); let mut values = (0..CPU_MEMORY_CHANNELS * MEMORY_CELL_BYTES) - .map(|_| VirtualPairCol::constant(M::F::zero())) + .map(|_| VirtualPairCol::constant(SC::Val::zero())) .collect::>(); values[MEMORY_CELL_BYTES - 1] = VirtualPairCol::single_main(OUTPUT_COL_MAP.value); @@ -135,21 +136,22 @@ where } } -pub trait MachineWithOutputChip: MachineWithCpuChip { +pub trait MachineWithOutputChip: MachineWithCpuChip { fn output(&self) -> &OutputChip; fn output_mut(&mut self) -> &mut OutputChip; } instructions!(WriteInstruction); -impl Instruction for WriteInstruction +impl Instruction for WriteInstruction where - M: MachineWithOutputChip, + M: MachineWithOutputChip, + F: Field, { const OPCODE: u32 = WRITE; fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + let opcode = >::OPCODE; let clk = state.cpu().clock; let pc = state.cpu().pc; let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; diff --git a/program/src/lib.rs b/program/src/lib.rs index 77026d2b..151300cf 100644 --- a/program/src/lib.rs +++ b/program/src/lib.rs @@ -6,11 +6,13 @@ use crate::columns::{COL_MAP, NUM_PROGRAM_COLS, PREPROCESSED_COL_MAP}; use alloc::vec; use alloc::vec::Vec; use valida_bus::MachineWithProgramBus; -use valida_machine::{Chip, Interaction, Machine, PrimeField64, ProgramROM}; +use valida_machine::{Chip, Interaction, Machine, ProgramROM}; use valida_util::pad_to_power_of_two; use p3_air::VirtualPairCol; +use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; +use valida_machine::config::StarkConfig; pub mod columns; pub mod stark; @@ -29,24 +31,24 @@ impl ProgramChip { } } -impl Chip for ProgramChip +impl Chip for ProgramChip where - F: PrimeField64, - M: MachineWithProgramBus, + M: MachineWithProgramBus, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { let mut values = self .counts .iter() - .map(|c| F::from_canonical_u32(*c)) + .map(|c| SC::Val::from_canonical_u32(*c)) .collect(); - pad_to_power_of_two::(&mut values); + pad_to_power_of_two::(&mut values); RowMajorMatrix::new(values, NUM_PROGRAM_COLS) } - fn global_receives(&self, machine: &M) -> Vec> { + fn global_receives(&self, machine: &M) -> Vec> { let pc = VirtualPairCol::single_preprocessed(PREPROCESSED_COL_MAP.pc); let opcode = VirtualPairCol::single_preprocessed(PREPROCESSED_COL_MAP.opcode); let mut fields = vec![pc, opcode]; @@ -66,7 +68,7 @@ where } } -pub trait MachineWithProgramChip: Machine { +pub trait MachineWithProgramChip: Machine { fn program(&self) -> &ProgramChip; fn program_mut(&mut self) -> &mut ProgramChip; diff --git a/program/src/stark.rs b/program/src/stark.rs index 80196a40..e846857e 100644 --- a/program/src/stark.rs +++ b/program/src/stark.rs @@ -4,18 +4,17 @@ use alloc::vec; use valida_machine::InstructionWord; use p3_air::{Air, BaseAir, PairBuilder}; -use p3_field::PrimeField64; +use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; -impl Air for ProgramChip +impl Air for ProgramChip where - F: PrimeField64, - AB: PairBuilder, + AB: PairBuilder, { fn eval(&self, _builder: &mut AB) {} } -impl BaseAir for ProgramChip { +impl BaseAir for ProgramChip { fn width(&self) -> usize { NUM_PROGRAM_COLS } diff --git a/range/src/lib.rs b/range/src/lib.rs index d66cb6d9..aa988357 100644 --- a/range/src/lib.rs +++ b/range/src/lib.rs @@ -9,10 +9,12 @@ use columns::{RangeCols, NUM_RANGE_COLS, RANGE_COL_MAP}; use core::mem::transmute; use valida_bus::MachineWithRangeBus8; use valida_machine::Interaction; -use valida_machine::{Chip, Machine, PrimeField, Word}; +use valida_machine::{Chip, Machine, Word}; use p3_air::VirtualPairCol; +use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; +use valida_machine::config::StarkConfig; pub mod columns; pub mod stark; @@ -22,26 +24,26 @@ pub struct RangeCheckerChip { pub count: BTreeMap, } -impl Chip for RangeCheckerChip +impl Chip for RangeCheckerChip where - F: PrimeField, - M: MachineWithRangeBus8, + M: MachineWithRangeBus8, + SC: StarkConfig, { - fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { - let mut rows = vec![[F::zero(); NUM_RANGE_COLS]; MAX as usize]; + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + let mut rows = vec![[SC::Val::zero(); NUM_RANGE_COLS]; MAX as usize]; for (n, row) in rows.iter_mut().enumerate() { - let cols: &mut RangeCols = unsafe { transmute(row) }; + let cols: &mut RangeCols = unsafe { transmute(row) }; // FIXME: This is very inefficient when the range is large. // Iterate over key/val pairs instead in a separate loop. if let Some(c) = self.count.get(&(n as u32)) { - cols.mult = M::F::from_canonical_u32(*c); + cols.mult = SC::Val::from_canonical_u32(*c); } - cols.counter = M::F::from_canonical_u32(n as u32); + cols.counter = SC::Val::from_canonical_u32(n as u32); } RowMajorMatrix::new(rows.concat(), NUM_RANGE_COLS) } - fn global_receives(&self, machine: &M) -> Vec> { + fn global_receives(&self, machine: &M) -> Vec> { let input = VirtualPairCol::single_main(RANGE_COL_MAP.counter); let receive = Interaction { @@ -53,7 +55,7 @@ where } } -pub trait MachineWithRangeChip: Machine { +pub trait MachineWithRangeChip: Machine { fn range(&self) -> &RangeCheckerChip; fn range_mut(&mut self) -> &mut RangeCheckerChip;