From b12db17267f79c6554c3e57df06f723f7f746bd4 Mon Sep 17 00:00:00 2001 From: Arni Hod Date: Thu, 13 Jun 2024 11:24:12 +0300 Subject: [PATCH] feat: check that the compiled class hash matches the supplied class --- crates/gateway/src/errors.rs | 6 ++ crates/gateway/src/gateway.rs | 11 ++++ crates/gateway/src/gateway_test.rs | 62 ++++++++++++++++++- crates/gateway/src/starknet_api_test_utils.rs | 9 ++- .../stateful_transaction_validator_test.rs | 2 +- crates/test_utils/src/lib.rs | 2 + 6 files changed, 88 insertions(+), 4 deletions(-) diff --git a/crates/gateway/src/errors.rs b/crates/gateway/src/errors.rs index fb465f67..3b65f783 100644 --- a/crates/gateway/src/errors.rs +++ b/crates/gateway/src/errors.rs @@ -7,6 +7,7 @@ use blockifier::transaction::errors::TransactionExecutionError; use cairo_vm::types::errors::program_errors::ProgramError; use serde_json::{Error as SerdeError, Value}; use starknet_api::block::{BlockNumber, GasPrice}; +use starknet_api::core::CompiledClassHash; use starknet_api::transaction::{Resource, ResourceBounds}; use starknet_api::StarknetApiError; use thiserror::Error; @@ -19,6 +20,11 @@ use crate::compiler_version::{VersionId, VersionIdError}; pub enum GatewayError { #[error(transparent)] CompilationError(#[from] starknet_sierra_compile::compile::CompilationUtilError), + #[error( + "The supplied compiled class hash {supplied:?} does not match the hash of the Casm class \ + compiled from the supplied Sierra {hash_result:?}." + )] + CompiledClassHashMismatch { supplied: CompiledClassHash, hash_result: CompiledClassHash }, #[error(transparent)] DeclaredContractClassError(#[from] ContractClassError), #[error(transparent)] diff --git a/crates/gateway/src/gateway.rs b/crates/gateway/src/gateway.rs index 54e3f38c..12464589 100644 --- a/crates/gateway/src/gateway.rs +++ b/crates/gateway/src/gateway.rs @@ -7,6 +7,8 @@ use axum::extract::State; use axum::routing::{get, post}; use axum::{Json, Router}; use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractClassV1}; +use blockifier::execution::execution_utils::felt_to_stark_felt; +use starknet_api::core::CompiledClassHash; use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; use starknet_api::transaction::TransactionHash; use starknet_mempool_types::communication::SharedMempoolClient; @@ -157,6 +159,15 @@ pub fn compile_contract_class(declare_tx: &RPCDeclareTransaction) -> GatewayResu } }; + let hash_result = + CompiledClassHash(felt_to_stark_felt(&casm_contract_class.compiled_class_hash())); + if hash_result != tx.compiled_class_hash { + return Err(GatewayError::CompiledClassHashMismatch { + supplied: tx.compiled_class_hash, + hash_result, + }); + } + // Convert Casm contract class to Starknet contract class directly. let blockifier_contract_class = ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index cd89df0b..7c1fb857 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -1,21 +1,27 @@ use std::sync::Arc; +use assert_matches::assert_matches; use axum::body::{Bytes, HttpBody}; use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use blockifier::context::ChainInfo; +use blockifier::execution::contract_class::ContractClass; use blockifier::test_utils::CairoVersion; +use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError; use rstest::{fixture, rstest}; -use starknet_api::rpc_transaction::RPCTransaction; +use starknet_api::core::CompiledClassHash; +use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; use starknet_api::transaction::TransactionHash; use starknet_mempool::communication::create_mempool_server; use starknet_mempool::mempool::Mempool; use starknet_mempool_types::communication::{MempoolClientImpl, MempoolRequestAndResponseSender}; +use starknet_sierra_compile::compile::CompilationUtilError; use tokio::sync::mpsc::channel; use tokio::task; use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig}; +use crate::errors::GatewayError; use crate::gateway::{add_tx, compile_contract_class, AppState, SharedMempoolClient}; use crate::starknet_api_test_utils::{declare_tx, deploy_account_tx, invoke_tx}; use crate::state_reader_test_utils::{ @@ -103,6 +109,60 @@ async fn test_add_tx( assert_eq!(tx_hash, serde_json::from_slice(response_bytes).unwrap()); } +#[test] +fn test_compile_contract_class_compiled_class_hash_missmatch() { + let mut tx = assert_matches!( + declare_tx(), + RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx + ); + let expected_hash_result = tx.compiled_class_hash; + let supplied_hash = CompiledClassHash::default(); + + tx.compiled_class_hash = supplied_hash; + let declare_tx = RPCDeclareTransaction::V3(tx); + + let result = compile_contract_class(&declare_tx); + assert_matches!( + result.unwrap_err(), + GatewayError::CompiledClassHashMismatch { supplied, hash_result } + if supplied == supplied_hash && hash_result == expected_hash_result + ); +} + +#[test] +fn test_compile_contract_class_bad_sierra() { + let mut tx = assert_matches!( + declare_tx(), + RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx + ); + // Truncate the sierra program to trigger an error. + tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec(); + let declare_tx = RPCDeclareTransaction::V3(tx); + + let result = compile_contract_class(&declare_tx); + assert_matches!( + result.unwrap_err(), + GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError( + AllowedLibfuncsError::SierraProgramError + )) + ) +} + +#[test] +fn test_compile_contract_class() { + let declare_tx = assert_matches!( + declare_tx(), + RPCTransaction::Declare(declare_tx) => declare_tx + ); + let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; + let contract_class = &declare_tx_v3.contract_class; + + let class_info = compile_contract_class(&declare_tx).unwrap(); + assert_matches!(class_info.contract_class(), ContractClass::V1(_)); + assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); + assert_eq!(class_info.abi_length(), contract_class.abi.len()); +} + async fn to_bytes(res: Response) -> Bytes { res.into_body().collect().await.unwrap().to_bytes() } diff --git a/crates/gateway/src/starknet_api_test_utils.rs b/crates/gateway/src/starknet_api_test_utils.rs index 92ba7843..10076c9a 100644 --- a/crates/gateway/src/starknet_api_test_utils.rs +++ b/crates/gateway/src/starknet_api_test_utils.rs @@ -18,7 +18,10 @@ use starknet_api::transaction::{ TransactionSignature, TransactionVersion, }; use starknet_api::{calldata, stark_felt}; -use test_utils::{get_absolute_path, CONTRACT_CLASS_FILE, TEST_FILES_FOLDER}; +use test_utils::{ + get_absolute_path, COMPILED_CLASS_HASH_OF_CONTRACT_CLASS, CONTRACT_CLASS_FILE, + TEST_FILES_FOLDER, +}; use crate::{declare_tx_args, deploy_account_tx_args, invoke_tx_args}; @@ -97,6 +100,7 @@ pub fn declare_tx() -> RPCTransaction { env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir."); let json_file_path = Path::new(CONTRACT_CLASS_FILE); let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap(); + let compiled_class_hash = CompiledClassHash(stark_felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS)); let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1); let account_address = account_contract.get_instance_address(0); @@ -108,7 +112,8 @@ pub fn declare_tx() -> RPCTransaction { sender_address: account_address, resource_bounds: executable_resource_bounds_mapping(), nonce, - contract_class + class_hash: compiled_class_hash, + contract_class, )) } diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index f8ead84f..5553922a 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -46,7 +46,7 @@ use crate::stateful_transaction_validator::StatefulTransactionValidator; declare_tx(), local_test_state_reader_factory(CairoVersion::Cairo1, false), Ok(TransactionHash(StarkFelt::try_from( - "0x0278ed2700d5a30254a6b895d4e1140438d7d1a3b2b2ce0c096a9d5ee1c61f39" + "0x02da54b89e00d2e201f8e3ed2bcc715a69e89aefdce88aff2d2facb8dec55c0a" ).unwrap())) )] #[case::invalid_tx( diff --git a/crates/test_utils/src/lib.rs b/crates/test_utils/src/lib.rs index 23634bc6..571ec7f3 100644 --- a/crates/test_utils/src/lib.rs +++ b/crates/test_utils/src/lib.rs @@ -3,6 +3,8 @@ use std::path::{Path, PathBuf}; pub const TEST_FILES_FOLDER: &str = "crates/test_utils/test_files"; pub const CONTRACT_CLASS_FILE: &str = "contract_class.json"; +pub const COMPILED_CLASS_HASH_OF_CONTRACT_CLASS: &str = + "0x01e4f1248860f32c336f93f2595099aaa4959be515e40b75472709ef5243ae17"; /// Returns the absolute path from the project root. pub fn get_absolute_path(relative_path: &str) -> PathBuf {