Skip to content

Commit

Permalink
chore: create struct for gateway compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
ArniStarkware committed Jul 10, 2024
1 parent 025f4c3 commit 4cbcaea
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 68 deletions.
114 changes: 63 additions & 51 deletions crates/gateway/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,51 +11,81 @@ use starknet_sierra_compile::compile::compile_sierra_to_casm;
use starknet_sierra_compile::errors::CompilationUtilError;
use starknet_sierra_compile::utils::into_contract_class_for_compilation;

use crate::config::GatewayCompilerConfig;
use crate::errors::{GatewayError, GatewayResult};
use crate::utils::is_subsequence;

#[cfg(test)]
#[path = "compilation_test.rs"]
mod compilation_test;

/// Formats the contract class for compilation, compiles it, and returns the compiled contract class
/// wrapped in a [`ClassInfo`].
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResult<ClassInfo> {
let RPCDeclareTransaction::V3(tx) = declare_tx;
let starknet_api_contract_class = &tx.contract_class;
let cairo_lang_contract_class =
into_contract_class_for_compilation(starknet_api_contract_class);
pub struct GatewayCompiler {
#[allow(dead_code)]
pub config: GatewayCompilerConfig,
}

impl GatewayCompiler {
/// Formats the contract class for compilation, compiles it, and returns the compiled contract
/// class wrapped in a [`ClassInfo`].
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
pub fn compile_contract_class(
&self,
declare_tx: &RPCDeclareTransaction,
) -> GatewayResult<ClassInfo> {
let RPCDeclareTransaction::V3(tx) = declare_tx;
let starknet_api_contract_class = &tx.contract_class;
let cairo_lang_contract_class =
into_contract_class_for_compilation(starknet_api_contract_class);

// Compile Sierra to Casm.
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result?,
Err(_) => {
// TODO(Arni): Log the panic.
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
// Compile Sierra to Casm.
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result?,
Err(_) => {
// TODO(Arni): Log the panic.
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
}
};
self.validate_casm_class(&casm_contract_class)?;

let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}
};
validate_casm_class(&casm_contract_class)?;

let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(
&blockifier_contract_class,
starknet_api_contract_class.sierra_program.len(),
starknet_api_contract_class.abi.len(),
)?;
Ok(class_info)
}

// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(
&blockifier_contract_class,
starknet_api_contract_class.sierra_program.len(),
starknet_api_contract_class.abi.len(),
)?;
Ok(class_info)
// TODO(Arni): Add test.
fn validate_casm_class(&self, contract_class: &CasmContractClass) -> Result<(), GatewayError> {
let CasmContractEntryPoints { external, l1_handler, constructor } =
&contract_class.entry_points_by_type;
let entry_points_iterator =
external.iter().chain(l1_handler.iter()).chain(constructor.iter());

for entry_point in entry_points_iterator {
let builtins = &entry_point.builtins;
if !is_subsequence(builtins, supported_builtins()) {
return Err(GatewayError::UnsupportedBuiltins {
builtins: builtins.clone(),
supported_builtins: supported_builtins().to_vec(),
});
}
}
Ok(())
}
}

// TODO(Arni): Add to a config.
Expand All @@ -70,21 +100,3 @@ fn supported_builtins() -> &'static Vec<String> {
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
})
}

// TODO(Arni): Add test.
fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> {
let CasmContractEntryPoints { external, l1_handler, constructor } =
&contract_class.entry_points_by_type;
let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter());

for entry_point in entry_points_iterator {
let builtins = &entry_point.builtins;
if !is_subsequence(builtins, supported_builtins()) {
return Err(GatewayError::UnsupportedBuiltins {
builtins: builtins.clone(),
supported_builtins: supported_builtins().to_vec(),
});
}
}
Ok(())
}
26 changes: 16 additions & 10 deletions crates/gateway/src/compilation_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@ use assert_matches::assert_matches;
use blockifier::execution::contract_class::ContractClass;
use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError;
use mempool_test_utils::starknet_api_test_utils::declare_tx;
use rstest::{fixture, rstest};
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_sierra_compile::errors::CompilationUtilError;

use crate::compilation::compile_contract_class;
use crate::compilation::GatewayCompiler;
use crate::errors::GatewayError;

#[test]
fn test_compile_contract_class_compiled_class_hash_missmatch() {
#[fixture]
fn gateway_compiler() -> GatewayCompiler {
GatewayCompiler { config: Default::default() }
}

#[rstest]
fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
Expand All @@ -21,16 +27,16 @@ fn test_compile_contract_class_compiled_class_hash_missmatch() {
tx.compiled_class_hash = supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
let result = gateway_compiler.compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
if supplied == supplied_hash && hash_result == expected_hash_result
);
}

#[test]
fn test_compile_contract_class_bad_sierra() {
#[rstest]
fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
Expand All @@ -39,7 +45,7 @@ fn test_compile_contract_class_bad_sierra() {
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = compile_contract_class(&declare_tx);
let result = gateway_compiler.compile_contract_class(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
Expand All @@ -48,16 +54,16 @@ fn test_compile_contract_class_bad_sierra() {
)
}

#[test]
fn test_compile_contract_class() {
#[rstest]
fn test_compile_contract_class(gateway_compiler: GatewayCompiler) {
let declare_tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(declare_tx) => declare_tx
);
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
let contract_class = &declare_tx_v3.contract_class;

let class_info = compile_contract_class(&declare_tx).unwrap();
let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap();
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
assert_eq!(class_info.abi_length(), contract_class.abi.len());
Expand Down
11 changes: 11 additions & 0 deletions crates/gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub struct GatewayConfig {
pub network_config: GatewayNetworkConfig,
pub stateless_tx_validator_config: StatelessTransactionValidatorConfig,
pub stateful_tx_validator_config: StatefulTransactionValidatorConfig,
pub compiler_config: GatewayCompilerConfig,
}

impl SerializeConfig for GatewayConfig {
Expand All @@ -30,6 +31,7 @@ impl SerializeConfig for GatewayConfig {
self.stateful_tx_validator_config.dump(),
"stateful_tx_validator_config",
),
append_sub_config_name(self.compiler_config.dump(), "compiler_config"),
]
.into_iter()
.flatten()
Expand Down Expand Up @@ -293,3 +295,12 @@ impl StatefulTransactionValidatorConfig {
}
}
}

#[derive(Clone, Debug, Default, Serialize, Deserialize, Validate, PartialEq)]
pub struct GatewayCompilerConfig {}

impl SerializeConfig for GatewayCompilerConfig {
fn dump(&self) -> BTreeMap<ParamPath, SerializedParam> {
BTreeMap::new()
}
}
7 changes: 5 additions & 2 deletions crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use starknet_mempool_types::communication::SharedMempoolClient;
use starknet_mempool_types::mempool_types::{Account, MempoolInput};
use tracing::info;

use crate::compilation::compile_contract_class;
use crate::compilation::GatewayCompiler;
use crate::config::{GatewayConfig, GatewayNetworkConfig, RpcStateReaderConfig};
use crate::errors::{GatewayError, GatewayResult, GatewayRunError};
use crate::rpc_state_reader::RpcStateReaderFactory;
Expand Down Expand Up @@ -120,7 +120,10 @@ fn process_tx(

// Compile Sierra to Casm.
let optional_class_info = match &tx {
RPCTransaction::Declare(declare_tx) => Some(compile_contract_class(declare_tx)?),
RPCTransaction::Declare(declare_tx) => {
let gateway_compiler = GatewayCompiler { config: Default::default() };
Some(gateway_compiler.compile_contract_class(declare_tx)?)
}
_ => None,
};

Expand Down
8 changes: 6 additions & 2 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tokio::sync::mpsc::channel;
use tokio::task;

use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig};
use crate::gateway::{add_tx, compile_contract_class, AppState, SharedMempoolClient};
use crate::gateway::{add_tx, AppState, GatewayCompiler, SharedMempoolClient};
use crate::state_reader_test_utils::{
local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account,
TestStateReaderFactory,
Expand Down Expand Up @@ -110,7 +110,11 @@ async fn to_bytes(res: Response) -> Bytes {

fn calculate_hash(external_tx: &RPCTransaction) -> TransactionHash {
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => Some(compile_contract_class(declare_tx).unwrap()),
RPCTransaction::Declare(declare_tx) => Some(
GatewayCompiler { config: Default::default() }
.compile_contract_class(declare_tx)
.unwrap(),
),
_ => None,
};

Expand Down
7 changes: 5 additions & 2 deletions crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use starknet_api::felt;
use starknet_api::rpc_transaction::RPCTransaction;
use starknet_api::transaction::TransactionHash;

use crate::compilation::compile_contract_class;
use crate::compilation::GatewayCompiler;
use crate::config::StatefulTransactionValidatorConfig;
use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult};
use crate::state_reader_test_utils::{
Expand Down Expand Up @@ -81,7 +81,10 @@ fn test_stateful_tx_validator(
},
};
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => Some(compile_contract_class(declare_tx).unwrap()),
RPCTransaction::Declare(declare_tx) => {
let gateway_compiler = GatewayCompiler { config: Default::default() };
Some(gateway_compiler.compile_contract_class(declare_tx).unwrap())
}
_ => None,
};

Expand Down
8 changes: 7 additions & 1 deletion crates/tests-integration/src/integration_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@ async fn create_gateway_config() -> GatewayConfig {
let socket = get_available_socket().await;
let network_config = GatewayNetworkConfig { ip: socket.ip(), port: socket.port() };
let stateful_tx_validator_config = StatefulTransactionValidatorConfig::create_for_testing();
let gateway_compiler_config = Default::default();

GatewayConfig { network_config, stateless_tx_validator_config, stateful_tx_validator_config }
GatewayConfig {
network_config,
stateless_tx_validator_config,
stateful_tx_validator_config,
compiler_config: gateway_compiler_config,
}
}

pub async fn create_config(rpc_server_addr: SocketAddr) -> MempoolNodeConfig {
Expand Down

0 comments on commit 4cbcaea

Please sign in to comment.