Skip to content

Commit

Permalink
add "create_prover_input_from_file" func (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stavbe authored Feb 6, 2025
1 parent 1ceb87d commit 76b11a1
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 159 deletions.
40 changes: 34 additions & 6 deletions stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,38 @@ pub enum CairoVerificationError {

#[cfg(test)]
pub mod tests {
use std::path::PathBuf;

use cairo_lang_casm::casm;
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;

use super::ProverConfig;
use crate::cairo_air::{prove_cairo, verify_cairo, ProverInput};
use crate::input::plain::input_from_plain_casm;
use crate::input::vm_import::adapt_vm_output;

/// Creates a prover input from `pub.json`, `priv.json`, `mem`, and `trace` files.
///
/// # Expects
/// - These files must be stored in the `test_data/test_name` directory and contain valid Cairo
/// program data.
/// - They can be downloaded from Google Storage using `./scripts/download_test_data.sh`. See
/// `input/README.md` for details.
///
/// # Panics
/// - If it fails to convert the files into a prover input.
pub fn test_input(test_name: &str) -> ProverInput {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/");
d.push(test_name);

adapt_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect(
"
Failed to read test files. Checkout input/README.md.",
)
}

fn test_input() -> ProverInput {
fn test_basic_cairo_air_input() -> ProverInput {
let u128_max = u128::MAX;
let instructions = casm! {
// TODO(AlonH): Add actual range check segment.
Expand Down Expand Up @@ -267,7 +291,8 @@ pub mod tests {

#[test]
fn test_basic_cairo_air() {
let cairo_proof = prove_cairo::<Blake2sMerkleChannel>(test_input(), test_cfg()).unwrap();
let cairo_proof =
prove_cairo::<Blake2sMerkleChannel>(test_basic_cairo_air_input(), test_cfg()).unwrap();
verify_cairo::<Blake2sMerkleChannel>(cairo_proof).unwrap();
}

Expand All @@ -281,12 +306,12 @@ pub mod tests {
use stwo_prover::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel;

use super::*;
use crate::input::vm_import::tests::small_cairo_input;

#[test]
fn generate_and_serialise_proof() {
let cairo_proof =
prove_cairo::<Poseidon252MerkleChannel>(test_input(), test_cfg()).unwrap();
prove_cairo::<Poseidon252MerkleChannel>(test_basic_cairo_air_input(), test_cfg())
.unwrap();
let mut output = Vec::new();
CairoSerialize::serialize(&cairo_proof, &mut output);
let proof_str = output.iter().map(|v| v.to_string()).join(",");
Expand All @@ -297,8 +322,11 @@ pub mod tests {

#[test]
fn test_full_cairo_air() {
let cairo_proof =
prove_cairo::<Blake2sMerkleChannel>(small_cairo_input(), test_cfg()).unwrap();
let cairo_proof = prove_cairo::<Blake2sMerkleChannel>(
test_input("test_read_from_small_files"),
test_cfg(),
)
.unwrap();
verify_cairo::<Blake2sMerkleChannel>(cairo_proof).unwrap();
}
}
Expand Down
26 changes: 8 additions & 18 deletions stwo_cairo_prover/crates/prover/src/input/state_transitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -637,33 +637,21 @@ fn is_small_mul(op0: MemoryValue, op_1: MemoryValue) -> bool {
/// Tests instructions mapping.
#[cfg(test)]
mod mappings_tests {
use std::path::PathBuf;

use cairo_lang_casm::casm;
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;

use super::*;
use crate::cairo_air::tests::test_cfg;
use crate::cairo_air::tests::{test_cfg, test_input};
use crate::cairo_air::{prove_cairo, verify_cairo};
use crate::input::decode::{Instruction, OpcodeExtension};
use crate::input::memory::*;
use crate::input::plain::input_from_plain_casm;
use crate::input::vm_import::adapt_vm_output;
use crate::input::ProverInput;

pub fn all_opcode_coponents_program() -> ProverInput {
let mut d: PathBuf = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/test_prove_verify_all_components");
adapt_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect(
"
Failed to read test files. Checkout input/README.md.",
)
}

#[test]
#[cfg(feature = "slow-tests")]
fn test_prove_verify_all_components() {
let input = all_opcode_coponents_program();
fn test_prove_verify_all_opcode_components() {
let input = test_input("test_prove_verify_all_components");
for (opcode, n_instances) in input.state_transitions.casm_states_by_opcode.counts() {
// TODO(Stav): Remove when `Blake` opcode is in the VM.
if opcode == "blake2s_opcode" {
Expand All @@ -676,9 +664,11 @@ mod mappings_tests {
opcode
);
}
let cairo_proof =
prove_cairo::<Blake2sMerkleChannel>(all_opcode_coponents_program(), test_cfg())
.unwrap();
let cairo_proof = prove_cairo::<Blake2sMerkleChannel>(
test_input("test_prove_verify_all_components"),
test_cfg(),
)
.unwrap();
verify_cairo::<Blake2sMerkleChannel>(cairo_proof).unwrap();
}

Expand Down
243 changes: 108 additions & 135 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,142 +190,115 @@ impl<R: Read> Iterator for MemoryEntryIter<'_, R> {
}

#[cfg(test)]
pub mod tests {
use std::path::PathBuf;

use super::*;

pub fn large_cairo_input() -> ProverInput {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/test_read_from_large_files");

adapt_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect(
"
Failed to read test files. Checkout input/README.md.",
)
#[cfg(feature = "slow-tests")]
pub mod slow_tests {
use crate::cairo_air::tests::test_input;

#[test]
fn test_read_from_large_files() {
let input = test_input("test_read_from_large_files");

// Test opcode components.
let components = input.state_transitions.casm_states_by_opcode;
assert_eq!(components.generic_opcode.len(), 0);
assert_eq!(components.add_ap_opcode.len(), 0);
assert_eq!(components.add_ap_opcode_imm.len(), 36895);
assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 33);
assert_eq!(components.add_opcode_small_imm.len(), 84732);
assert_eq!(components.add_opcode.len(), 189425);
assert_eq!(components.add_opcode_small.len(), 36623);
assert_eq!(components.add_opcode_imm.len(), 22089);
assert_eq!(components.assert_eq_opcode.len(), 233432);
assert_eq!(components.assert_eq_opcode_double_deref.len(), 811061);
assert_eq!(components.assert_eq_opcode_imm.len(), 43184);
assert_eq!(components.call_opcode.len(), 0);
assert_eq!(components.call_opcode_rel.len(), 49439);
assert_eq!(components.call_opcode_op_1_base_fp.len(), 33);
assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 11235);
assert_eq!(components.jnz_opcode.len(), 27032);
assert_eq!(components.jnz_opcode_taken.len(), 51060);
assert_eq!(components.jnz_opcode_dst_base_fp.len(), 5100);
assert_eq!(components.jump_opcode_rel_imm.len(), 31873865);
assert_eq!(components.jump_opcode_rel.len(), 500);
assert_eq!(components.jump_opcode_double_deref.len(), 32);
assert_eq!(components.jump_opcode.len(), 0);
assert_eq!(components.mul_opcode_small_imm.len(), 7234);
assert_eq!(components.mul_opcode_small.len(), 7203);
assert_eq!(components.mul_opcode.len(), 3943);
assert_eq!(components.mul_opcode_imm.len(), 10809);
assert_eq!(components.ret_opcode.len(), 49472);

// Test builtins.
let builtins_segments = input.builtins_segments;
assert_eq!(builtins_segments.add_mod, None);
assert_eq!(builtins_segments.bitwise, None);
assert_eq!(builtins_segments.ec_op, Some((16428600, 16428824).into()));
assert_eq!(builtins_segments.ecdsa, None);
assert_eq!(builtins_segments.keccak, None);
assert_eq!(builtins_segments.mul_mod, None);
assert_eq!(builtins_segments.pedersen, Some((1322552, 1347128).into()));
assert_eq!(
builtins_segments.poseidon,
Some((16920120, 17706552).into())
);
assert_eq!(builtins_segments.range_check_bits_96, None);
assert_eq!(
builtins_segments.range_check_bits_128,
Some((1715768, 1781304).into())
);
}

pub fn small_cairo_input() -> ProverInput {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/test_read_from_small_files");
adapt_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect(
"
Failed to read test files. Checkout input/README.md.",
)
}

#[cfg(test)]
#[cfg(feature = "slow-tests")]
pub mod slow_tests {

use super::*;

#[test]
fn test_read_from_large_files() {
let input = large_cairo_input();

// Test opcode components.
let components = input.state_transitions.casm_states_by_opcode;
assert_eq!(components.generic_opcode.len(), 0);
assert_eq!(components.add_ap_opcode.len(), 0);
assert_eq!(components.add_ap_opcode_imm.len(), 36895);
assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 33);
assert_eq!(components.add_opcode_small_imm.len(), 84732);
assert_eq!(components.add_opcode.len(), 189425);
assert_eq!(components.add_opcode_small.len(), 36623);
assert_eq!(components.add_opcode_imm.len(), 22089);
assert_eq!(components.assert_eq_opcode.len(), 233432);
assert_eq!(components.assert_eq_opcode_double_deref.len(), 811061);
assert_eq!(components.assert_eq_opcode_imm.len(), 43184);
assert_eq!(components.call_opcode.len(), 0);
assert_eq!(components.call_opcode_rel.len(), 49439);
assert_eq!(components.call_opcode_op_1_base_fp.len(), 33);
assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 11235);
assert_eq!(components.jnz_opcode.len(), 27032);
assert_eq!(components.jnz_opcode_taken.len(), 51060);
assert_eq!(components.jnz_opcode_dst_base_fp.len(), 5100);
assert_eq!(components.jump_opcode_rel_imm.len(), 31873865);
assert_eq!(components.jump_opcode_rel.len(), 500);
assert_eq!(components.jump_opcode_double_deref.len(), 32);
assert_eq!(components.jump_opcode.len(), 0);
assert_eq!(components.mul_opcode_small_imm.len(), 7234);
assert_eq!(components.mul_opcode_small.len(), 7203);
assert_eq!(components.mul_opcode.len(), 3943);
assert_eq!(components.mul_opcode_imm.len(), 10809);
assert_eq!(components.ret_opcode.len(), 49472);

// Test builtins.
let builtins_segments = input.builtins_segments;
assert_eq!(builtins_segments.add_mod, None);
assert_eq!(builtins_segments.bitwise, None);
assert_eq!(builtins_segments.ec_op, Some((16428600, 16428824).into()));
assert_eq!(builtins_segments.ecdsa, None);
assert_eq!(builtins_segments.keccak, None);
assert_eq!(builtins_segments.mul_mod, None);
assert_eq!(builtins_segments.pedersen, Some((1322552, 1347128).into()));
assert_eq!(
builtins_segments.poseidon,
Some((16920120, 17706552).into())
);
assert_eq!(builtins_segments.range_check_bits_96, None);
assert_eq!(
builtins_segments.range_check_bits_128,
Some((1715768, 1781304).into())
);
}

#[test]
fn test_read_from_small_files() {
let input = small_cairo_input();

// Test opcode components.
let components = input.state_transitions.casm_states_by_opcode;
assert_eq!(components.generic_opcode.len(), 0);
assert_eq!(components.add_ap_opcode.len(), 0);
assert_eq!(components.add_ap_opcode_imm.len(), 2);
assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 1);
assert_eq!(components.add_opcode_small_imm.len(), 500);
assert_eq!(components.add_opcode.len(), 0);
assert_eq!(components.add_opcode_small.len(), 0);
assert_eq!(components.add_opcode_imm.len(), 450);
assert_eq!(components.assert_eq_opcode.len(), 55);
assert_eq!(components.assert_eq_opcode_double_deref.len(), 2100);
assert_eq!(components.assert_eq_opcode_imm.len(), 1952);
assert_eq!(components.call_opcode.len(), 0);
assert_eq!(components.call_opcode_rel.len(), 462);
assert_eq!(components.call_opcode_op_1_base_fp.len(), 0);
assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 450);
assert_eq!(components.jnz_opcode.len(), 0);
assert_eq!(components.jnz_opcode_taken.len(), 0);
assert_eq!(components.jnz_opcode_dst_base_fp.len(), 11);
assert_eq!(components.jump_opcode_rel_imm.len(), 124626);
assert_eq!(components.jump_opcode_rel.len(), 0);
assert_eq!(components.jump_opcode_double_deref.len(), 0);
assert_eq!(components.jump_opcode.len(), 0);
assert_eq!(components.mul_opcode_small_imm.len(), 0);
assert_eq!(components.mul_opcode_small.len(), 0);
assert_eq!(components.mul_opcode.len(), 0);
assert_eq!(components.mul_opcode_imm.len(), 0);
assert_eq!(components.ret_opcode.len(), 462);

// Test builtins.
let builtins_segments = input.builtins_segments;
assert_eq!(builtins_segments.add_mod, None);
assert_eq!(builtins_segments.bitwise, Some((22512, 22832).into()));
assert_eq!(builtins_segments.ec_op, Some((63472, 63920).into()));
assert_eq!(builtins_segments.ecdsa, Some((22384, 22512).into()));
assert_eq!(builtins_segments.keccak, Some((64368, 65392).into()));
assert_eq!(builtins_segments.mul_mod, None);
assert_eq!(builtins_segments.pedersen, Some((4464, 4656).into()));
assert_eq!(builtins_segments.poseidon, Some((65392, 65776).into()));
assert_eq!(
builtins_segments.range_check_bits_96,
Some((68464, 68528).into())
);
assert_eq!(
builtins_segments.range_check_bits_128,
Some((6000, 6064).into())
);
}
#[test]
fn test_read_from_small_files() {
let input = test_input("test_read_from_small_files");

// Test opcode components.
let components = input.state_transitions.casm_states_by_opcode;
assert_eq!(components.generic_opcode.len(), 0);
assert_eq!(components.add_ap_opcode.len(), 0);
assert_eq!(components.add_ap_opcode_imm.len(), 2);
assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 1);
assert_eq!(components.add_opcode_small_imm.len(), 500);
assert_eq!(components.add_opcode.len(), 0);
assert_eq!(components.add_opcode_small.len(), 0);
assert_eq!(components.add_opcode_imm.len(), 450);
assert_eq!(components.assert_eq_opcode.len(), 55);
assert_eq!(components.assert_eq_opcode_double_deref.len(), 2100);
assert_eq!(components.assert_eq_opcode_imm.len(), 1952);
assert_eq!(components.call_opcode.len(), 0);
assert_eq!(components.call_opcode_rel.len(), 462);
assert_eq!(components.call_opcode_op_1_base_fp.len(), 0);
assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 450);
assert_eq!(components.jnz_opcode.len(), 0);
assert_eq!(components.jnz_opcode_taken.len(), 0);
assert_eq!(components.jnz_opcode_dst_base_fp.len(), 11);
assert_eq!(components.jump_opcode_rel_imm.len(), 124626);
assert_eq!(components.jump_opcode_rel.len(), 0);
assert_eq!(components.jump_opcode_double_deref.len(), 0);
assert_eq!(components.jump_opcode.len(), 0);
assert_eq!(components.mul_opcode_small_imm.len(), 0);
assert_eq!(components.mul_opcode_small.len(), 0);
assert_eq!(components.mul_opcode.len(), 0);
assert_eq!(components.mul_opcode_imm.len(), 0);
assert_eq!(components.ret_opcode.len(), 462);

// Test builtins.
let builtins_segments = input.builtins_segments;
assert_eq!(builtins_segments.add_mod, None);
assert_eq!(builtins_segments.bitwise, Some((22512, 22832).into()));
assert_eq!(builtins_segments.ec_op, Some((63472, 63920).into()));
assert_eq!(builtins_segments.ecdsa, Some((22384, 22512).into()));
assert_eq!(builtins_segments.keccak, Some((64368, 65392).into()));
assert_eq!(builtins_segments.mul_mod, None);
assert_eq!(builtins_segments.pedersen, Some((4464, 4656).into()));
assert_eq!(builtins_segments.poseidon, Some((65392, 65776).into()));
assert_eq!(
builtins_segments.range_check_bits_96,
Some((68464, 68528).into())
);
assert_eq!(
builtins_segments.range_check_bits_128,
Some((6000, 6064).into())
);
}
}

0 comments on commit 76b11a1

Please sign in to comment.