Skip to content

Commit

Permalink
Add eq.
Browse files Browse the repository at this point in the history
  • Loading branch information
thealmarty committed Jan 15, 2024
1 parent 5521161 commit b087138
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 65 deletions.
8 changes: 8 additions & 0 deletions alu_u32/src/com/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@ pub struct Com32Cols<T> {
pub input_1: Word<T>,
pub input_2: Word<T>,

/// When doing an equality test between two words, `x` and `y`, this holds the sum of
/// `(x_i - y_i)^2`, which is zero if and only if `x = y`.
pub diff: T,
/// The inverse of `diff`, or undefined if `diff = 0`.
pub diff_inv: T,
/// A boolean flag indicating whether `diff != 0`.
pub not_equal: T,

pub output: T,

pub is_ne: T,
pub is_eq: T,
pub is_ne: T,
pub is_eq: T,
}

pub const NUM_COM_COLS: usize = size_of::<Com32Cols<u8>>();
Expand Down
56 changes: 26 additions & 30 deletions alu_u32/src/com/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use valida_cpu::MachineWithCpuChip;
use valida_machine::{
instructions, Chip, Instruction, Interaction, Operands, Word, MEMORY_CELL_BYTES,
};
use valida_opcodes::{NE32, EQ32};
use valida_opcodes::{EQ32, NE32};

use p3_air::VirtualPairCol;
use p3_field::PrimeField;
Expand All @@ -25,6 +25,7 @@ pub mod stark;
pub enum Operation {
Ne32(Word<u8>, Word<u8>, Word<u8>), // (dst, src1, src2)
Eq32(Word<u8>, Word<u8>, Word<u8>), // (dst, src1, src2)
Eq32(Word<u8>, Word<u8>, Word<u8>), // (dst, src1, src2)
}

#[derive(Default)]
Expand Down Expand Up @@ -53,6 +54,13 @@ where
}

fn global_receives(&self, machine: &M) -> Vec<Interaction<M::F>> {
let opcode = VirtualPairCol::new_main(
vec![
(COM_COL_MAP.is_ne, M::F::from_canonical_u32(NE32)),
(COM_COL_MAP.is_eq, M::F::from_canonical_u32(EQ32)),
],
M::F::zero(),
);
let opcode = VirtualPairCol::new_main(
vec![
(COM_COL_MAP.is_ne, M::F::from_canonical_u32(NE32)),
Expand All @@ -71,14 +79,12 @@ where
fields.extend(input_2);
fields.extend(output);

let is_real = VirtualPairCol::sum_main(vec![
COM_COL_MAP.is_ne,
COM_COL_MAP.is_eq,
]);
let is_real = VirtualPairCol::sum_main(vec![COM_COL_MAP.is_ne, COM_COL_MAP.is_eq]);

let receive = Interaction {
fields,
count: is_real,
count: is_real,
argument_index: machine.general_bus(),
};
vec![receive]
Expand All @@ -94,39 +100,23 @@ impl Com32Chip {
let cols: &mut Com32Cols<F> = unsafe { transmute(&mut row) };

match op {
Operation::Ne32(dst, src1, src2) => {
Operation::Ne32(_, _, _) => {
cols.is_ne = F::one();
}
Operation::Eq32(_, _, _) => {
cols.is_eq = F::one();
}

// if let Some(n) = src1
// .into_iter()
// .zip(src2.into_iter())
// .enumerate()
// .find_map(|(n, (x, y))| if x == y { Some(n) } else { None })
// {
// let z = 256u16 + src1[n] as u16 - src2[n] as u16;
// for i in 0..10 {
// cols.bits[i] = F::from_canonical_u16(z >> i & 1);
// }
// if n < 3 {
// cols.byte_flag[n] = F::one();
// }
}
row
// cols.input_1 = src1.transform(F::from_canonical_u8);
// cols.input_2 = src2.transform(F::from_canonical_u8);
// cols.output = F::from_canonical_u8(dst[3]);
}
}
row
}
}

pub trait MachineWithCom32Chip: MachineWithCpuChip {
fn com_u32(&self) -> &Com32Chip;
fn com_u32_mut(&mut self) -> &mut Com32Chip;
}

instructions!(Ne32Instruction, Eq32Instruction);
instructions!(Ne32Instruction, Eq32Instruction);

impl<M> Instruction<M> for Ne32Instruction
Expand All @@ -142,14 +132,18 @@ where
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1 = state.mem_mut().read(clk, read_addr_1, true, pc, opcode, 0, "");
let src1 = state
.mem_mut()
.read(clk, read_addr_1, true, pc, opcode, 0, "");
let src2 = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state.mem_mut().read(clk, read_addr_2, true, pc, opcode, 1, "")
state
.mem_mut()
.read(clk, read_addr_2, true, pc, opcode, 1, "")
};

let dst = if src1 != src2 {
Expand All @@ -164,8 +158,10 @@ where
.operations
.push(Operation::Ne32(dst, src1, src2));
state
.cpu_mut()
.push_bus_op(imm, opcode, ops);
.com_u32_mut()
.operations
.push(Operation::Eq32(dst, src1, src2));
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}

Expand Down
34 changes: 2 additions & 32 deletions alu_u32/src/com/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,50 +22,20 @@ where
let main = builder.main();
let local: &Com32Cols<AB::Var> = main.row_slice(0).borrow();

// let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512].map(AB::Expr::from_canonical_u32);

// let bit_comp: AB::Expr = local
// .bits
// .into_iter()
// .zip(base_2.iter().cloned())
// .map(|(bit, base)| bit * base)
// .sum();

// Check if the first two operand values are equal, in case we're doing a conditional branch.
// (when is_imm == 1, the second read value is guaranteed to be an immediate value)
builder.assert_eq(
local.diff,
local
.input_1()
.input_1
.into_iter()
.zip(local.input_2())
.zip(local.input_2)
.map(|(a, b)| (a - b) * (a - b))
.sum::<AB::Expr>(),
);
builder.assert_bool(local.not_equal);
builder.assert_eq(local.not_equal, local.diff * local.diff_inv);

// Output constraints
builder.when(local.bits[8]).assert_zero(local.output);
builder
.when_ne(local.multiplicity, AB::Expr::zero())
.when_ne(local.bits[8], AB::Expr::one())
.assert_one(local.output);


// Check the resulting output byte
builder
.when(local.is_ne)
.assert_eq(com_ne.clone(), local.output);
builder
.when(local.is_eq)
.assert_eq(com_eq.clone(), local.output);

// Check that bits are boolean values
// for bit in local.bits_1[i].into_iter().chain(local.bits_2[i]) {
// builder.assert_bool(bit);
// }

builder.assert_bool(local.is_ne);
builder.assert_bool(local.is_eq);
builder.assert_bool(local.is_ne + local.is_eq);
Expand Down
2 changes: 1 addition & 1 deletion alu_u32/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ extern crate alloc;

pub mod add;
pub mod bitwise;
pub mod com;
pub mod div;
pub mod lt;
pub mod com;
pub mod mul;
pub mod shift;
pub mod sub;
3 changes: 3 additions & 0 deletions basic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use valida_alu_u32::{
And32Instruction, Bitwise32Chip, MachineWithBitwise32Chip, Or32Instruction,
Xor32Instruction,
},
com::{Com32Chip, Eq32Instruction, MachineWithCom32Chip, Ne32Instruction},
div::{Div32Chip, Div32Instruction, MachineWithDiv32Chip, SDiv32Instruction},
lt::{Lt32Chip, Lt32Instruction, MachineWithLt32Chip},
com::{Com32Chip, Ne32Instruction, Eq32Instruction, MachineWithCom32Chip},
Expand Down Expand Up @@ -91,6 +92,8 @@ pub struct BasicMachine<F: PrimeField32 + TwoAdicField> {
ne32: Ne32Instruction,
#[instruction(com_u32)]
eq32: Eq32Instruction,
#[instruction(com_u32)]
eq32: Eq32Instruction,
#[instruction(bitwise_u32)]
and32: And32Instruction,
#[instruction(bitwise_u32)]
Expand Down
4 changes: 2 additions & 2 deletions opcodes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ pub const AND32: u32 = 107;
pub const OR32: u32 = 108;
pub const XOR32: u32 = 109;
pub const NE32: u32 = 111;
pub const MULHU32: u32 = 112;
pub const SRA32: u32 = 113;
pub const MULHU32: u32 = 112;
pub const SRA32: u32 = 113;
pub const MULHS32: u32 = 114;
pub const LTE32: u32 = 115; //TODO
pub const EQ32: u32 = 116;
Expand Down

0 comments on commit b087138

Please sign in to comment.