From 57fc04f455099f44c524410745feccb4abb94382 Mon Sep 17 00:00:00 2001 From: Arni Hod Date: Thu, 11 Jul 2024 10:43:50 +0300 Subject: [PATCH] refactor: gateway compiler handle declare tx --- crates/gateway/src/compilation.rs | 79 ++++++++++++------- crates/gateway/src/compilation_test.rs | 37 ++++----- crates/gateway/src/gateway.rs | 2 +- crates/gateway/src/gateway_test.rs | 2 +- .../stateful_transaction_validator_test.rs | 2 +- .../src/starknet_api_test_utils.rs | 16 +++- crates/starknet_sierra_compile/src/utils.rs | 10 +-- 7 files changed, 89 insertions(+), 59 deletions(-) diff --git a/crates/gateway/src/compilation.rs b/crates/gateway/src/compilation.rs index 18dd58197..30da8d97f 100644 --- a/crates/gateway/src/compilation.rs +++ b/crates/gateway/src/compilation.rs @@ -5,6 +5,7 @@ use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractCl use cairo_lang_starknet_classes::casm_contract_class::{ CasmContractClass, CasmContractEntryPoints, }; +use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass; use starknet_api::core::CompiledClassHash; use starknet_api::rpc_transaction::RPCDeclareTransaction; use starknet_sierra_compile::compile::compile_sierra_to_casm; @@ -29,44 +30,40 @@ impl GatewayCompiler { /// Formats the contract class for compilation, compiles it, and returns the compiled contract /// class wrapped in a [`ClassInfo`]. /// Assumes the contract class is of a Sierra program which is compiled to Casm. - pub fn compile_contract_class( + pub fn handle_declare_tx( &self, declare_tx: &RPCDeclareTransaction, ) -> GatewayResult { let RPCDeclareTransaction::V3(tx) = declare_tx; - let starknet_api_contract_class = &tx.contract_class; - let cairo_lang_contract_class = - into_contract_class_for_compilation(starknet_api_contract_class); + let rpc_contract_class = &tx.contract_class; + let cairo_lang_contract_class = into_contract_class_for_compilation(rpc_contract_class); - // Compile Sierra to Casm. + let casm_contract_class = self.compile(cairo_lang_contract_class)?; + + validate_compiled_class_hash(&casm_contract_class, tx.compiled_class_hash)?; + self.validate_casm_class(&casm_contract_class)?; + + build_result_class_info( + casm_contract_class, + rpc_contract_class.sierra_program.len(), + rpc_contract_class.abi.len(), + ) + } + + /// TODO(Arni): Pass the compilation args from the config. + fn compile( + &self, + cairo_lang_contract_class: CairoLangContractClass, + ) -> Result { let catch_unwind_result = panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class)); - let casm_contract_class = match catch_unwind_result { - Ok(compilation_result) => compilation_result?, + match catch_unwind_result { + Ok(compilation_result) => Ok(compilation_result?), Err(_) => { // TODO(Arni): Log the panic. - return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)); + Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)) } - }; - self.validate_casm_class(&casm_contract_class)?; - - let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash()); - if hash_result != tx.compiled_class_hash { - return Err(GatewayError::CompiledClassHashMismatch { - supplied: tx.compiled_class_hash, - hash_result, - }); } - - // Convert Casm contract class to Starknet contract class directly. - let blockifier_contract_class = - ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); - let class_info = ClassInfo::new( - &blockifier_contract_class, - starknet_api_contract_class.sierra_program.len(), - starknet_api_contract_class.abi.len(), - )?; - Ok(class_info) } // TODO(Arni): Add test. @@ -101,3 +98,31 @@ fn supported_builtins() -> &'static Vec { SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::>() }) } + +/// Returns a [`ClassInfo`] struct from the compiled contract class. +fn build_result_class_info( + casm_contract_class: CasmContractClass, + sierra_program_len: usize, + abi_len: usize, +) -> GatewayResult { + let blockifier_contract_class = + ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); + let class_info = ClassInfo::new(&blockifier_contract_class, sierra_program_len, abi_len)?; + Ok(class_info) +} + +/// Validates that the compiled class hash of the compiled contract class matches the supplied +/// compiled class hash. +fn validate_compiled_class_hash( + casm_contract_class: &CasmContractClass, + suppled_compiled_class_hash: CompiledClassHash, +) -> Result<(), GatewayError> { + let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash()); + if compiled_class_hash != suppled_compiled_class_hash { + return Err(GatewayError::CompiledClassHashMismatch { + supplied: suppled_compiled_class_hash, + hash_result: compiled_class_hash, + }); + } + Ok(()) +} diff --git a/crates/gateway/src/compilation_test.rs b/crates/gateway/src/compilation_test.rs index 22e81d10d..96d262f90 100644 --- a/crates/gateway/src/compilation_test.rs +++ b/crates/gateway/src/compilation_test.rs @@ -1,13 +1,16 @@ use assert_matches::assert_matches; use blockifier::execution::contract_class::ContractClass; use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError; -use mempool_test_utils::starknet_api_test_utils::declare_tx; +use mempool_test_utils::starknet_api_test_utils::{ + compiled_class_hash, contract_class, declare_tx, +}; use rstest::{fixture, rstest}; use starknet_api::core::CompiledClassHash; use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; use starknet_sierra_compile::errors::CompilationUtilError; +use starknet_sierra_compile::utils::into_contract_class_for_compilation; -use crate::compilation::GatewayCompiler; +use crate::compilation::{validate_compiled_class_hash, GatewayCompiler}; use crate::errors::GatewayError; #[fixture] @@ -17,17 +20,12 @@ fn gateway_compiler() -> GatewayCompiler { #[rstest] fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) { - let mut tx = assert_matches!( - declare_tx(), - RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx - ); - let expected_hash_result = tx.compiled_class_hash; + let casm_contract_class = + gateway_compiler.compile(into_contract_class_for_compilation(&contract_class())).unwrap(); + let expected_hash_result = compiled_class_hash(); let supplied_hash = CompiledClassHash::default(); - tx.compiled_class_hash = supplied_hash; - let declare_tx = RPCDeclareTransaction::V3(tx); - - let result = gateway_compiler.compile_contract_class(&declare_tx); + let result = validate_compiled_class_hash(&casm_contract_class, supplied_hash); assert_matches!( result.unwrap_err(), GatewayError::CompiledClassHashMismatch { supplied, hash_result } @@ -37,15 +35,12 @@ fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: G #[rstest] fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) { - let mut tx = assert_matches!( - declare_tx(), - RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx - ); - // Truncate the sierra program to trigger an error. - tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec(); - let declare_tx = RPCDeclareTransaction::V3(tx); + // Create a currupted contract class. + let mut contract_class = contract_class(); + contract_class.sierra_program = contract_class.sierra_program[..100].to_vec(); - let result = gateway_compiler.compile_contract_class(&declare_tx); + let cairo_lang_contract_class = into_contract_class_for_compilation(&contract_class); + let result = gateway_compiler.compile(cairo_lang_contract_class); assert_matches!( result.unwrap_err(), GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError( @@ -55,7 +50,7 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) { } #[rstest] -fn test_compile_contract_class(gateway_compiler: GatewayCompiler) { +fn test_handle_declare_tx(gateway_compiler: GatewayCompiler) { let declare_tx = assert_matches!( declare_tx(), RPCTransaction::Declare(declare_tx) => declare_tx @@ -63,7 +58,7 @@ fn test_compile_contract_class(gateway_compiler: GatewayCompiler) { let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; let contract_class = &declare_tx_v3.contract_class; - let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap(); + let class_info = gateway_compiler.handle_declare_tx(&declare_tx).unwrap(); assert_matches!(class_info.contract_class(), ContractClass::V1(_)); assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); assert_eq!(class_info.abi_length(), contract_class.abi.len()); diff --git a/crates/gateway/src/gateway.rs b/crates/gateway/src/gateway.rs index 6c26834b2..9a02930c8 100644 --- a/crates/gateway/src/gateway.rs +++ b/crates/gateway/src/gateway.rs @@ -128,7 +128,7 @@ fn process_tx( // Compile Sierra to Casm. let optional_class_info = match &tx { RPCTransaction::Declare(declare_tx) => { - Some(gateway_compiler.compile_contract_class(declare_tx)?) + Some(gateway_compiler.handle_declare_tx(declare_tx)?) } _ => None, }; diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index 5216c8797..c443db83c 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -117,7 +117,7 @@ fn calculate_hash( ) -> TransactionHash { let optional_class_info = match &external_tx { RPCTransaction::Declare(declare_tx) => { - Some(gateway_compiler.compile_contract_class(declare_tx).unwrap()) + Some(gateway_compiler.handle_declare_tx(declare_tx).unwrap()) } _ => None, }; diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index bd1216075..093032ba4 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -97,7 +97,7 @@ fn test_stateful_tx_validator( let optional_class_info = match &external_tx { RPCTransaction::Declare(declare_tx) => Some( GatewayCompiler { config: GatewayCompilerConfig {} } - .compile_contract_class(declare_tx) + .handle_declare_tx(declare_tx) .unwrap(), ), _ => None, diff --git a/crates/mempool_test_utils/src/starknet_api_test_utils.rs b/crates/mempool_test_utils/src/starknet_api_test_utils.rs index 278cf93aa..7c33b32bd 100644 --- a/crates/mempool_test_utils/src/starknet_api_test_utils.rs +++ b/crates/mempool_test_utils/src/starknet_api_test_utils.rs @@ -88,11 +88,21 @@ pub fn executable_resource_bounds_mapping() -> ResourceBoundsMapping { ) } -pub fn declare_tx() -> RPCTransaction { +/// Get the contract class used for testing. +pub fn contract_class() -> ContractClass { env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir."); let json_file_path = Path::new(CONTRACT_CLASS_FILE); - let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap(); - let compiled_class_hash = CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS)); + serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap() +} + +/// Get the compiled class hash corresponding to the contract class used for testing. +pub fn compiled_class_hash() -> CompiledClassHash { + CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS)) +} + +pub fn declare_tx() -> RPCTransaction { + let contract_class = contract_class(); + let compiled_class_hash = compiled_class_hash(); let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1); let account_address = account_contract.get_instance_address(0); diff --git a/crates/starknet_sierra_compile/src/utils.rs b/crates/starknet_sierra_compile/src/utils.rs index 717eaf176..5ccdcf9fa 100644 --- a/crates/starknet_sierra_compile/src/utils.rs +++ b/crates/starknet_sierra_compile/src/utils.rs @@ -6,7 +6,7 @@ use cairo_lang_starknet_classes::contract_class::{ }; use cairo_lang_utils::bigint::BigUintAsHex; use starknet_api::rpc_transaction::{ - ContractClass as StarknetApiContractClass, EntryPointByType as StarknetApiEntryPointByType, + ContractClass as RpcContractClass, EntryPointByType as StarknetApiEntryPointByType, }; use starknet_api::state::EntryPoint as StarknetApiEntryPoint; use starknet_types_core::felt::Felt; @@ -14,17 +14,17 @@ use starknet_types_core::felt::Felt; /// Retruns a [`CairoLangContractClass`] struct ready for Sierra to Casm compilation. Note the `abi` /// field is None as it is not relevant for the compilation. pub fn into_contract_class_for_compilation( - starknet_api_contract_class: &StarknetApiContractClass, + rpc_contract_class: &RpcContractClass, ) -> CairoLangContractClass { let sierra_program = - starknet_api_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect(); + rpc_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect(); let entry_points_by_type = - into_cairo_lang_contract_entry_points(&starknet_api_contract_class.entry_points_by_type); + into_cairo_lang_contract_entry_points(&rpc_contract_class.entry_points_by_type); CairoLangContractClass { sierra_program, sierra_program_debug_info: None, - contract_class_version: starknet_api_contract_class.contract_class_version.clone(), + contract_class_version: rpc_contract_class.contract_class_version.clone(), entry_points_by_type, abi: None, }