Skip to content

Commit

Permalink
fixes stark constraints for LT32Chip
Browse files Browse the repository at this point in the history
  • Loading branch information
tess-eract committed May 1, 2024
1 parent 55fa452 commit 605c959
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 25 deletions.
2 changes: 1 addition & 1 deletion alu_u32/src/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl Lt32Chip {
.into_iter()
.zip(c.into_iter())
.enumerate()
.find_map(|(n, (x, y))| if x == y { Some(n) } else { None })
.find_map(|(n, (x, y))| if x == y { None } else { Some(n) })
{
let z = 256u16 + b[n] as u16 - c[n] as u16;
for i in 0..10 {
Expand Down
51 changes: 32 additions & 19 deletions alu_u32/src/lt/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,46 +33,59 @@ where

// Check bit decomposition of z = 256 + input_1[n] - input_2[n], where
// n is the most significant byte that differs between inputs
for i in 0..3 {
builder
.when_ne(local.byte_flag[i], AB::Expr::one())
.assert_eq(local.input_1[i], local.input_2[i]);

for i in 0..4 {
builder.when(local.byte_flag[i]).assert_eq(
AB::Expr::from_canonical_u32(256) + local.input_1[i] - local.input_2[i],
bit_comp.clone(),
);

builder.assert_bool(local.byte_flag[i]);
}

// Check final byte (if no other byte flags were set)
let flag_sum = local.byte_flag[0] + local.byte_flag[1] + local.byte_flag[2];
// ensure at most one byte flag is set
let flag_sum =
local.byte_flag[0] + local.byte_flag[1] + local.byte_flag[2] + local.byte_flag[3];
builder.assert_bool(flag_sum.clone());

// case: top bytes match
builder
.when_ne(local.byte_flag[0], AB::Expr::one())
.assert_eq(local.input_1[0], local.input_2[0]);
// case: top two bytes match
builder
.when_ne(local.byte_flag[0] + local.byte_flag[1], AB::Expr::one())
.assert_eq(local.input_1[1], local.input_2[1]);
// case: top three bytes match
builder
.when_ne(
local.byte_flag[0] + local.byte_flag[1] + local.byte_flag[2],
AB::Expr::one(),
)
.assert_eq(local.input_1[2], local.input_2[2]);
// case: top four bytes match; must set z = 0
builder
.when_ne(local.multiplicity, AB::Expr::zero())
.when_ne(flag_sum.clone(), AB::Expr::one())
.assert_eq(
AB::Expr::from_canonical_u32(256) + local.input_1[3] - local.input_2[3],
bit_comp.clone(),
);
.assert_eq(local.input_1[3], local.input_2[3]);
builder
.when_ne(flag_sum.clone(), AB::Expr::one())
.assert_eq(bit_comp, AB::Expr::zero());

builder.assert_bool(local.is_lt);
builder.assert_bool(local.is_lte);
builder.assert_bool(local.is_lt + local.is_lte);

// Output constraints
// local.bits[8] is 1 iff input_1 > input_2: output should be 0
builder.when(local.bits[8]).assert_zero(local.output);
builder
.when_ne(local.multiplicity, AB::Expr::zero())
.when_ne(local.bits[8], AB::Expr::one())
.assert_one(local.output);
// output should be 1 if is_lte & input_1 == input_2
let all_flag_sum = flag_sum + local.byte_flag[3];
builder
.when(local.is_lte)
.when_ne(all_flag_sum, AB::Expr::one())
.when_ne(flag_sum.clone(), AB::Expr::one())
.assert_one(local.output);
// output should be 0 if is_lt & input_1 == input_2
builder
.when(local.is_lt)
.when_ne(flag_sum, AB::Expr::one())
.assert_zero(local.output);

// Check bit decomposition
for bit in local.bits.into_iter() {
Expand Down
133 changes: 128 additions & 5 deletions basic/tests/test_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ extern crate core;
use p3_baby_bear::BabyBear;
use p3_fri::{TwoAdicFriPcs, TwoAdicFriPcsConfig};
use valida_alu_u32::add::{Add32Instruction, MachineWithAdd32Chip};
use valida_alu_u32::lt::{Lt32Instruction, Lte32Instruction};
use valida_basic::BasicMachine;
use valida_cpu::{
BeqInstruction, BneInstruction, Imm32Instruction, JalInstruction, JalvInstruction,
Expand All @@ -20,7 +21,7 @@ use valida_program::MachineWithProgramChip;
use p3_challenger::DuplexChallenger;
use p3_dft::Radix2Bowers;
use p3_field::extension::BinomialExtensionField;
use p3_field::Field;
use p3_field::{Field, PrimeField32, TwoAdicField};
use p3_fri::FriConfig;
use p3_keccak::Keccak256Hash;
use p3_mds::coset_mds::CosetMds;
Expand All @@ -31,8 +32,7 @@ use rand::thread_rng;
use valida_machine::StarkConfigImpl;
use valida_machine::__internal::p3_commit::ExtensionMmcs;

#[test]
fn prove_fibonacci() {
fn fib_program<Val: PrimeField32 + TwoAdicField>() -> Vec<InstructionWord<i32>> {
let mut program = vec![];

// Label locations
Expand All @@ -46,7 +46,7 @@ fn prove_fibonacci() {
//main: ; @main
//; %bb.0:
// imm32 -4(fp), 0, 0, 0, 0
// imm32 -8(fp), 0, 0, 0, 10
// imm32 -8(fp), 0, 0, 0, 25
// addi -16(fp), -8(fp), 0
// imm32 -20(fp), 0, 0, 0, 28
// jal -28(fp), fib, -28
Expand Down Expand Up @@ -184,7 +184,74 @@ fn prove_fibonacci() {
operands: Operands([-4, 0, 8, 0, 0]),
},
]);
program
}

fn left_imm_ops_program<Val: PrimeField32 + TwoAdicField>() -> Vec<InstructionWord<i32>> {
let mut program = vec![];

// imm32 -4(fp), 0, 0, 0, 3
// lt32 -8(fp), 3, -4(fp), 1, 0
// lte32 -12(fp), 3, -4(fp), 1, 0
// stop
program.extend([
InstructionWord {
opcode: <Imm32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([-4, 0, 0, 0, 3]),
},
InstructionWord {
opcode: <Imm32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([-8, 0, 0, 1, 0]),
},
InstructionWord {
opcode: <Lt32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([4, 3, -4, 1, 0]),
},
InstructionWord {
opcode: <Lte32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([8, 3, -4, 1, 0]),
},
InstructionWord {
opcode: <Lt32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([12, 4, -4, 1, 0]),
},
InstructionWord {
opcode: <Lte32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([16, 4, -4, 1, 0]),
},
InstructionWord {
opcode: <Lt32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([20, 2, -4, 1, 0]),
},
InstructionWord {
opcode: <Lte32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([24, 2, -4, 1, 0]),
},
InstructionWord {
opcode: <Lt32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([28, 256, -4, 1, 0]),
},
InstructionWord {
opcode: <Lte32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([32, 256, -4, 1, 0]),
},
InstructionWord {
opcode: <Lt32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([36, 3, -8, 1, 0]),
},
InstructionWord {
opcode: <Lte32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([40, 3, -8, 1, 0]),
},
InstructionWord {
opcode: <StopInstruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands::default(),
},
]);
program
}

fn prove_program(program: Vec<InstructionWord<i32>>) -> BasicMachine<BabyBear> {
let mut machine = BasicMachine::<Val>::default();
let rom = ProgramROM::new(program);
machine.program_mut().set_program_rom(&rom);
Expand All @@ -194,6 +261,7 @@ fn prove_fibonacci() {
machine.run(&rom, &mut FixedAdviceProvider::empty());

type Val = BabyBear;

type Challenge = BinomialExtensionField<Val, 5>;
type PackedChallenge = BinomialExtensionField<<Val as Field>::Packing, 5>;

Expand Down Expand Up @@ -250,13 +318,68 @@ fn prove_fibonacci() {
.verify(&config, &deserialized_proof)
.expect("verification failed");

machine
}
#[test]
fn prove_fibonacci() {
let program = fib_program::<BabyBear>();

let machine = prove_program(program);

assert_eq!(machine.cpu().clock, 192);
assert_eq!(machine.cpu().operations.len(), 192);
assert_eq!(machine.mem().operations.values().flatten().count(), 401);
assert_eq!(machine.add_u32().operations.len(), 105);

assert_eq!(
*machine.mem().cells.get(&(0x1000 + 4)).unwrap(), // Return value
Word([0, 1, 37, 17,]) // 25th fibonacci number (75025)
);
}

#[test]
fn prove_left_imm_ops() {
let program = left_imm_ops_program::<BabyBear>();

let machine = prove_program(program);

assert_eq!(
*machine.mem().cells.get(&(0x1000 + 4)).unwrap(),
Word([0, 0, 0, 0]) // 3 < 3 (false)
);
assert_eq!(
*machine.mem().cells.get(&(0x1000 + 8)).unwrap(),
Word([0, 0, 0, 1]) // 3 <= 3 (true)
);
assert_eq!(
*machine.mem().cells.get(&(0x1000 + 12)).unwrap(),
Word([0, 0, 0, 0]) // 4 < 3 (false)
);
assert_eq!(
*machine.mem().cells.get(&(0x1000 + 16)).unwrap(),
Word([0, 0, 0, 0]) // 4 <= 3 (false)
);
assert_eq!(
*machine.mem().cells.get(&(0x1000 + 20)).unwrap(),
Word([0, 0, 0, 1]) // 2 < 3 (true)
);
assert_eq!(
*machine.mem().cells.get(&(0x1000 + 24)).unwrap(),
Word([0, 0, 0, 1]) // 2 <= 3 (true)
);
assert_eq!(
*machine.mem().cells.get(&(0x1000 + 28)).unwrap(),
Word([0, 0, 0, 0]) // 256 < 3 (false)
);
assert_eq!(
*machine.mem().cells.get(&(0x1000 + 32)).unwrap(),
Word([0, 0, 0, 0]) // 256 <= 3 (false)
);
assert_eq!(
*machine.mem().cells.get(&(0x1000 + 36)).unwrap(),
Word([0, 0, 0, 1]) // 3 < 256 (true)
);
assert_eq!(
*machine.mem().cells.get(&(0x1000 + 40)).unwrap(),
Word([0, 0, 0, 1]) // 3 <= 256 (false)
);
}

0 comments on commit 605c959

Please sign in to comment.