diff --git a/config/default_config.json b/config/default_config.json index 42fa1ecad..58732f860 100644 --- a/config/default_config.json +++ b/config/default_config.json @@ -9,6 +9,16 @@ "privacy": "Public", "value": true }, + "gateway_config.compiler_config.max_bytecode_size": { + "description": "Limitation of contract bytecode size", + "privacy": "Public", + "value": 81920 + }, + "gateway_config.compiler_config.max_raw_class_size": { + "description": "Limitation of contract class object size", + "privacy": "Public", + "value": 4089446 + }, "gateway_config.network_config.ip": { "description": "The gateway server ip.", "privacy": "Public", diff --git a/crates/gateway/src/compilation.rs b/crates/gateway/src/compilation.rs index 5814bff8b..37e49ca5e 100644 --- a/crates/gateway/src/compilation.rs +++ b/crates/gateway/src/compilation.rs @@ -19,9 +19,9 @@ use crate::utils::is_subsequence; #[path = "compilation_test.rs"] mod compilation_test; +// TODO(Define a function for `compile_contract_class` - which ignores the `config` parameter). #[derive(Clone)] pub struct GatewayCompiler { - #[allow(dead_code)] pub config: GatewayCompilerConfig, } @@ -48,7 +48,7 @@ impl GatewayCompiler { return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)); } }; - self.validate_casm_class(&casm_contract_class)?; + self.validate_casm(&casm_contract_class)?; let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash()); if hash_result != tx.compiled_class_hash { @@ -69,6 +69,12 @@ impl GatewayCompiler { Ok(class_info) } + fn validate_casm(&self, casm_contract_class: &CasmContractClass) -> Result<(), GatewayError> { + self.validate_casm_class(casm_contract_class)?; + self.validate_casm_class_size(casm_contract_class)?; + Ok(()) + } + // TODO(Arni): Add test. fn validate_casm_class(&self, contract_class: &CasmContractClass) -> Result<(), GatewayError> { let CasmContractEntryPoints { external, l1_handler, constructor } = @@ -87,6 +93,30 @@ impl GatewayCompiler { } Ok(()) } + + fn validate_casm_class_size( + &self, + casm_contract_class: &CasmContractClass, + ) -> Result<(), GatewayError> { + let bytecode_size = casm_contract_class.bytecode.len(); + if bytecode_size > self.config.max_bytecode_size { + return Err(GatewayError::CasmBytecodeSizeTooLarge { + bytecode_size, + max_bytecode_size: self.config.max_bytecode_size, + }); + } + let contract_class_object_size = serde_json::to_string(&casm_contract_class) + .expect("Unexpected error serializing Casm contract class.") + .len(); + if contract_class_object_size > self.config.max_raw_class_size { + return Err(GatewayError::CasmContractClassObjectSizeTooLarge { + contract_class_object_size, + max_contract_class_object_size: self.config.max_raw_class_size, + }); + } + + Ok(()) + } } // TODO(Arni): Add to a config. diff --git a/crates/gateway/src/compilation_config.rs b/crates/gateway/src/compilation_config.rs index 155b81cda..f3368a43f 100644 --- a/crates/gateway/src/compilation_config.rs +++ b/crates/gateway/src/compilation_config.rs @@ -1,15 +1,37 @@ use std::collections::BTreeMap; -use papyrus_config::dumping::SerializeConfig; -use papyrus_config::{ParamPath, SerializedParam}; +use papyrus_config::dumping::{ser_param, SerializeConfig}; +use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam}; use serde::{Deserialize, Serialize}; use validator::Validate; -#[derive(Clone, Debug, Default, Serialize, Deserialize, Validate, PartialEq)] -pub struct GatewayCompilerConfig {} +#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)] +pub struct GatewayCompilerConfig { + pub max_bytecode_size: usize, + pub max_raw_class_size: usize, +} + +impl Default for GatewayCompilerConfig { + fn default() -> Self { + Self { max_bytecode_size: 81920, max_raw_class_size: 4089446 } + } +} impl SerializeConfig for GatewayCompilerConfig { fn dump(&self) -> BTreeMap { - BTreeMap::new() + BTreeMap::from_iter([ + ser_param( + "max_bytecode_size", + &self.max_bytecode_size, + "Limitation of contract bytecode size", + ParamPrivacyInput::Public, + ), + ser_param( + "max_raw_class_size", + &self.max_raw_class_size, + "Limitation of contract class object size", + ParamPrivacyInput::Public, + ), + ]) } } diff --git a/crates/gateway/src/compilation_test.rs b/crates/gateway/src/compilation_test.rs index 3cf682d91..255af31c0 100644 --- a/crates/gateway/src/compilation_test.rs +++ b/crates/gateway/src/compilation_test.rs @@ -2,11 +2,13 @@ use assert_matches::assert_matches; use blockifier::execution::contract_class::ContractClass; use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError; use mempool_test_utils::starknet_api_test_utils::declare_tx; +use rstest::rstest; use starknet_api::core::CompiledClassHash; use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; use starknet_sierra_compile::errors::CompilationUtilError; use crate::compilation::GatewayCompiler; +use crate::compilation_config::GatewayCompilerConfig; use crate::errors::GatewayError; #[test] @@ -21,7 +23,10 @@ fn test_compile_contract_class_compiled_class_hash_missmatch() { tx.compiled_class_hash = supplied_hash; let declare_tx = RPCDeclareTransaction::V3(tx); - let result = GatewayCompiler { config: Default::default() }.compile_contract_class(&declare_tx); + let result = GatewayCompiler { + config: GatewayCompilerConfig { max_bytecode_size: 4800, max_raw_class_size: 111037 }, + } + .compile_contract_class(&declare_tx); assert_matches!( result.unwrap_err(), GatewayError::CompiledClassHashMismatch { supplied, hash_result } @@ -29,6 +34,50 @@ fn test_compile_contract_class_compiled_class_hash_missmatch() { ); } +#[rstest] +#[case::bytecode_size( + GatewayCompilerConfig { max_bytecode_size: 1, max_raw_class_size: usize::MAX}, + GatewayError::CasmBytecodeSizeTooLarge { bytecode_size: 4800, max_bytecode_size: 1 } +)] +#[case::raw_class_size( + GatewayCompilerConfig { max_bytecode_size: usize::MAX, max_raw_class_size: 1}, + GatewayError::CasmContractClassObjectSizeTooLarge { + contract_class_object_size: 111037, max_contract_class_object_size: 1 + } +)] +fn test_compile_contract_class_size_validation( + #[case] sierra_to_casm_compilation_config: GatewayCompilerConfig, + #[case] expected_error: GatewayError, +) { + let declare_tx = match declare_tx() { + RPCTransaction::Declare(declare_tx) => declare_tx, + _ => panic!("Invalid transaction type"), + }; + + let gateway_compiler = GatewayCompiler { config: sierra_to_casm_compilation_config }; + let result = gateway_compiler.compile_contract_class(&declare_tx); + if let GatewayError::CasmBytecodeSizeTooLarge { + bytecode_size: expected_bytecode_size, .. + } = expected_error + { + assert_matches!( + result.unwrap_err(), + GatewayError::CasmBytecodeSizeTooLarge { bytecode_size, .. } + if bytecode_size == expected_bytecode_size + ) + } else if let GatewayError::CasmContractClassObjectSizeTooLarge { + contract_class_object_size: expected_contract_class_object_size, + .. + } = expected_error + { + assert_matches!( + result.unwrap_err(), + GatewayError::CasmContractClassObjectSizeTooLarge { contract_class_object_size, .. } + if contract_class_object_size == expected_contract_class_object_size + ) + } +} + #[test] fn test_compile_contract_class_bad_sierra() { let mut tx = assert_matches!( @@ -57,8 +106,9 @@ fn test_compile_contract_class() { let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; let contract_class = &declare_tx_v3.contract_class; - let class_info = - GatewayCompiler { config: Default::default() }.compile_contract_class(&declare_tx).unwrap(); + let class_info = GatewayCompiler { config: GatewayCompilerConfig::default() } + .compile_contract_class(&declare_tx) + .unwrap(); assert_matches!(class_info.contract_class(), ContractClass::V1(_)); assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); assert_eq!(class_info.abi_length(), contract_class.abi.len()); diff --git a/crates/gateway/src/errors.rs b/crates/gateway/src/errors.rs index 24d5ef74b..9517a969e 100644 --- a/crates/gateway/src/errors.rs +++ b/crates/gateway/src/errors.rs @@ -19,6 +19,19 @@ use crate::compiler_version::{VersionId, VersionIdError}; /// Errors directed towards the end-user, as a result of gateway requests. #[derive(Debug, Error)] pub enum GatewayError { + #[error( + "Cannot declare Casm contract class with bytecode size of {bytecode_size}; max allowed \ + size: {max_bytecode_size}." + )] + CasmBytecodeSizeTooLarge { bytecode_size: usize, max_bytecode_size: usize }, + #[error( + "Cannot declare Casm contract class with size of {contract_class_object_size}; max \ + allowed size: {max_contract_class_object_size}." + )] + CasmContractClassObjectSizeTooLarge { + contract_class_object_size: usize, + max_contract_class_object_size: usize, + }, #[error(transparent)] CompilationError(#[from] CompilationUtilError), #[error( diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index e9e70d287..9d6a40038 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -39,21 +39,28 @@ pub fn app_state( mempool_client: SharedMempoolClient, state_reader_factory: TestStateReaderFactory, ) -> AppState { + const MAX_BYTECODE_SIZE: usize = 10000; + const MAX_RAW_CLASS_SIZE: usize = 1000000; AppState { stateless_tx_validator: StatelessTransactionValidator { config: StatelessTransactionValidatorConfig { validate_non_zero_l1_gas_fee: true, max_calldata_length: 10, max_signature_length: 2, - max_bytecode_size: 10000, - max_raw_class_size: 1000000, + max_bytecode_size: MAX_BYTECODE_SIZE, + max_raw_class_size: MAX_RAW_CLASS_SIZE, ..Default::default() }, }, stateful_tx_validator: Arc::new(StatefulTransactionValidator { config: StatefulTransactionValidatorConfig::create_for_testing(), }), - gateway_compiler: GatewayCompiler { config: GatewayCompilerConfig {} }, + gateway_compiler: GatewayCompiler { + config: GatewayCompilerConfig { + max_bytecode_size: MAX_BYTECODE_SIZE, + max_raw_class_size: MAX_RAW_CLASS_SIZE, + }, + }, state_reader_factory: Arc::new(state_reader_factory), mempool_client, } @@ -113,9 +120,14 @@ async fn to_bytes(res: Response) -> Bytes { fn calculate_hash(external_tx: &RPCTransaction) -> TransactionHash { let optional_class_info = match &external_tx { RPCTransaction::Declare(declare_tx) => Some( - GatewayCompiler { config: GatewayCompilerConfig {} } - .compile_contract_class(declare_tx) - .unwrap(), + GatewayCompiler { + config: GatewayCompilerConfig { + max_bytecode_size: 4800, + max_raw_class_size: 111037, + }, + } + .compile_contract_class(declare_tx) + .unwrap(), ), _ => None, }; diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index 9800ac700..dc6db8902 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -82,11 +82,10 @@ fn test_stateful_tx_validator( }, }; let optional_class_info = match &external_tx { - RPCTransaction::Declare(declare_tx) => Some( - GatewayCompiler { config: GatewayCompilerConfig {} } - .compile_contract_class(declare_tx) - .unwrap(), - ), + RPCTransaction::Declare(declare_tx) => { + let gateway_compiler = GatewayCompiler { config: GatewayCompilerConfig::default() }; + Some(gateway_compiler.compile_contract_class(declare_tx).unwrap()) + } _ => None, };