From 76b11a1f45a96d9b2c00e307274c6132a9cfb4d4 Mon Sep 17 00:00:00 2001 From: Stavbe Date: Thu, 6 Feb 2025 13:44:02 +0200 Subject: [PATCH] add "create_prover_input_from_file" func (#454) --- .../crates/prover/src/cairo_air/mod.rs | 40 ++- .../prover/src/input/state_transitions.rs | 26 +- .../crates/prover/src/input/vm_import/mod.rs | 243 ++++++++---------- 3 files changed, 150 insertions(+), 159 deletions(-) diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs index 696b9b132..4288cd2e0 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs @@ -225,14 +225,38 @@ pub enum CairoVerificationError { #[cfg(test)] pub mod tests { + use std::path::PathBuf; + use cairo_lang_casm::casm; use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; use super::ProverConfig; use crate::cairo_air::{prove_cairo, verify_cairo, ProverInput}; use crate::input::plain::input_from_plain_casm; + use crate::input::vm_import::adapt_vm_output; + + /// Creates a prover input from `pub.json`, `priv.json`, `mem`, and `trace` files. + /// + /// # Expects + /// - These files must be stored in the `test_data/test_name` directory and contain valid Cairo + /// program data. + /// - They can be downloaded from Google Storage using `./scripts/download_test_data.sh`. See + /// `input/README.md` for details. + /// + /// # Panics + /// - If it fails to convert the files into a prover input. + pub fn test_input(test_name: &str) -> ProverInput { + let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + d.push("test_data/"); + d.push(test_name); + + adapt_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect( + " + Failed to read test files. Checkout input/README.md.", + ) + } - fn test_input() -> ProverInput { + fn test_basic_cairo_air_input() -> ProverInput { let u128_max = u128::MAX; let instructions = casm! { // TODO(AlonH): Add actual range check segment. @@ -267,7 +291,8 @@ pub mod tests { #[test] fn test_basic_cairo_air() { - let cairo_proof = prove_cairo::(test_input(), test_cfg()).unwrap(); + let cairo_proof = + prove_cairo::(test_basic_cairo_air_input(), test_cfg()).unwrap(); verify_cairo::(cairo_proof).unwrap(); } @@ -281,12 +306,12 @@ pub mod tests { use stwo_prover::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; use super::*; - use crate::input::vm_import::tests::small_cairo_input; #[test] fn generate_and_serialise_proof() { let cairo_proof = - prove_cairo::(test_input(), test_cfg()).unwrap(); + prove_cairo::(test_basic_cairo_air_input(), test_cfg()) + .unwrap(); let mut output = Vec::new(); CairoSerialize::serialize(&cairo_proof, &mut output); let proof_str = output.iter().map(|v| v.to_string()).join(","); @@ -297,8 +322,11 @@ pub mod tests { #[test] fn test_full_cairo_air() { - let cairo_proof = - prove_cairo::(small_cairo_input(), test_cfg()).unwrap(); + let cairo_proof = prove_cairo::( + test_input("test_read_from_small_files"), + test_cfg(), + ) + .unwrap(); verify_cairo::(cairo_proof).unwrap(); } } diff --git a/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs b/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs index b9b35b549..917dcba82 100644 --- a/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs +++ b/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs @@ -637,33 +637,21 @@ fn is_small_mul(op0: MemoryValue, op_1: MemoryValue) -> bool { /// Tests instructions mapping. #[cfg(test)] mod mappings_tests { - use std::path::PathBuf; use cairo_lang_casm::casm; use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; use super::*; - use crate::cairo_air::tests::test_cfg; + use crate::cairo_air::tests::{test_cfg, test_input}; use crate::cairo_air::{prove_cairo, verify_cairo}; use crate::input::decode::{Instruction, OpcodeExtension}; use crate::input::memory::*; use crate::input::plain::input_from_plain_casm; - use crate::input::vm_import::adapt_vm_output; - use crate::input::ProverInput; - - pub fn all_opcode_coponents_program() -> ProverInput { - let mut d: PathBuf = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - d.push("test_data/test_prove_verify_all_components"); - adapt_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect( - " - Failed to read test files. Checkout input/README.md.", - ) - } #[test] #[cfg(feature = "slow-tests")] - fn test_prove_verify_all_components() { - let input = all_opcode_coponents_program(); + fn test_prove_verify_all_opcode_components() { + let input = test_input("test_prove_verify_all_components"); for (opcode, n_instances) in input.state_transitions.casm_states_by_opcode.counts() { // TODO(Stav): Remove when `Blake` opcode is in the VM. if opcode == "blake2s_opcode" { @@ -676,9 +664,11 @@ mod mappings_tests { opcode ); } - let cairo_proof = - prove_cairo::(all_opcode_coponents_program(), test_cfg()) - .unwrap(); + let cairo_proof = prove_cairo::( + test_input("test_prove_verify_all_components"), + test_cfg(), + ) + .unwrap(); verify_cairo::(cairo_proof).unwrap(); } diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index 93c418fa1..746efc294 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -190,142 +190,115 @@ impl Iterator for MemoryEntryIter<'_, R> { } #[cfg(test)] -pub mod tests { - use std::path::PathBuf; - - use super::*; - - pub fn large_cairo_input() -> ProverInput { - let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - d.push("test_data/test_read_from_large_files"); - - adapt_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect( - " - Failed to read test files. Checkout input/README.md.", - ) +#[cfg(feature = "slow-tests")] +pub mod slow_tests { + use crate::cairo_air::tests::test_input; + + #[test] + fn test_read_from_large_files() { + let input = test_input("test_read_from_large_files"); + + // Test opcode components. + let components = input.state_transitions.casm_states_by_opcode; + assert_eq!(components.generic_opcode.len(), 0); + assert_eq!(components.add_ap_opcode.len(), 0); + assert_eq!(components.add_ap_opcode_imm.len(), 36895); + assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 33); + assert_eq!(components.add_opcode_small_imm.len(), 84732); + assert_eq!(components.add_opcode.len(), 189425); + assert_eq!(components.add_opcode_small.len(), 36623); + assert_eq!(components.add_opcode_imm.len(), 22089); + assert_eq!(components.assert_eq_opcode.len(), 233432); + assert_eq!(components.assert_eq_opcode_double_deref.len(), 811061); + assert_eq!(components.assert_eq_opcode_imm.len(), 43184); + assert_eq!(components.call_opcode.len(), 0); + assert_eq!(components.call_opcode_rel.len(), 49439); + assert_eq!(components.call_opcode_op_1_base_fp.len(), 33); + assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 11235); + assert_eq!(components.jnz_opcode.len(), 27032); + assert_eq!(components.jnz_opcode_taken.len(), 51060); + assert_eq!(components.jnz_opcode_dst_base_fp.len(), 5100); + assert_eq!(components.jump_opcode_rel_imm.len(), 31873865); + assert_eq!(components.jump_opcode_rel.len(), 500); + assert_eq!(components.jump_opcode_double_deref.len(), 32); + assert_eq!(components.jump_opcode.len(), 0); + assert_eq!(components.mul_opcode_small_imm.len(), 7234); + assert_eq!(components.mul_opcode_small.len(), 7203); + assert_eq!(components.mul_opcode.len(), 3943); + assert_eq!(components.mul_opcode_imm.len(), 10809); + assert_eq!(components.ret_opcode.len(), 49472); + + // Test builtins. + let builtins_segments = input.builtins_segments; + assert_eq!(builtins_segments.add_mod, None); + assert_eq!(builtins_segments.bitwise, None); + assert_eq!(builtins_segments.ec_op, Some((16428600, 16428824).into())); + assert_eq!(builtins_segments.ecdsa, None); + assert_eq!(builtins_segments.keccak, None); + assert_eq!(builtins_segments.mul_mod, None); + assert_eq!(builtins_segments.pedersen, Some((1322552, 1347128).into())); + assert_eq!( + builtins_segments.poseidon, + Some((16920120, 17706552).into()) + ); + assert_eq!(builtins_segments.range_check_bits_96, None); + assert_eq!( + builtins_segments.range_check_bits_128, + Some((1715768, 1781304).into()) + ); } - pub fn small_cairo_input() -> ProverInput { - let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - d.push("test_data/test_read_from_small_files"); - adapt_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect( - " - Failed to read test files. Checkout input/README.md.", - ) - } - - #[cfg(test)] - #[cfg(feature = "slow-tests")] - pub mod slow_tests { - - use super::*; - - #[test] - fn test_read_from_large_files() { - let input = large_cairo_input(); - - // Test opcode components. - let components = input.state_transitions.casm_states_by_opcode; - assert_eq!(components.generic_opcode.len(), 0); - assert_eq!(components.add_ap_opcode.len(), 0); - assert_eq!(components.add_ap_opcode_imm.len(), 36895); - assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 33); - assert_eq!(components.add_opcode_small_imm.len(), 84732); - assert_eq!(components.add_opcode.len(), 189425); - assert_eq!(components.add_opcode_small.len(), 36623); - assert_eq!(components.add_opcode_imm.len(), 22089); - assert_eq!(components.assert_eq_opcode.len(), 233432); - assert_eq!(components.assert_eq_opcode_double_deref.len(), 811061); - assert_eq!(components.assert_eq_opcode_imm.len(), 43184); - assert_eq!(components.call_opcode.len(), 0); - assert_eq!(components.call_opcode_rel.len(), 49439); - assert_eq!(components.call_opcode_op_1_base_fp.len(), 33); - assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 11235); - assert_eq!(components.jnz_opcode.len(), 27032); - assert_eq!(components.jnz_opcode_taken.len(), 51060); - assert_eq!(components.jnz_opcode_dst_base_fp.len(), 5100); - assert_eq!(components.jump_opcode_rel_imm.len(), 31873865); - assert_eq!(components.jump_opcode_rel.len(), 500); - assert_eq!(components.jump_opcode_double_deref.len(), 32); - assert_eq!(components.jump_opcode.len(), 0); - assert_eq!(components.mul_opcode_small_imm.len(), 7234); - assert_eq!(components.mul_opcode_small.len(), 7203); - assert_eq!(components.mul_opcode.len(), 3943); - assert_eq!(components.mul_opcode_imm.len(), 10809); - assert_eq!(components.ret_opcode.len(), 49472); - - // Test builtins. - let builtins_segments = input.builtins_segments; - assert_eq!(builtins_segments.add_mod, None); - assert_eq!(builtins_segments.bitwise, None); - assert_eq!(builtins_segments.ec_op, Some((16428600, 16428824).into())); - assert_eq!(builtins_segments.ecdsa, None); - assert_eq!(builtins_segments.keccak, None); - assert_eq!(builtins_segments.mul_mod, None); - assert_eq!(builtins_segments.pedersen, Some((1322552, 1347128).into())); - assert_eq!( - builtins_segments.poseidon, - Some((16920120, 17706552).into()) - ); - assert_eq!(builtins_segments.range_check_bits_96, None); - assert_eq!( - builtins_segments.range_check_bits_128, - Some((1715768, 1781304).into()) - ); - } - - #[test] - fn test_read_from_small_files() { - let input = small_cairo_input(); - - // Test opcode components. - let components = input.state_transitions.casm_states_by_opcode; - assert_eq!(components.generic_opcode.len(), 0); - assert_eq!(components.add_ap_opcode.len(), 0); - assert_eq!(components.add_ap_opcode_imm.len(), 2); - assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 1); - assert_eq!(components.add_opcode_small_imm.len(), 500); - assert_eq!(components.add_opcode.len(), 0); - assert_eq!(components.add_opcode_small.len(), 0); - assert_eq!(components.add_opcode_imm.len(), 450); - assert_eq!(components.assert_eq_opcode.len(), 55); - assert_eq!(components.assert_eq_opcode_double_deref.len(), 2100); - assert_eq!(components.assert_eq_opcode_imm.len(), 1952); - assert_eq!(components.call_opcode.len(), 0); - assert_eq!(components.call_opcode_rel.len(), 462); - assert_eq!(components.call_opcode_op_1_base_fp.len(), 0); - assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 450); - assert_eq!(components.jnz_opcode.len(), 0); - assert_eq!(components.jnz_opcode_taken.len(), 0); - assert_eq!(components.jnz_opcode_dst_base_fp.len(), 11); - assert_eq!(components.jump_opcode_rel_imm.len(), 124626); - assert_eq!(components.jump_opcode_rel.len(), 0); - assert_eq!(components.jump_opcode_double_deref.len(), 0); - assert_eq!(components.jump_opcode.len(), 0); - assert_eq!(components.mul_opcode_small_imm.len(), 0); - assert_eq!(components.mul_opcode_small.len(), 0); - assert_eq!(components.mul_opcode.len(), 0); - assert_eq!(components.mul_opcode_imm.len(), 0); - assert_eq!(components.ret_opcode.len(), 462); - - // Test builtins. - let builtins_segments = input.builtins_segments; - assert_eq!(builtins_segments.add_mod, None); - assert_eq!(builtins_segments.bitwise, Some((22512, 22832).into())); - assert_eq!(builtins_segments.ec_op, Some((63472, 63920).into())); - assert_eq!(builtins_segments.ecdsa, Some((22384, 22512).into())); - assert_eq!(builtins_segments.keccak, Some((64368, 65392).into())); - assert_eq!(builtins_segments.mul_mod, None); - assert_eq!(builtins_segments.pedersen, Some((4464, 4656).into())); - assert_eq!(builtins_segments.poseidon, Some((65392, 65776).into())); - assert_eq!( - builtins_segments.range_check_bits_96, - Some((68464, 68528).into()) - ); - assert_eq!( - builtins_segments.range_check_bits_128, - Some((6000, 6064).into()) - ); - } + #[test] + fn test_read_from_small_files() { + let input = test_input("test_read_from_small_files"); + + // Test opcode components. + let components = input.state_transitions.casm_states_by_opcode; + assert_eq!(components.generic_opcode.len(), 0); + assert_eq!(components.add_ap_opcode.len(), 0); + assert_eq!(components.add_ap_opcode_imm.len(), 2); + assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 1); + assert_eq!(components.add_opcode_small_imm.len(), 500); + assert_eq!(components.add_opcode.len(), 0); + assert_eq!(components.add_opcode_small.len(), 0); + assert_eq!(components.add_opcode_imm.len(), 450); + assert_eq!(components.assert_eq_opcode.len(), 55); + assert_eq!(components.assert_eq_opcode_double_deref.len(), 2100); + assert_eq!(components.assert_eq_opcode_imm.len(), 1952); + assert_eq!(components.call_opcode.len(), 0); + assert_eq!(components.call_opcode_rel.len(), 462); + assert_eq!(components.call_opcode_op_1_base_fp.len(), 0); + assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 450); + assert_eq!(components.jnz_opcode.len(), 0); + assert_eq!(components.jnz_opcode_taken.len(), 0); + assert_eq!(components.jnz_opcode_dst_base_fp.len(), 11); + assert_eq!(components.jump_opcode_rel_imm.len(), 124626); + assert_eq!(components.jump_opcode_rel.len(), 0); + assert_eq!(components.jump_opcode_double_deref.len(), 0); + assert_eq!(components.jump_opcode.len(), 0); + assert_eq!(components.mul_opcode_small_imm.len(), 0); + assert_eq!(components.mul_opcode_small.len(), 0); + assert_eq!(components.mul_opcode.len(), 0); + assert_eq!(components.mul_opcode_imm.len(), 0); + assert_eq!(components.ret_opcode.len(), 462); + + // Test builtins. + let builtins_segments = input.builtins_segments; + assert_eq!(builtins_segments.add_mod, None); + assert_eq!(builtins_segments.bitwise, Some((22512, 22832).into())); + assert_eq!(builtins_segments.ec_op, Some((63472, 63920).into())); + assert_eq!(builtins_segments.ecdsa, Some((22384, 22512).into())); + assert_eq!(builtins_segments.keccak, Some((64368, 65392).into())); + assert_eq!(builtins_segments.mul_mod, None); + assert_eq!(builtins_segments.pedersen, Some((4464, 4656).into())); + assert_eq!(builtins_segments.poseidon, Some((65392, 65776).into())); + assert_eq!( + builtins_segments.range_check_bits_96, + Some((68464, 68528).into()) + ); + assert_eq!( + builtins_segments.range_check_bits_128, + Some((6000, 6064).into()) + ); } }