From bb1bf2e3e493a5fe9aef6016b0cb48a221bb9f1a Mon Sep 17 00:00:00 2001 From: eitanm-starkware <144585602+eitanm-starkware@users.noreply.github.com> Date: Wed, 15 May 2024 11:30:36 +0300 Subject: [PATCH] refactor: change chainid from struct to enum (#250) --- src/core.rs | 53 ++++++++++++++++++++++++++--- src/transaction_hash.rs | 74 ++++++++++++++++++----------------------- 2 files changed, 82 insertions(+), 45 deletions(-) diff --git a/src/core.rs b/src/core.rs index 52c969f..d7bc8f9 100644 --- a/src/core.rs +++ b/src/core.rs @@ -2,12 +2,13 @@ #[path = "core_test.rs"] mod core_test; +use core::fmt::Display; use std::fmt::Debug; use derive_more::Display; use once_cell::sync::Lazy; use primitive_types::H160; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use starknet_crypto::FieldElement; use crate::crypto::utils::PublicKey; @@ -17,12 +18,56 @@ use crate::transaction::{Calldata, ContractAddressSalt}; use crate::{impl_from_through_intermediate, StarknetApiError}; /// A chain id. -#[derive(Clone, Debug, Display, Eq, PartialEq, Hash, Deserialize, Serialize, PartialOrd, Ord)] -pub struct ChainId(pub String); +#[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord)] +pub enum ChainId { + Mainnet, + Sepolia, + IntegrationSepolia, + Other(String), +} + +impl Serialize for ChainId { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for ChainId { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Ok(ChainId::from(s)) + } +} +impl From for ChainId { + fn from(s: String) -> Self { + match s.as_ref() { + "SN_MAIN" => ChainId::Mainnet, + "SN_SEPOLIA" => ChainId::Sepolia, + "SN_INTEGRATION_SEPOLIA" => ChainId::IntegrationSepolia, + other => ChainId::Other(other.to_owned()), + } + } +} +impl Display for ChainId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ChainId::Mainnet => write!(f, "SN_MAIN"), + ChainId::Sepolia => write!(f, "SN_SEPOLIA"), + ChainId::IntegrationSepolia => write!(f, "SN_INTEGRATION_SEPOLIA"), + ChainId::Other(ref s) => write!(f, "{}", s), + } + } +} impl ChainId { pub fn as_hex(&self) -> String { - format!("0x{}", hex::encode(&self.0)) + format!("0x{}", hex::encode(self.to_string())) } } diff --git a/src/transaction_hash.rs b/src/transaction_hash.rs index 622d888..52b11db 100644 --- a/src/transaction_hash.rs +++ b/src/transaction_hash.rs @@ -104,40 +104,32 @@ fn get_deprecated_transaction_hashes( transaction: &Transaction, transaction_version: &TransactionVersion, ) -> Result, StarknetApiError> { - Ok( - if chain_id == &ChainId("SN_MAIN".to_string()) - && block_number > &MAINNET_TRANSACTION_HASH_WITH_VERSION - { - vec![] - } else { - match transaction { - Transaction::Declare(_) => vec![], - Transaction::Deploy(deploy) => { - vec![get_deprecated_deploy_transaction_hash( - deploy, + Ok(if chain_id == &ChainId::Mainnet && block_number > &MAINNET_TRANSACTION_HASH_WITH_VERSION { + vec![] + } else { + match transaction { + Transaction::Declare(_) => vec![], + Transaction::Deploy(deploy) => { + vec![get_deprecated_deploy_transaction_hash(deploy, chain_id, transaction_version)?] + } + Transaction::DeployAccount(_) => vec![], + Transaction::Invoke(invoke) => match invoke { + InvokeTransaction::V0(invoke_v0) => { + vec![get_deprecated_invoke_transaction_v0_hash( + invoke_v0, chain_id, transaction_version, )?] } - Transaction::DeployAccount(_) => vec![], - Transaction::Invoke(invoke) => match invoke { - InvokeTransaction::V0(invoke_v0) => { - vec![get_deprecated_invoke_transaction_v0_hash( - invoke_v0, - chain_id, - transaction_version, - )?] - } - InvokeTransaction::V1(_) | InvokeTransaction::V3(_) => vec![], - }, - Transaction::L1Handler(l1_handler) => get_deprecated_l1_handler_transaction_hashes( - l1_handler, - chain_id, - transaction_version, - )?, - } - }, - ) + InvokeTransaction::V1(_) | InvokeTransaction::V3(_) => vec![], + }, + Transaction::L1Handler(l1_handler) => get_deprecated_l1_handler_transaction_hashes( + l1_handler, + chain_id, + transaction_version, + )?, + } + }) } /// Validates the hash of a starknet transaction. @@ -274,7 +266,7 @@ fn get_common_deploy_transaction_hash( None } }) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .get_pedersen_hash(), )) } @@ -309,7 +301,7 @@ fn get_common_invoke_transaction_v0_hash( .chain(&transaction.entry_point_selector.0) .chain(&HashChain::new().chain_iter(transaction.calldata.0.iter()).get_pedersen_hash()) .chain_if_fn(|| if !is_deprecated { Some(transaction.max_fee.0.into()) } else { None }) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .get_pedersen_hash(), )) } @@ -327,7 +319,7 @@ pub(crate) fn get_invoke_transaction_v1_hash( .chain(&StarkFelt::ZERO) // No entry point selector in invoke transaction. .chain(&HashChain::new().chain_iter(transaction.calldata.0.iter()).get_pedersen_hash()) .chain(&transaction.max_fee.0.into()) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .chain(&transaction.nonce.0) .get_pedersen_hash(), )) @@ -359,7 +351,7 @@ pub(crate) fn get_invoke_transaction_v3_hash( .chain(transaction.sender_address.0.key()) .chain(&tip_resource_bounds_hash) .chain(&paymaster_data_hash) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .chain(&transaction.nonce.0) .chain(&data_availability_mode) .chain(&account_deployment_data_hash) @@ -442,7 +434,7 @@ fn get_common_l1_handler_transaction_hash( None } }) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .chain_if_fn(|| { if version > L1HandlerVersions::AsInvoke { Some(transaction.nonce.0) @@ -467,7 +459,7 @@ pub(crate) fn get_declare_transaction_v0_hash( .chain(&StarkFelt::ZERO) // No entry point selector in declare transaction. .chain(&HashChain::new().get_pedersen_hash()) .chain(&transaction.max_fee.0.into()) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .chain(&transaction.class_hash.0) .get_pedersen_hash(), )) @@ -486,7 +478,7 @@ pub(crate) fn get_declare_transaction_v1_hash( .chain(&StarkFelt::ZERO) // No entry point selector in declare transaction. .chain(&HashChain::new().chain(&transaction.class_hash.0).get_pedersen_hash()) .chain(&transaction.max_fee.0.into()) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .chain(&transaction.nonce.0) .get_pedersen_hash(), )) @@ -505,7 +497,7 @@ pub(crate) fn get_declare_transaction_v2_hash( .chain(&StarkFelt::ZERO) // No entry point selector in declare transaction. .chain(&HashChain::new().chain(&transaction.class_hash.0).get_pedersen_hash()) .chain(&transaction.max_fee.0.into()) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .chain(&transaction.nonce.0) .chain(&transaction.compiled_class_hash.0) .get_pedersen_hash(), @@ -536,7 +528,7 @@ pub(crate) fn get_declare_transaction_v3_hash( .chain(transaction.sender_address.0.key()) .chain(&tip_resource_bounds_hash) .chain(&paymaster_data_hash) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .chain(&transaction.nonce.0) .chain(&data_availability_mode) .chain(&account_deployment_data_hash) @@ -572,7 +564,7 @@ pub(crate) fn get_deploy_account_transaction_v1_hash( .chain(&StarkFelt::ZERO) // No entry point selector in deploy account transaction. .chain(&calldata_hash) .chain(&transaction.max_fee.0.into()) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .chain(&transaction.nonce.0) .get_pedersen_hash(), )) @@ -607,7 +599,7 @@ pub(crate) fn get_deploy_account_transaction_v3_hash( .chain(contract_address.0.key()) .chain(&tip_resource_bounds_hash) .chain(&paymaster_data_hash) - .chain(&ascii_as_felt(chain_id.0.as_str())?) + .chain(&ascii_as_felt(chain_id.to_string().as_str())?) .chain(&data_availability_mode) .chain(&transaction.nonce.0) .chain(&constructor_calldata_hash)