Skip to content

Commit

Permalink
fixes setting instruction column values with immediate arguments in C…
Browse files Browse the repository at this point in the history
…puChip
  • Loading branch information
tess-eract committed May 3, 2024
1 parent 2bfd1a3 commit f579563
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 148 deletions.
189 changes: 57 additions & 132 deletions alu_u32/src/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,29 +139,18 @@ impl Lt32Chip {
}
cols.multiplicity = F::one();
}
}

pub trait MachineWithLt32Chip<F: Field>: MachineWithCpuChip<F> {
fn lt_u32(&self) -> &Lt32Chip;
fn lt_u32_mut(&mut self) -> &mut Lt32Chip;
}

instructions!(
Lt32Instruction,
Lte32Instruction,
Slt32Instruction,
Sle32Instruction
);

impl<M, F> Instruction<M, F> for Lt32Instruction
where
M: MachineWithLt32Chip<F>,
F: Field,
{
const OPCODE: u32 = LT32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
fn execute_with_closure<M, E, F>(
state: &mut M,
ops: Operands<i32>,
opcode: u32,
comp: F,
) -> (Word<u8>, Word<u8>, Word<u8>)
where
M: MachineWithLt32Chip<E>,
E: Field,
F: Fn(Word<u8>, Word<u8>) -> bool,
{
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
Expand All @@ -187,75 +176,68 @@ where
.read(clk, read_addr_2, true, pc, opcode, 1, "")
};

let dst = if src1 < src2 {
let dst = if comp(src1, src2) {
Word::from(1)
} else {
Word::from(0)
};
state.mem_mut().write(clk, write_addr, dst, true);

state
.lt_u32_mut()
.operations
.push(Operation::Lt32(dst, src1, src2));
if ops.d() == 1 {
state.cpu_mut().push_left_imm_bus_op(imm, opcode, ops);
state.cpu_mut().push_left_imm_bus_op(imm, opcode, ops)
} else {
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
(dst, src1, src2)
}
}

impl<M, F> Instruction<M, F> for Lte32Instruction
pub trait MachineWithLt32Chip<F: Field>: MachineWithCpuChip<F> {
fn lt_u32(&self) -> &Lt32Chip;
fn lt_u32_mut(&mut self) -> &mut Lt32Chip;
}

instructions!(
Lt32Instruction,
Lte32Instruction,
Slt32Instruction,
Sle32Instruction
);

impl<M, F> Instruction<M, F> for Lt32Instruction
where
M: MachineWithLt32Chip<F>,
F: Field,
{
const OPCODE: u32 = LTE32;
const OPCODE: u32 = LT32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1: Word<u8> = if ops.d() == 1 {
let b = (ops.b() as u32).into();
imm = Some(b);
b
} else {
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
state
.mem_mut()
.read(clk, read_addr_1, true, pc, opcode, 0, "")
};
let src2: Word<u8> = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state
.mem_mut()
.read(clk, read_addr_2, true, pc, opcode, 1, "")
};
let comp = |a, b| a < b;
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);

let dst = if src1 <= src2 {
Word::from(1)
} else {
Word::from(0)
};
state.mem_mut().write(clk, write_addr, dst, true);
state
.lt_u32_mut()
.operations
.push(Operation::Lt32(dst, src1, src2));
}
}

impl<M, F> Instruction<M, F> for Lte32Instruction
where
M: MachineWithLt32Chip<F>,
F: Field,
{
const OPCODE: u32 = LTE32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
let comp = |a, b| a <= b;
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);
state
.lt_u32_mut()
.operations
.push(Operation::Lte32(dst, src1, src2));
if ops.d() == 1 {
state.cpu_mut().push_left_imm_bus_op(imm, opcode, ops);
} else {
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}
}

Expand All @@ -268,45 +250,16 @@ where

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1: Word<u8> = if ops.d() == 1 {
let b = (ops.b() as u32).into();
imm = Some(b);
b
} else {
state
.mem_mut()
.read(clk, read_addr_1, true, pc, opcode, 0, "")
};
let src2: Word<u8> = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state
.mem_mut()
.read(clk, read_addr_2, true, pc, opcode, 1, "")
};

let src1_i: i32 = src1.into();
let src2_i: i32 = src2.into();
let dst = if src1_i < src2_i {
Word::from(1)
} else {
Word::from(0)
let comp = |a: Word<u8>, b: Word<u8>| {
let a_i: i32 = a.into();
let b_i: i32 = b.into();
a_i < b_i
};
state.mem_mut().write(clk, write_addr, dst, true);

let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);
state
.lt_u32_mut()
.operations
.push(Operation::Slt32(dst, src1, src2));
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}

Expand All @@ -319,44 +272,16 @@ where

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1: Word<u8> = if ops.d() == 1 {
let b = (ops.b() as u32).into();
imm = Some(b);
b
} else {
state
.mem_mut()
.read(clk, read_addr_1, true, pc, opcode, 0, "")
};
let src2: Word<u8> = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state
.mem_mut()
.read(clk, read_addr_2, true, pc, opcode, 1, "")
let comp = |a: Word<u8>, b: Word<u8>| {
let a_i: i32 = a.into();
let b_i: i32 = b.into();
a_i <= b_i
};

let src1_i: i32 = src1.into();
let src2_i: i32 = src2.into();
let dst = if src1_i <= src2_i {
Word::from(1)
} else {
Word::from(0)
};
state.mem_mut().write(clk, write_addr, dst, true);
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);

state
.lt_u32_mut()
.operations
.push(Operation::Sle32(dst, src1, src2));
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}
24 changes: 13 additions & 11 deletions alu_u32/src/lt/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,13 @@ where
.map(|(bit, base)| bit * base)
.sum();

// Check bit decomposition of z = 256 + input_1[n] - input_2[n], where
// n is the most significant byte that differs between inputs
for i in 0..4 {
builder.when(local.byte_flag[i]).assert_eq(
AB::Expr::from_canonical_u32(256) + local.input_1[i] - local.input_2[i],
bit_comp.clone(),
);
builder.assert_bool(local.byte_flag[i]);
}
// check that the n-th byte flag is set, where n is the first byte that differs between the two inputs

// ensure at most one byte flag is set
let flag_sum =
local.byte_flag[0] + local.byte_flag[1] + local.byte_flag[2] + local.byte_flag[3];
builder.assert_bool(flag_sum.clone());

// check that bytes before the first set byte flag are all equal
// case: top bytes match
builder
.when_ne(local.byte_flag[0], AB::Expr::one())
Expand All @@ -67,7 +59,17 @@ where
.assert_eq(local.input_1[3], local.input_2[3]);
builder
.when_ne(flag_sum.clone(), AB::Expr::one())
.assert_eq(bit_comp, AB::Expr::zero());
.assert_eq(bit_comp.clone(), AB::Expr::zero());

// Check bit decomposition of z = 256 + input_1[n] - input_2[n]
// when `n` is the first byte that differs between the two inputs.
for i in 0..4 {
builder.when(local.byte_flag[i]).assert_eq(
AB::Expr::from_canonical_u32(256) + local.input_1[i] - local.input_2[i],
bit_comp.clone(),
);
builder.assert_bool(local.byte_flag[i]);
}

builder.assert_bool(local.is_lt);
builder.assert_bool(local.is_lte);
Expand Down
18 changes: 14 additions & 4 deletions basic/tests/test_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,59 +190,69 @@ fn fib_program<Val: PrimeField32 + TwoAdicField>() -> Vec<InstructionWord<i32>>
fn left_imm_ops_program<Val: PrimeField32 + TwoAdicField>() -> Vec<InstructionWord<i32>> {
let mut program = vec![];

// imm32 -4(fp), 0, 0, 0, 3
// lt32 -8(fp), 3, -4(fp), 1, 0
// lte32 -12(fp), 3, -4(fp), 1, 0
// stop
program.extend([
// imm32 -4(fp), 0, 0, 0, 3
// ;(0, 0, 1, 0) == 256
InstructionWord {
opcode: <Imm32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([-4, 0, 0, 0, 3]),
},
// imm32 -8(fp), 0, 0, 1, 0
InstructionWord {
opcode: <Imm32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([-8, 0, 0, 1, 0]),
},
// lt32 4(fp), 3, -4(fp), 1, 0
InstructionWord {
opcode: <Lt32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([4, 3, -4, 1, 0]),
},
// lte32 8(fp), 3, -4(fp), 1, 0
InstructionWord {
opcode: <Lte32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([8, 3, -4, 1, 0]),
},
// lt32 12(fp), 4, -4(fp), 1, 0
InstructionWord {
opcode: <Lt32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([12, 4, -4, 1, 0]),
},
// lte32 16(fp), 4, -4(fp), 1, 0
InstructionWord {
opcode: <Lte32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([16, 4, -4, 1, 0]),
},
// lt32 20(fp), 2, -4(fp), 1, 0
InstructionWord {
opcode: <Lt32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([20, 2, -4, 1, 0]),
},
// lte32 24(fp), 2, -4(fp), 1, 0
InstructionWord {
opcode: <Lte32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([24, 2, -4, 1, 0]),
},
// lt32 28(fp), 256, -4(fp), 1, 0
InstructionWord {
opcode: <Lt32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([28, 256, -4, 1, 0]),
},
// lte32 32(fp), 256, -4(fp), 1, 0
InstructionWord {
opcode: <Lte32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([32, 256, -4, 1, 0]),
},
// lt32 36(fp), 3, -8(fp), 1, 0
InstructionWord {
opcode: <Lt32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([36, 3, -8, 1, 0]),
},
// lte32 40(fp), 3, -8(fp), 1, 0
InstructionWord {
opcode: <Lte32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([40, 3, -8, 1, 0]),
},
// stop 0, 0, 0, 0, 0
InstructionWord {
opcode: <StopInstruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands::default(),
Expand Down
2 changes: 1 addition & 1 deletion cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ impl CpuChip {
cols.pc = SC::Val::from_canonical_u32(self.registers[clk].pc);
cols.fp = SC::Val::from_canonical_u32(self.registers[clk].fp);
cols.clk = SC::Val::from_canonical_usize(clk);
self.set_instruction_values(clk, cols);

match op {
Operation::Store32 => {
Expand Down Expand Up @@ -229,7 +230,6 @@ impl CpuChip {
}
}

self.set_instruction_values(clk, cols);
self.set_memory_channel_values::<M, SC>(clk, cols, machine);

row
Expand Down

0 comments on commit f579563

Please sign in to comment.