Skip to content

Commit

Permalink
feat(gateway): return GatewaySpecErrr to match the SN spec
Browse files Browse the repository at this point in the history
  • Loading branch information
yair-starkware committed Jul 24, 2024
1 parent 56c44a3 commit d0e9d09
Show file tree
Hide file tree
Showing 13 changed files with 185 additions and 118 deletions.
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,6 @@ tokio = { version = "1.37.0", features = ["full"] }
tokio-test = "0.4.4"
tracing = "0.1.37"
tracing-subscriber = "0.3.16"
tracing-test = "0.2"
url = "2.5.0"
validator = "0.12"
1 change: 1 addition & 0 deletions crates/gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ num-bigint.workspace = true
pretty_assertions.workspace = true
rstest.workspace = true
starknet_mempool = { path = "../mempool", version = "0.0" }
tracing-test.workspace = true
43 changes: 30 additions & 13 deletions crates/gateway/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContr
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::RPCDeclareTransaction;
use starknet_sierra_compile::compile::compile_sierra_to_casm;
use starknet_sierra_compile::errors::CompilationUtilError;
use starknet_sierra_compile::utils::into_contract_class_for_compilation;
use tracing::{debug, error};

use crate::config::GatewayCompilerConfig;
use crate::errors::{GatewayError, GatewayResult};
use crate::errors::{GatewayResult, GatewaySpecError};

#[cfg(test)]
#[path = "compilation_test.rs"]
Expand Down Expand Up @@ -39,22 +39,38 @@ impl GatewayCompiler {

validate_compiled_class_hash(&casm_contract_class, &tx.compiled_class_hash)?;

Ok(ClassInfo::new(
&ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?),
ClassInfo::new(
&ContractClass::V1(ContractClassV1::try_from(casm_contract_class).map_err(|e| {
error!("Failed to convert CasmContractClass to Blockifier ContractClass: {:?}", e);
GatewaySpecError::UnexpectedError("Internal server error.".to_owned())
})?),
rpc_contract_class.sierra_program.len(),
rpc_contract_class.abi.len(),
)?)
)
.map_err(|e| {
error!("Failed to convert Blockifier ContractClass to Blockifier ClassInfo: {:?}", e);
GatewaySpecError::UnexpectedError("Internal server error.".to_owned())
})
}

// TODO(Arni): Pass the compilation args from the config.
fn compile(
&self,
cairo_lang_contract_class: CairoLangContractClass,
) -> Result<CasmContractClass, GatewayError> {
) -> GatewayResult<CasmContractClass> {
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class =
catch_unwind_result.map_err(|_| CompilationUtilError::CompilationPanic)??;
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result.map_err(|e| {
debug!("Compilation failed: {:?}", e);
GatewaySpecError::CompilationFailed
})?,
Err(_panicked_compilation) => {
// TODO(Arni): Log the panic.
error!("Compilation panicked.");
return Err(GatewaySpecError::UnexpectedError("Internal server error.".to_owned()));
}
};

Ok(casm_contract_class)
}
Expand All @@ -65,13 +81,14 @@ impl GatewayCompiler {
fn validate_compiled_class_hash(
casm_contract_class: &CasmContractClass,
supplied_compiled_class_hash: &CompiledClassHash,
) -> Result<(), GatewayError> {
) -> GatewayResult<()> {
let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash());
if compiled_class_hash != *supplied_compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: *supplied_compiled_class_hash,
hash_result: compiled_class_hash,
});
debug!(
"Compiled class hash mismatch. Supplied: {:?}, Hash result: {:?}",
supplied_compiled_class_hash, compiled_class_hash
);
return Err(GatewaySpecError::CompiledClassHashMismatch);
}
Ok(())
}
33 changes: 19 additions & 14 deletions crates/gateway/src/compilation_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ use rstest::{fixture, rstest};
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_sierra_compile::errors::CompilationUtilError;
use tracing_test::traced_test;

use crate::compilation::GatewayCompiler;
use crate::errors::GatewayError;
use crate::errors::GatewaySpecError;

#[fixture]
fn gateway_compiler() -> GatewayCompiler {
GatewayCompiler { config: Default::default() }
}

// TODO(Arni): Redesign this test once the compiler is passed with dependancy injection.
#[traced_test]
#[rstest]
fn test_compile_contract_class_compiled_class_hash_mismatch(gateway_compiler: GatewayCompiler) {
let mut tx = assert_matches!(
Expand All @@ -27,14 +29,18 @@ fn test_compile_contract_class_compiled_class_hash_mismatch(gateway_compiler: Ga
tx.compiled_class_hash = wrong_supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = gateway_compiler.process_declare_tx(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
if supplied == wrong_supplied_hash && hash_result == expected_hash
);
let err = gateway_compiler.process_declare_tx(&declare_tx).unwrap_err();
assert_eq!(err, GatewaySpecError::CompiledClassHashMismatch);
assert!(logs_contain(
format!(
"Compiled class hash mismatch. Supplied: {:?}, Hash result: {:?}",
wrong_supplied_hash, expected_hash
)
.as_str()
));
}

#[traced_test]
#[rstest]
fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
let mut tx = assert_matches!(
Expand All @@ -45,13 +51,12 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);

let result = gateway_compiler.process_declare_tx(&declare_tx);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
AllowedLibfuncsError::SierraProgramError
))
)
let err = gateway_compiler.process_declare_tx(&declare_tx).unwrap_err();
assert_eq!(err, GatewaySpecError::CompilationFailed);

let expected_compilation_error =
CompilationUtilError::AllowedLibfuncsError(AllowedLibfuncsError::SierraProgramError);
assert!(logs_contain(format!("Compilation failed: {:?}", expected_compilation_error).as_str()));
}

#[rstest]
Expand Down
90 changes: 37 additions & 53 deletions crates/gateway/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,27 @@
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use blockifier::blockifier::stateful_validator::StatefulValidatorError;
use blockifier::execution::errors::ContractClassError;
use blockifier::state::errors::StateError;
use blockifier::transaction::errors::TransactionExecutionError;
use cairo_vm::types::errors::program_errors::ProgramError;
use enum_assoc::Assoc;
use serde::Serialize;
use serde_json::{Error as SerdeError, Value};
use starknet_api::block::GasPrice;
use starknet_api::core::CompiledClassHash;
use starknet_api::transaction::{Resource, ResourceBounds};
use starknet_api::StarknetApiError;
use starknet_sierra_compile::errors::CompilationUtilError;
use strum::EnumIter;
use thiserror::Error;
use tokio::task::JoinError;

use crate::compiler_version::{VersionId, VersionIdError};

/// Errors directed towards the end-user, as a result of gateway requests.
#[derive(Debug, Error)]
pub enum GatewayError {
#[error(transparent)]
CompilationError(#[from] CompilationUtilError),
#[error(
"The supplied compiled class hash {supplied:?} does not match the hash of the Casm class \
compiled from the supplied Sierra {hash_result:?}"
)]
CompiledClassHashMismatch { supplied: CompiledClassHash, hash_result: CompiledClassHash },
#[error(transparent)]
DeclaredContractClassError(#[from] ContractClassError),
#[error(transparent)]
DeclaredContractProgramError(#[from] ProgramError),
#[error("Internal server error: {0}")]
InternalServerError(#[from] JoinError),
#[error("Error sending message: {0}")]
MessageSendError(String),
#[error(transparent)]
StatefulTransactionValidatorError(#[from] StatefulTransactionValidatorError),
#[error(transparent)]
StatelessTransactionValidatorError(#[from] StatelessTransactionValidatorError),
#[error("{builtins:?} is not a subsquence of {supported_builtins:?}")]
UnsupportedBuiltins { builtins: Vec<String>, supported_builtins: Vec<String> },
}

pub type GatewayResult<T> = Result<T, GatewayError>;
pub type GatewayResult<T> = Result<T, GatewaySpecError>;

impl IntoResponse for GatewayError {
// TODO(Arni, 1/5/2024): Be more fine tuned about the error response. Not all Gateway errors
// are internal server errors.
impl IntoResponse for GatewaySpecError {
fn into_response(self) -> Response {
let body = self.to_string();
(StatusCode::INTERNAL_SERVER_ERROR, body).into_response()
(StatusCode::from_u16(self.code()).expect("Expecting a valid error code"), body)
.into_response()
}
}

#[derive(Error, Debug, Assoc, Clone, EnumIter, Serialize)]
#[derive(Error, Debug, Assoc, Clone, EnumIter, Serialize, PartialEq)]
#[func(pub fn code(&self) -> u16)]
#[func(pub fn data(&self) -> Option<&str>)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
Expand Down Expand Up @@ -146,21 +111,40 @@ pub enum StatelessTransactionValidatorError {
EntryPointsNotUniquelySorted,
}

pub type StatelessTransactionValidatorResult<T> = Result<T, StatelessTransactionValidatorError>;

#[derive(Debug, Error)]
pub enum StatefulTransactionValidatorError {
#[error(transparent)]
StarknetApiError(#[from] StarknetApiError),
#[error(transparent)]
StateError(#[from] StateError),
#[error(transparent)]
StatefulValidatorError(#[from] StatefulValidatorError),
#[error(transparent)]
TransactionExecutionError(#[from] TransactionExecutionError),
impl From<StatelessTransactionValidatorError> for GatewaySpecError {
fn from(e: StatelessTransactionValidatorError) -> Self {
match e {
StatelessTransactionValidatorError::ZeroResourceBounds { .. } => {
GatewaySpecError::ValidationFailure(e.to_string())
}
StatelessTransactionValidatorError::CalldataTooLong { .. } => {
GatewaySpecError::ValidationFailure(e.to_string())
}
StatelessTransactionValidatorError::SignatureTooLong { .. } => {
GatewaySpecError::ValidationFailure(e.to_string())
}
StatelessTransactionValidatorError::InvalidSierraVersion(..) => {
GatewaySpecError::ValidationFailure(e.to_string())
}
StatelessTransactionValidatorError::UnsupportedSierraVersion { .. } => {
GatewaySpecError::UnsupportedContractClassVersion
}
StatelessTransactionValidatorError::BytecodeSizeTooLarge { .. } => {
GatewaySpecError::ContractClassSizeIsTooLarge
}
StatelessTransactionValidatorError::ContractClassObjectSizeTooLarge { .. } => {
GatewaySpecError::ContractClassSizeIsTooLarge
}
StatelessTransactionValidatorError::EntryPointsNotUniquelySorted => {
GatewaySpecError::ValidationFailure(e.to_string())
}
}
}
}

pub type StatefulTransactionValidatorResult<T> = Result<T, StatefulTransactionValidatorError>;
pub type StatelessTransactionValidatorResult<T> = Result<T, StatelessTransactionValidatorError>;

pub type StatefulTransactionValidatorResult<T> = Result<T, GatewaySpecError>;

/// Errors originating from `[`Gateway::run`]` command, to be handled by infrastructure code.
#[derive(Debug, Error)]
Expand Down
19 changes: 11 additions & 8 deletions crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ use starknet_api::transaction::TransactionHash;
use starknet_mempool_infra::component_runner::{ComponentStartError, ComponentStarter};
use starknet_mempool_types::communication::SharedMempoolClient;
use starknet_mempool_types::mempool_types::{Account, MempoolInput};
use tracing::{info, instrument};
use tracing::{error, info, instrument};

use crate::compilation::GatewayCompiler;
use crate::config::{GatewayConfig, GatewayNetworkConfig, RpcStateReaderConfig};
use crate::errors::{GatewayError, GatewayResult, GatewayRunError};
use crate::errors::{GatewayResult, GatewayRunError, GatewaySpecError};
use crate::rpc_state_reader::RpcStateReaderFactory;
use crate::state_reader::StateReaderFactory;
use crate::stateful_transaction_validator::StatefulTransactionValidator;
Expand Down Expand Up @@ -100,15 +100,18 @@ async fn add_tx(
tx,
)
})
.await??;
.await
.map_err(|join_err| {
error!("Failed to process tx: {}", join_err);
GatewaySpecError::UnexpectedError("Internal server error".to_owned())
})??;

let tx_hash = mempool_input.tx.tx_hash;

app_state
.mempool_client
.add_tx(mempool_input)
.await
.map_err(|e| GatewayError::MessageSendError(e.to_string()))?;
app_state.mempool_client.add_tx(mempool_input).await.map_err(|e| {
error!("Failed to send tx to mempool: {}", e);
GatewaySpecError::UnexpectedError("Internal server error".to_owned())
})?;
// TODO: Also return `ContractAddress` for deploy and `ClassHash` for Declare.
Ok(Json(tx_hash))
}
Expand Down
18 changes: 14 additions & 4 deletions crates/gateway/src/stateful_transaction_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ use starknet_api::core::{ContractAddress, Nonce};
use starknet_api::rpc_transaction::{RPCInvokeTransaction, RPCTransaction};
use starknet_api::transaction::TransactionHash;
use starknet_types_core::felt::Felt;
use tracing::error;

use crate::config::StatefulTransactionValidatorConfig;
use crate::errors::StatefulTransactionValidatorResult;
use crate::errors::{GatewaySpecError, StatefulTransactionValidatorResult};
use crate::state_reader::{MempoolStateReader, StateReaderFactory};
use crate::utils::{external_tx_to_account_tx, get_sender_address, get_tx_hash};

Expand Down Expand Up @@ -75,9 +76,15 @@ impl StatefulTransactionValidator {
&self.config.chain_info.chain_id,
)?;
let tx_hash = get_tx_hash(&account_tx);
let account_nonce = validator.get_nonce(get_sender_address(external_tx))?;
let sender_address = get_sender_address(external_tx);
let account_nonce = validator.get_nonce(sender_address).map_err(|e| {
error!("Failed to get nonce for sender address {}: {}", sender_address, e);
GatewaySpecError::UnexpectedError("Internal server error.".to_owned())
})?;
let skip_validate = skip_stateful_validations(external_tx, account_nonce);
validator.validate(account_tx, skip_validate)?;
validator
.validate(account_tx, skip_validate)
.map_err(|err| GatewaySpecError::ValidationFailure(err.to_string()))?;
Ok(tx_hash)
}

Expand Down Expand Up @@ -128,5 +135,8 @@ pub fn get_latest_block_info(
state_reader_factory: &dyn StateReaderFactory,
) -> StatefulTransactionValidatorResult<BlockInfo> {
let state_reader = state_reader_factory.get_state_reader_from_latest_block();
Ok(state_reader.get_block_info()?)
state_reader.get_block_info().map_err(|e| {
error!("Failed to get latest block info: {}", e);
GatewaySpecError::UnexpectedError("Internal server error.".to_owned())
})
}
Loading

0 comments on commit d0e9d09

Please sign in to comment.