diff --git a/crates/gateway/src/compilation.rs b/crates/gateway/src/compilation.rs index 5381ce44c..1df6d0d99 100644 --- a/crates/gateway/src/compilation.rs +++ b/crates/gateway/src/compilation.rs @@ -11,6 +11,7 @@ use starknet_sierra_compile::compile::compile_sierra_to_casm; use starknet_sierra_compile::errors::CompilationUtilError; use starknet_sierra_compile::utils::into_contract_class_for_compilation; +use crate::config::GatewayCompilerConfig; use crate::errors::{GatewayError, GatewayResult}; use crate::utils::is_subsequence; @@ -18,44 +19,73 @@ use crate::utils::is_subsequence; #[path = "compilation_test.rs"] mod compilation_test; -/// Formats the contract class for compilation, compiles it, and returns the compiled contract class -/// wrapped in a [`ClassInfo`]. -/// Assumes the contract class is of a Sierra program which is compiled to Casm. -pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResult { - let RPCDeclareTransaction::V3(tx) = declare_tx; - let starknet_api_contract_class = &tx.contract_class; - let cairo_lang_contract_class = - into_contract_class_for_compilation(starknet_api_contract_class); +pub struct GatewayCompiler { + #[allow(dead_code)] + pub config: GatewayCompilerConfig, +} + +impl GatewayCompiler { + /// Formats the contract class for compilation, compiles it, and returns the compiled contract + /// class wrapped in a [`ClassInfo`]. + /// Assumes the contract class is of a Sierra program which is compiled to Casm. + pub fn compile_contract_class( + &self, + declare_tx: &RPCDeclareTransaction, + ) -> GatewayResult { + let RPCDeclareTransaction::V3(tx) = declare_tx; + let starknet_api_contract_class = &tx.contract_class; + let cairo_lang_contract_class = + into_contract_class_for_compilation(starknet_api_contract_class); - // Compile Sierra to Casm. - let catch_unwind_result = - panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class)); - let casm_contract_class = match catch_unwind_result { - Ok(compilation_result) => compilation_result?, - Err(_) => { - // TODO(Arni): Log the panic. - return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)); + // Compile Sierra to Casm. + let catch_unwind_result = + panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class)); + let casm_contract_class = match catch_unwind_result { + Ok(compilation_result) => compilation_result?, + Err(_) => { + // TODO(Arni): Log the panic. + return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)); + } + }; + self.validate_casm_class(&casm_contract_class)?; + + let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash()); + if hash_result != tx.compiled_class_hash { + return Err(GatewayError::CompiledClassHashMismatch { + supplied: tx.compiled_class_hash, + hash_result, + }); } - }; - validate_casm_class(&casm_contract_class)?; - let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash()); - if hash_result != tx.compiled_class_hash { - return Err(GatewayError::CompiledClassHashMismatch { - supplied: tx.compiled_class_hash, - hash_result, - }); + // Convert Casm contract class to Starknet contract class directly. + let blockifier_contract_class = + ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); + let class_info = ClassInfo::new( + &blockifier_contract_class, + starknet_api_contract_class.sierra_program.len(), + starknet_api_contract_class.abi.len(), + )?; + Ok(class_info) } - // Convert Casm contract class to Starknet contract class directly. - let blockifier_contract_class = - ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); - let class_info = ClassInfo::new( - &blockifier_contract_class, - starknet_api_contract_class.sierra_program.len(), - starknet_api_contract_class.abi.len(), - )?; - Ok(class_info) + // TODO(Arni): Add test. + fn validate_casm_class(&self, contract_class: &CasmContractClass) -> Result<(), GatewayError> { + let CasmContractEntryPoints { external, l1_handler, constructor } = + &contract_class.entry_points_by_type; + let entry_points_iterator = + external.iter().chain(l1_handler.iter()).chain(constructor.iter()); + + for entry_point in entry_points_iterator { + let builtins = &entry_point.builtins; + if !is_subsequence(builtins, supported_builtins()) { + return Err(GatewayError::UnsupportedBuiltins { + builtins: builtins.clone(), + supported_builtins: supported_builtins().to_vec(), + }); + } + } + Ok(()) + } } // TODO(Arni): Add to a config. @@ -70,21 +100,3 @@ fn supported_builtins() -> &'static Vec { SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::>() }) } - -// TODO(Arni): Add test. -fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> { - let CasmContractEntryPoints { external, l1_handler, constructor } = - &contract_class.entry_points_by_type; - let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter()); - - for entry_point in entry_points_iterator { - let builtins = &entry_point.builtins; - if !is_subsequence(builtins, supported_builtins()) { - return Err(GatewayError::UnsupportedBuiltins { - builtins: builtins.clone(), - supported_builtins: supported_builtins().to_vec(), - }); - } - } - Ok(()) -} diff --git a/crates/gateway/src/compilation_test.rs b/crates/gateway/src/compilation_test.rs index a1d8a1023..22e81d10d 100644 --- a/crates/gateway/src/compilation_test.rs +++ b/crates/gateway/src/compilation_test.rs @@ -2,15 +2,21 @@ use assert_matches::assert_matches; use blockifier::execution::contract_class::ContractClass; use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError; use mempool_test_utils::starknet_api_test_utils::declare_tx; +use rstest::{fixture, rstest}; use starknet_api::core::CompiledClassHash; use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; use starknet_sierra_compile::errors::CompilationUtilError; -use crate::compilation::compile_contract_class; +use crate::compilation::GatewayCompiler; use crate::errors::GatewayError; -#[test] -fn test_compile_contract_class_compiled_class_hash_missmatch() { +#[fixture] +fn gateway_compiler() -> GatewayCompiler { + GatewayCompiler { config: Default::default() } +} + +#[rstest] +fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) { let mut tx = assert_matches!( declare_tx(), RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx @@ -21,7 +27,7 @@ fn test_compile_contract_class_compiled_class_hash_missmatch() { tx.compiled_class_hash = supplied_hash; let declare_tx = RPCDeclareTransaction::V3(tx); - let result = compile_contract_class(&declare_tx); + let result = gateway_compiler.compile_contract_class(&declare_tx); assert_matches!( result.unwrap_err(), GatewayError::CompiledClassHashMismatch { supplied, hash_result } @@ -29,8 +35,8 @@ fn test_compile_contract_class_compiled_class_hash_missmatch() { ); } -#[test] -fn test_compile_contract_class_bad_sierra() { +#[rstest] +fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) { let mut tx = assert_matches!( declare_tx(), RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx @@ -39,7 +45,7 @@ fn test_compile_contract_class_bad_sierra() { tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec(); let declare_tx = RPCDeclareTransaction::V3(tx); - let result = compile_contract_class(&declare_tx); + let result = gateway_compiler.compile_contract_class(&declare_tx); assert_matches!( result.unwrap_err(), GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError( @@ -48,8 +54,8 @@ fn test_compile_contract_class_bad_sierra() { ) } -#[test] -fn test_compile_contract_class() { +#[rstest] +fn test_compile_contract_class(gateway_compiler: GatewayCompiler) { let declare_tx = assert_matches!( declare_tx(), RPCTransaction::Declare(declare_tx) => declare_tx @@ -57,7 +63,7 @@ fn test_compile_contract_class() { let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; let contract_class = &declare_tx_v3.contract_class; - let class_info = compile_contract_class(&declare_tx).unwrap(); + let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap(); assert_matches!(class_info.contract_class(), ContractClass::V1(_)); assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); assert_eq!(class_info.abi_length(), contract_class.abi.len()); diff --git a/crates/gateway/src/config.rs b/crates/gateway/src/config.rs index 8817b69eb..389bc4f69 100644 --- a/crates/gateway/src/config.rs +++ b/crates/gateway/src/config.rs @@ -16,6 +16,7 @@ pub struct GatewayConfig { pub network_config: GatewayNetworkConfig, pub stateless_tx_validator_config: StatelessTransactionValidatorConfig, pub stateful_tx_validator_config: StatefulTransactionValidatorConfig, + pub compiler_config: GatewayCompilerConfig, } impl SerializeConfig for GatewayConfig { @@ -30,6 +31,7 @@ impl SerializeConfig for GatewayConfig { self.stateful_tx_validator_config.dump(), "stateful_tx_validator_config", ), + append_sub_config_name(self.compiler_config.dump(), "compiler_config"), ] .into_iter() .flatten() @@ -293,3 +295,12 @@ impl StatefulTransactionValidatorConfig { } } } + +#[derive(Clone, Debug, Default, Serialize, Deserialize, Validate, PartialEq)] +pub struct GatewayCompilerConfig {} + +impl SerializeConfig for GatewayCompilerConfig { + fn dump(&self) -> BTreeMap { + BTreeMap::new() + } +} diff --git a/crates/gateway/src/gateway.rs b/crates/gateway/src/gateway.rs index 25eefe936..613cea6ca 100644 --- a/crates/gateway/src/gateway.rs +++ b/crates/gateway/src/gateway.rs @@ -13,7 +13,7 @@ use starknet_mempool_types::communication::SharedMempoolClient; use starknet_mempool_types::mempool_types::{Account, MempoolInput}; use tracing::{info, instrument}; -use crate::compilation::compile_contract_class; +use crate::compilation::GatewayCompiler; use crate::config::{GatewayConfig, GatewayNetworkConfig, RpcStateReaderConfig}; use crate::errors::{GatewayError, GatewayResult, GatewayRunError}; use crate::rpc_state_reader::RpcStateReaderFactory; @@ -122,7 +122,10 @@ fn process_tx( // Compile Sierra to Casm. let optional_class_info = match &tx { - RPCTransaction::Declare(declare_tx) => Some(compile_contract_class(declare_tx)?), + RPCTransaction::Declare(declare_tx) => { + let gateway_compiler = GatewayCompiler { config: Default::default() }; + Some(gateway_compiler.compile_contract_class(declare_tx)?) + } _ => None, }; diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index 0a824cb97..e4288afe5 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -18,7 +18,7 @@ use tokio::sync::mpsc::channel; use tokio::task; use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig}; -use crate::gateway::{add_tx, compile_contract_class, AppState, SharedMempoolClient}; +use crate::gateway::{add_tx, AppState, GatewayCompiler, SharedMempoolClient}; use crate::state_reader_test_utils::{ local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account, TestStateReaderFactory, @@ -110,7 +110,11 @@ async fn to_bytes(res: Response) -> Bytes { fn calculate_hash(external_tx: &RPCTransaction) -> TransactionHash { let optional_class_info = match &external_tx { - RPCTransaction::Declare(declare_tx) => Some(compile_contract_class(declare_tx).unwrap()), + RPCTransaction::Declare(declare_tx) => Some( + GatewayCompiler { config: Default::default() } + .compile_contract_class(declare_tx) + .unwrap(), + ), _ => None, }; diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index 630bc0c52..ffe1df19f 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -18,7 +18,7 @@ use starknet_api::rpc_transaction::RPCTransaction; use starknet_api::transaction::TransactionHash; use starknet_types_core::felt::Felt; -use crate::compilation::compile_contract_class; +use crate::compilation::GatewayCompiler; use crate::config::StatefulTransactionValidatorConfig; use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult}; use crate::state_reader_test_utils::{ @@ -95,7 +95,10 @@ fn test_stateful_tx_validator( stateful_validator: StatefulTransactionValidator, ) { let optional_class_info = match &external_tx { - RPCTransaction::Declare(declare_tx) => Some(compile_contract_class(declare_tx).unwrap()), + RPCTransaction::Declare(declare_tx) => { + let gateway_compiler = GatewayCompiler { config: Default::default() }; + Some(gateway_compiler.compile_contract_class(declare_tx).unwrap()) + } _ => None, }; diff --git a/crates/tests-integration/src/integration_test_utils.rs b/crates/tests-integration/src/integration_test_utils.rs index 58f9eeb98..882a953b6 100644 --- a/crates/tests-integration/src/integration_test_utils.rs +++ b/crates/tests-integration/src/integration_test_utils.rs @@ -28,8 +28,14 @@ async fn create_gateway_config() -> GatewayConfig { let socket = get_available_socket().await; let network_config = GatewayNetworkConfig { ip: socket.ip(), port: socket.port() }; let stateful_tx_validator_config = StatefulTransactionValidatorConfig::create_for_testing(); + let gateway_compiler_config = Default::default(); - GatewayConfig { network_config, stateless_tx_validator_config, stateful_tx_validator_config } + GatewayConfig { + network_config, + stateless_tx_validator_config, + stateful_tx_validator_config, + compiler_config: gateway_compiler_config, + } } pub async fn create_config(rpc_server_addr: SocketAddr) -> MempoolNodeConfig {