diff --git a/crates/gateway/src/compilation.rs b/crates/gateway/src/compilation.rs new file mode 100644 index 000000000..7e3f148cf --- /dev/null +++ b/crates/gateway/src/compilation.rs @@ -0,0 +1,109 @@ +use std::panic; +use std::sync::OnceLock; + +use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractClassV1}; +use blockifier::execution::execution_utils::felt_to_stark_felt; +use cairo_lang_starknet_classes::casm_contract_class::{ + CasmContractClass, CasmContractEntryPoints, +}; +use starknet_api::core::CompiledClassHash; +use starknet_api::rpc_transaction::RPCDeclareTransaction; +use starknet_api::transaction::Builtin; +use starknet_sierra_compile::compile::compile_sierra_to_casm; +use starknet_sierra_compile::errors::CompilationUtilError; +use starknet_sierra_compile::utils::into_contract_class_for_compilation; + +use crate::errors::{GatewayError, GatewayResult}; +use crate::utils::{is_subsequence, IntoOsOrderEnumIteratorExt}; + +#[cfg(test)] +#[path = "compilation_test.rs"] +mod compilation_test; + +/// Formats the contract class for compilation, compiles it, and returns the compiled contract class +/// wrapped in a [`ClassInfo`]. +/// Assumes the contract class is of a Sierra program which is compiled to Casm. +pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResult { + let RPCDeclareTransaction::V3(tx) = declare_tx; + let starknet_api_contract_class = &tx.contract_class; + let cairo_lang_contract_class = + into_contract_class_for_compilation(starknet_api_contract_class); + + // Compile Sierra to Casm. + let catch_unwind_result = + panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class)); + let casm_contract_class = match catch_unwind_result { + Ok(compilation_result) => compilation_result?, + Err(_) => { + // TODO(Arni): Log the panic. + return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)); + } + }; + validate_casm_class(&casm_contract_class)?; + + let hash_result = + CompiledClassHash(felt_to_stark_felt(&casm_contract_class.compiled_class_hash())); + if hash_result != tx.compiled_class_hash { + return Err(GatewayError::CompiledClassHashMismatch { + supplied: tx.compiled_class_hash, + hash_result, + }); + } + + // Convert Casm contract class to Starknet contract class directly. + let blockifier_contract_class = + ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); + let class_info = ClassInfo::new( + &blockifier_contract_class, + starknet_api_contract_class.sierra_program.len(), + starknet_api_contract_class.abi.len(), + )?; + Ok(class_info) +} + +// List of supported builtins. +// This is an explicit function so that it is explicitly desiced which builtins are supported. +// If new builtins are added, they should be added here. +fn is_supported_builtin(builtin: &Builtin) -> bool { + match builtin { + Builtin::RangeCheck + | Builtin::Pedersen + | Builtin::Poseidon + | Builtin::EcOp + | Builtin::Ecdsa + | Builtin::Bitwise + | Builtin::SegmentArena => true, + Builtin::Keccak => false, + } +} + +// TODO(Arni): Add to a config. +// TODO(Arni): Use the Builtin enum from Starknet-api, and explicitly tag each builtin as supported +// or unsupported so that the compiler would alert us on new builtins. +fn supported_builtins() -> &'static Vec { + static SUPPORTED_BUILTINS: OnceLock> = OnceLock::new(); + SUPPORTED_BUILTINS.get_or_init(|| { + Builtin::os_order_iter() + .filter(is_supported_builtin) + .map(|builtin| builtin.name().to_string()) + .collect::>() + }) +} + +// TODO(Arni): Add test. +fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> { + let CasmContractEntryPoints { external, l1_handler, constructor } = + &contract_class.entry_points_by_type; + let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter()); + + for entry_point in entry_points_iterator { + let builtins = &entry_point.builtins; + if !is_subsequence(builtins, supported_builtins()) { + return Err(GatewayError::UnsupportedBuiltins { + builtins: builtins.clone(), + supported_builtins: supported_builtins().to_vec(), + }); + } + } + Ok(()) +} diff --git a/crates/gateway/src/compilation_test.rs b/crates/gateway/src/compilation_test.rs new file mode 100644 index 000000000..ea0336ded --- /dev/null +++ b/crates/gateway/src/compilation_test.rs @@ -0,0 +1,64 @@ +use assert_matches::assert_matches; +use blockifier::execution::contract_class::ContractClass; +use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError; +use starknet_api::core::CompiledClassHash; +use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; +use starknet_sierra_compile::errors::CompilationUtilError; +use test_utils::starknet_api_test_utils::declare_tx; + +use crate::compilation::compile_contract_class; +use crate::errors::GatewayError; + +#[test] +fn test_compile_contract_class_compiled_class_hash_missmatch() { + let mut tx = assert_matches!( + declare_tx(), + RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx + ); + let expected_hash_result = tx.compiled_class_hash; + let supplied_hash = CompiledClassHash::default(); + + tx.compiled_class_hash = supplied_hash; + let declare_tx = RPCDeclareTransaction::V3(tx); + + let result = compile_contract_class(&declare_tx); + assert_matches!( + result.unwrap_err(), + GatewayError::CompiledClassHashMismatch { supplied, hash_result } + if supplied == supplied_hash && hash_result == expected_hash_result + ); +} + +#[test] +fn test_compile_contract_class_bad_sierra() { + let mut tx = assert_matches!( + declare_tx(), + RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx + ); + // Truncate the sierra program to trigger an error. + tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec(); + let declare_tx = RPCDeclareTransaction::V3(tx); + + let result = compile_contract_class(&declare_tx); + assert_matches!( + result.unwrap_err(), + GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError( + AllowedLibfuncsError::SierraProgramError + )) + ) +} + +#[test] +fn test_compile_contract_class() { + let declare_tx = assert_matches!( + declare_tx(), + RPCTransaction::Declare(declare_tx) => declare_tx + ); + let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; + let contract_class = &declare_tx_v3.contract_class; + + let class_info = compile_contract_class(&declare_tx).unwrap(); + assert_matches!(class_info.contract_class(), ContractClass::V1(_)); + assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); + assert_eq!(class_info.abi_length(), contract_class.abi.len()); +} diff --git a/crates/gateway/src/errors.rs b/crates/gateway/src/errors.rs index ac43ed592..24d5ef74b 100644 --- a/crates/gateway/src/errors.rs +++ b/crates/gateway/src/errors.rs @@ -42,6 +42,8 @@ pub enum GatewayError { UnsupportedBuiltins { builtins: Vec, supported_builtins: Vec }, } +pub type GatewayResult = Result; + impl IntoResponse for GatewayError { // TODO(Arni, 1/5/2024): Be more fine tuned about the error response. Not all Gateway errors // are internal server errors. diff --git a/crates/gateway/src/gateway.rs b/crates/gateway/src/gateway.rs index 9102c8c86..352a80d51 100644 --- a/crates/gateway/src/gateway.rs +++ b/crates/gateway/src/gateway.rs @@ -1,43 +1,30 @@ use std::clone::Clone; use std::net::SocketAddr; -use std::panic; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use async_trait::async_trait; use axum::extract::State; use axum::routing::{get, post}; use axum::{Json, Router}; -use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractClassV1}; -use blockifier::execution::execution_utils::felt_to_stark_felt; -use cairo_lang_starknet_classes::casm_contract_class::{ - CasmContractClass, CasmContractEntryPoints, -}; -use starknet_api::core::CompiledClassHash; -use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; -use starknet_api::transaction::{Builtin, TransactionHash}; +use starknet_api::rpc_transaction::RPCTransaction; +use starknet_api::transaction::TransactionHash; use starknet_mempool_infra::component_runner::{ComponentRunner, ComponentStartError}; use starknet_mempool_types::communication::SharedMempoolClient; use starknet_mempool_types::mempool_types::{Account, MempoolInput}; -use starknet_sierra_compile::compile::compile_sierra_to_casm; -use starknet_sierra_compile::errors::CompilationUtilError; -use starknet_sierra_compile::utils::into_contract_class_for_compilation; +use crate::compilation::compile_contract_class; use crate::config::{GatewayConfig, GatewayNetworkConfig, RpcStateReaderConfig}; -use crate::errors::{GatewayError, GatewayRunError}; +use crate::errors::{GatewayError, GatewayResult, GatewayRunError}; use crate::rpc_state_reader::RpcStateReaderFactory; use crate::state_reader::StateReaderFactory; use crate::stateful_transaction_validator::StatefulTransactionValidator; use crate::stateless_transaction_validator::StatelessTransactionValidator; -use crate::utils::{ - external_tx_to_thin_tx, get_sender_address, is_subsequence, IntoOsOrderEnumIteratorExt, -}; +use crate::utils::{external_tx_to_thin_tx, get_sender_address}; #[cfg(test)] #[path = "gateway_test.rs"] pub mod gateway_test; -pub type GatewayResult = Result; - pub struct Gateway { pub config: GatewayConfig, app_state: AppState, @@ -147,47 +134,6 @@ fn process_tx( }) } -/// Formats the contract class for compilation, compiles it, and returns the compiled contract class -/// wrapped in a [`ClassInfo`]. -/// Assumes the contract class is of a Sierra program which is compiled to Casm. -pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResult { - let RPCDeclareTransaction::V3(tx) = declare_tx; - let starknet_api_contract_class = &tx.contract_class; - let cairo_lang_contract_class = - into_contract_class_for_compilation(starknet_api_contract_class); - - // Compile Sierra to Casm. - let catch_unwind_result = - panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class)); - let casm_contract_class = match catch_unwind_result { - Ok(compilation_result) => compilation_result?, - Err(_) => { - // TODO(Arni): Log the panic. - return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)); - } - }; - validate_casm_class(&casm_contract_class)?; - - let hash_result = - CompiledClassHash(felt_to_stark_felt(&casm_contract_class.compiled_class_hash())); - if hash_result != tx.compiled_class_hash { - return Err(GatewayError::CompiledClassHashMismatch { - supplied: tx.compiled_class_hash, - hash_result, - }); - } - - // Convert Casm contract class to Starknet contract class directly. - let blockifier_contract_class = - ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); - let class_info = ClassInfo::new( - &blockifier_contract_class, - starknet_api_contract_class.sierra_program.len(), - starknet_api_contract_class.abi.len(), - )?; - Ok(class_info) -} - pub fn create_gateway( config: GatewayConfig, rpc_state_reader_config: RpcStateReaderConfig, @@ -205,50 +151,3 @@ impl ComponentRunner for Gateway { Ok(()) } } - -// List of supported builtins. -// This is an explicit function so that it is explicitly desiced which builtins are supported. -// If new builtins are added, they should be added here. -fn is_supported_builtin(builtin: &Builtin) -> bool { - match builtin { - Builtin::RangeCheck - | Builtin::Pedersen - | Builtin::Poseidon - | Builtin::EcOp - | Builtin::Ecdsa - | Builtin::Bitwise - | Builtin::SegmentArena => true, - Builtin::Keccak => false, - } -} - -// TODO(Arni): Add to a config. -// TODO(Arni): Use the Builtin enum from Starknet-api, and explicitly tag each builtin as supported -// or unsupported so that the compiler would alert us on new builtins. -fn supported_builtins() -> &'static Vec { - static SUPPORTED_BUILTINS: OnceLock> = OnceLock::new(); - SUPPORTED_BUILTINS.get_or_init(|| { - Builtin::os_order_iter() - .filter(is_supported_builtin) - .map(|builtin| builtin.name().to_string()) - .collect::>() - }) -} - -// TODO(Arni): Add test. -fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> { - let CasmContractEntryPoints { external, l1_handler, constructor } = - &contract_class.entry_points_by_type; - let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter()); - - for entry_point in entry_points_iterator { - let builtins = &entry_point.builtins; - if !is_subsequence(builtins, supported_builtins()) { - return Err(GatewayError::UnsupportedBuiltins { - builtins: builtins.clone(), - supported_builtins: supported_builtins().to_vec(), - }); - } - } - Ok(()) -} diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index 5207672fe..25f15d954 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -1,29 +1,23 @@ use std::sync::Arc; -use assert_matches::assert_matches; use axum::body::{Bytes, HttpBody}; use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use blockifier::context::ChainInfo; -use blockifier::execution::contract_class::ContractClass; use blockifier::test_utils::CairoVersion; -use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError; use rstest::{fixture, rstest}; -use starknet_api::core::CompiledClassHash; -use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; +use starknet_api::rpc_transaction::RPCTransaction; use starknet_api::transaction::TransactionHash; use starknet_mempool::communication::create_mempool_server; use starknet_mempool::mempool::Mempool; use starknet_mempool_infra::component_server::ComponentServerStarter; use starknet_mempool_types::communication::{MempoolClientImpl, MempoolRequestAndResponseSender}; -use starknet_sierra_compile::errors::CompilationUtilError; use test_utils::starknet_api_test_utils::{declare_tx, deploy_account_tx, invoke_tx}; use tokio::sync::mpsc::channel; use tokio::task; use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig}; -use crate::errors::GatewayError; use crate::gateway::{add_tx, compile_contract_class, AppState, SharedMempoolClient}; use crate::state_reader_test_utils::{ local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account, @@ -110,60 +104,6 @@ async fn test_add_tx( assert_eq!(tx_hash, serde_json::from_slice(response_bytes).unwrap()); } -#[test] -fn test_compile_contract_class_compiled_class_hash_missmatch() { - let mut tx = assert_matches!( - declare_tx(), - RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx - ); - let expected_hash_result = tx.compiled_class_hash; - let supplied_hash = CompiledClassHash::default(); - - tx.compiled_class_hash = supplied_hash; - let declare_tx = RPCDeclareTransaction::V3(tx); - - let result = compile_contract_class(&declare_tx); - assert_matches!( - result.unwrap_err(), - GatewayError::CompiledClassHashMismatch { supplied, hash_result } - if supplied == supplied_hash && hash_result == expected_hash_result - ); -} - -#[test] -fn test_compile_contract_class_bad_sierra() { - let mut tx = assert_matches!( - declare_tx(), - RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx - ); - // Truncate the sierra program to trigger an error. - tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec(); - let declare_tx = RPCDeclareTransaction::V3(tx); - - let result = compile_contract_class(&declare_tx); - assert_matches!( - result.unwrap_err(), - GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError( - AllowedLibfuncsError::SierraProgramError - )) - ) -} - -#[test] -fn test_compile_contract_class() { - let declare_tx = assert_matches!( - declare_tx(), - RPCTransaction::Declare(declare_tx) => declare_tx - ); - let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; - let contract_class = &declare_tx_v3.contract_class; - - let class_info = compile_contract_class(&declare_tx).unwrap(); - assert_matches!(class_info.contract_class(), ContractClass::V1(_)); - assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); - assert_eq!(class_info.abi_length(), contract_class.abi.len()); -} - async fn to_bytes(res: Response) -> Bytes { res.into_body().collect().await.unwrap().to_bytes() } diff --git a/crates/gateway/src/lib.rs b/crates/gateway/src/lib.rs index c649eb2dd..59f79bf88 100644 --- a/crates/gateway/src/lib.rs +++ b/crates/gateway/src/lib.rs @@ -1,4 +1,5 @@ pub mod communication; +pub mod compilation; pub mod compiler_version; pub mod config; pub mod errors; diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index d3821782f..746d203e4 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -11,9 +11,9 @@ use test_utils::starknet_api_test_utils::{ VALID_L1_GAS_MAX_PRICE_PER_UNIT, }; +use crate::compilation::compile_contract_class; use crate::config::StatefulTransactionValidatorConfig; use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult}; -use crate::gateway::compile_contract_class; use crate::state_reader_test_utils::{ local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account, TestStateReaderFactory,