diff --git a/crates/gateway/src/compilation.rs b/crates/gateway/src/compilation.rs index 18dd58197..30da8d97f 100644 --- a/crates/gateway/src/compilation.rs +++ b/crates/gateway/src/compilation.rs @@ -5,6 +5,7 @@ use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractCl use cairo_lang_starknet_classes::casm_contract_class::{ CasmContractClass, CasmContractEntryPoints, }; +use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass; use starknet_api::core::CompiledClassHash; use starknet_api::rpc_transaction::RPCDeclareTransaction; use starknet_sierra_compile::compile::compile_sierra_to_casm; @@ -29,44 +30,40 @@ impl GatewayCompiler { /// Formats the contract class for compilation, compiles it, and returns the compiled contract /// class wrapped in a [`ClassInfo`]. /// Assumes the contract class is of a Sierra program which is compiled to Casm. - pub fn compile_contract_class( + pub fn handle_declare_tx( &self, declare_tx: &RPCDeclareTransaction, ) -> GatewayResult { let RPCDeclareTransaction::V3(tx) = declare_tx; - let starknet_api_contract_class = &tx.contract_class; - let cairo_lang_contract_class = - into_contract_class_for_compilation(starknet_api_contract_class); + let rpc_contract_class = &tx.contract_class; + let cairo_lang_contract_class = into_contract_class_for_compilation(rpc_contract_class); - // Compile Sierra to Casm. + let casm_contract_class = self.compile(cairo_lang_contract_class)?; + + validate_compiled_class_hash(&casm_contract_class, tx.compiled_class_hash)?; + self.validate_casm_class(&casm_contract_class)?; + + build_result_class_info( + casm_contract_class, + rpc_contract_class.sierra_program.len(), + rpc_contract_class.abi.len(), + ) + } + + /// TODO(Arni): Pass the compilation args from the config. + fn compile( + &self, + cairo_lang_contract_class: CairoLangContractClass, + ) -> Result { let catch_unwind_result = panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class)); - let casm_contract_class = match catch_unwind_result { - Ok(compilation_result) => compilation_result?, + match catch_unwind_result { + Ok(compilation_result) => Ok(compilation_result?), Err(_) => { // TODO(Arni): Log the panic. - return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)); + Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)) } - }; - self.validate_casm_class(&casm_contract_class)?; - - let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash()); - if hash_result != tx.compiled_class_hash { - return Err(GatewayError::CompiledClassHashMismatch { - supplied: tx.compiled_class_hash, - hash_result, - }); } - - // Convert Casm contract class to Starknet contract class directly. - let blockifier_contract_class = - ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); - let class_info = ClassInfo::new( - &blockifier_contract_class, - starknet_api_contract_class.sierra_program.len(), - starknet_api_contract_class.abi.len(), - )?; - Ok(class_info) } // TODO(Arni): Add test. @@ -101,3 +98,31 @@ fn supported_builtins() -> &'static Vec { SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::>() }) } + +/// Returns a [`ClassInfo`] struct from the compiled contract class. +fn build_result_class_info( + casm_contract_class: CasmContractClass, + sierra_program_len: usize, + abi_len: usize, +) -> GatewayResult { + let blockifier_contract_class = + ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); + let class_info = ClassInfo::new(&blockifier_contract_class, sierra_program_len, abi_len)?; + Ok(class_info) +} + +/// Validates that the compiled class hash of the compiled contract class matches the supplied +/// compiled class hash. +fn validate_compiled_class_hash( + casm_contract_class: &CasmContractClass, + suppled_compiled_class_hash: CompiledClassHash, +) -> Result<(), GatewayError> { + let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash()); + if compiled_class_hash != suppled_compiled_class_hash { + return Err(GatewayError::CompiledClassHashMismatch { + supplied: suppled_compiled_class_hash, + hash_result: compiled_class_hash, + }); + } + Ok(()) +} diff --git a/crates/gateway/src/compilation_test.rs b/crates/gateway/src/compilation_test.rs index 22e81d10d..fc1a0cef3 100644 --- a/crates/gateway/src/compilation_test.rs +++ b/crates/gateway/src/compilation_test.rs @@ -27,7 +27,7 @@ fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: G tx.compiled_class_hash = supplied_hash; let declare_tx = RPCDeclareTransaction::V3(tx); - let result = gateway_compiler.compile_contract_class(&declare_tx); + let result = gateway_compiler.handle_declare_tx(&declare_tx); assert_matches!( result.unwrap_err(), GatewayError::CompiledClassHashMismatch { supplied, hash_result } @@ -45,7 +45,7 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) { tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec(); let declare_tx = RPCDeclareTransaction::V3(tx); - let result = gateway_compiler.compile_contract_class(&declare_tx); + let result = gateway_compiler.handle_declare_tx(&declare_tx); assert_matches!( result.unwrap_err(), GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError( @@ -63,7 +63,7 @@ fn test_compile_contract_class(gateway_compiler: GatewayCompiler) { let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; let contract_class = &declare_tx_v3.contract_class; - let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap(); + let class_info = gateway_compiler.handle_declare_tx(&declare_tx).unwrap(); assert_matches!(class_info.contract_class(), ContractClass::V1(_)); assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); assert_eq!(class_info.abi_length(), contract_class.abi.len()); diff --git a/crates/gateway/src/gateway.rs b/crates/gateway/src/gateway.rs index 50b51ce32..60fd16889 100644 --- a/crates/gateway/src/gateway.rs +++ b/crates/gateway/src/gateway.rs @@ -125,7 +125,7 @@ fn process_tx( // Compile Sierra to Casm. let optional_class_info = match &tx { RPCTransaction::Declare(declare_tx) => { - Some(gateway_compiler.compile_contract_class(declare_tx)?) + Some(gateway_compiler.handle_declare_tx(declare_tx)?) } _ => None, }; diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index 4eeedb8bc..3f66a20d7 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -115,7 +115,7 @@ fn calculate_hash(external_tx: &RPCTransaction) -> TransactionHash { let optional_class_info = match &external_tx { RPCTransaction::Declare(declare_tx) => Some( GatewayCompiler { config: GatewayCompilerConfig {} } - .compile_contract_class(declare_tx) + .handle_declare_tx(declare_tx) .unwrap(), ), _ => None, diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index 1a8b9fb4e..5fd889272 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -83,7 +83,7 @@ fn test_stateful_tx_validator( let optional_class_info = match &external_tx { RPCTransaction::Declare(declare_tx) => Some( GatewayCompiler { config: GatewayCompilerConfig {} } - .compile_contract_class(declare_tx) + .handle_declare_tx(declare_tx) .unwrap(), ), _ => None, diff --git a/crates/starknet_sierra_compile/src/utils.rs b/crates/starknet_sierra_compile/src/utils.rs index 717eaf176..5ccdcf9fa 100644 --- a/crates/starknet_sierra_compile/src/utils.rs +++ b/crates/starknet_sierra_compile/src/utils.rs @@ -6,7 +6,7 @@ use cairo_lang_starknet_classes::contract_class::{ }; use cairo_lang_utils::bigint::BigUintAsHex; use starknet_api::rpc_transaction::{ - ContractClass as StarknetApiContractClass, EntryPointByType as StarknetApiEntryPointByType, + ContractClass as RpcContractClass, EntryPointByType as StarknetApiEntryPointByType, }; use starknet_api::state::EntryPoint as StarknetApiEntryPoint; use starknet_types_core::felt::Felt; @@ -14,17 +14,17 @@ use starknet_types_core::felt::Felt; /// Retruns a [`CairoLangContractClass`] struct ready for Sierra to Casm compilation. Note the `abi` /// field is None as it is not relevant for the compilation. pub fn into_contract_class_for_compilation( - starknet_api_contract_class: &StarknetApiContractClass, + rpc_contract_class: &RpcContractClass, ) -> CairoLangContractClass { let sierra_program = - starknet_api_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect(); + rpc_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect(); let entry_points_by_type = - into_cairo_lang_contract_entry_points(&starknet_api_contract_class.entry_points_by_type); + into_cairo_lang_contract_entry_points(&rpc_contract_class.entry_points_by_type); CairoLangContractClass { sierra_program, sierra_program_debug_info: None, - contract_class_version: starknet_api_contract_class.contract_class_version.clone(), + contract_class_version: rpc_contract_class.contract_class_version.clone(), entry_points_by_type, abi: None, }