diff --git a/crates/cheatnet/src/forking/cache.rs b/crates/cheatnet/src/forking/cache.rs index c5222dbf26..a15a84a4ca 100644 --- a/crates/cheatnet/src/forking/cache.rs +++ b/crates/cheatnet/src/forking/cache.rs @@ -81,7 +81,7 @@ impl ToString for ForkCacheContent { } } -#[derive(Debug)] +#[derive(Debug, Default)] pub struct ForkCache { fork_cache_content: ForkCacheContent, cache_file: Utf8PathBuf, diff --git a/crates/cheatnet/src/forking/state.rs b/crates/cheatnet/src/forking/state.rs index c818344d52..18e6003878 100644 --- a/crates/cheatnet/src/forking/state.rs +++ b/crates/cheatnet/src/forking/state.rs @@ -14,8 +14,8 @@ use flate2::read::GzDecoder; use num_bigint::BigUint; use runtime::starknet::context::SerializableGasPrices; use starknet::core::types::{ - BlockId, ContractClass as ContractClassStarknet, FieldElement, MaybePendingBlockWithTxHashes, - StarknetError, + BlockId, ContractClass as ContractClassStarknet, ContractStorageDiffItem, FieldElement, + MaybePendingBlockWithTxHashes, StarknetError, TransactionTrace, }; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::{JsonRpcClient, Provider, ProviderError}; @@ -37,18 +37,38 @@ pub struct ForkStateReader { client: JsonRpcClient, block_number: BlockNumber, cache: RefCell, + storage_diff: HashMap>, } impl ForkStateReader { - pub fn new(url: Url, block_number: BlockNumber, cache_dir: &str) -> Result { - Ok(ForkStateReader { - cache: RefCell::new( - ForkCache::load_or_new(&url, block_number, cache_dir) - .context("Could not create fork cache")?, - ), - client: JsonRpcClient::new(HttpTransport::new(url)), - block_number, - }) + pub fn new( + url: Url, + block_number: BlockNumber, + transaction_index: usize, + cache_dir: &str, + ) -> Result { + let fork_cache = ForkCache::load_or_new(&url, block_number, cache_dir) + .context("Could not load fork cache")?; + let mut fork_state_reader = ForkStateReader { + cache: RefCell::new(fork_cache), + client: JsonRpcClient::new(HttpTransport::new(url.clone())), + block_number: BlockNumber(block_number.0), + storage_diff: HashMap::new(), + }; + + let real_block_id = BlockId::Number(block_number.0 + 1); + let tx_in_block = fork_state_reader + .get_block_transaction_count(real_block_id) + .context("Unable to get block transactions count from node provider")?; + if tx_in_block > 1 { + //Get over all transaction till transaction_index and store new storage values in + //storage_diff hash map + fork_state_reader + .get_transactions_storage_diff(real_block_id, transaction_index) + .context("Unable to get trace block transactions from node provider")?; + } + // Return the initialized and state updated ForkStateReader + Ok(fork_state_reader) } fn block_id(&self) -> BlockId { @@ -64,6 +84,83 @@ impl ForkStateReader { .get_compiled_contract_class(&class_hash) .cloned() } + + pub fn get_block_transaction_count(&self, block_id: BlockId) -> Result { + let result = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current() + .block_on(self.client.get_block_transaction_count(block_id)) + }) + .map_err(|err| { + StateError::StateReadError(format!( + "Unable to get block transactions count from fork ({err})" + )) + })?; + + Ok(result) + } + + pub fn get_transactions_storage_diff( + &mut self, + block_id: BlockId, + transaction_index: usize, + ) -> Result<(), StateError> { + let results = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current() + .block_on(self.client.trace_block_transactions(block_id)) + }) + .map_err(|err| { + StateError::StateReadError(format!( + "Unable to get trace block transactions from fork ({err})" + )) + })?; + + for (index, result) in results.into_iter().enumerate() { + if index == transaction_index { + break; + } + match &result.trace_root { + TransactionTrace::Invoke(invoke_trace) => { + if let Some(state_diff) = &invoke_trace.state_diff { + let contract_storage_diff = &state_diff.storage_diffs; + self.collect_storage_diffs(contract_storage_diff); + } + } + TransactionTrace::Declare(declare_trace) => { + if let Some(state_diff) = &declare_trace.state_diff { + let contract_storage_diff = &state_diff.storage_diffs; + self.collect_storage_diffs(contract_storage_diff); + } + } + TransactionTrace::DeployAccount(deploy_trace) => { + if let Some(state_diff) = &deploy_trace.state_diff { + let contract_storage_diff = &state_diff.storage_diffs; + self.collect_storage_diffs(contract_storage_diff); + } + } + TransactionTrace::L1Handler(l1handler_trace) => { + if let Some(state_diff) = &l1handler_trace.state_diff { + let contract_storage_diff = &state_diff.storage_diffs; + self.collect_storage_diffs(contract_storage_diff); + } + } + } + } + + Ok(()) + } + + fn collect_storage_diffs(&mut self, storage_diffs: &[ContractStorageDiffItem]) { + for storage_diff in storage_diffs.iter() { + let contract_address: ContractAddress = + ContractAddress::try_from(StarkFelt::from(storage_diff.address)).unwrap(); + let contract_storage = self.storage_diff.entry(contract_address).or_default(); + for storage_entry in storage_diff.storage_entries.iter() { + let key = StorageKey::try_from(StarkFelt::from(storage_entry.key)).unwrap(); + let new_value: StarkFelt = storage_entry.value.into_(); + contract_storage.insert(key, new_value); + } + } + } } #[allow(clippy::needless_pass_by_value)] @@ -121,10 +218,23 @@ impl StateReader for ForkStateReader { contract_address: ContractAddress, key: StorageKey, ) -> StateResult { + // First check cache if let Some(cache_hit) = self.cache.borrow().get_storage_at(&contract_address, &key) { return Ok(cache_hit); } + // Second check the storage_diff hash map + if let Some(contract_updates) = self.storage_diff.get(&contract_address) { + if let Some(&value) = contract_updates.get(&key) { + self.cache + .borrow_mut() + .cache_get_storage_at(contract_address, key, value); + + return Ok(value); + } + } + + // Third ping provider match tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(self.client.get_storage_at( FieldElement::from_(contract_address), @@ -132,6 +242,7 @@ impl StateReader for ForkStateReader { self.block_id(), )) }) { + // match self.runtime.block_on(self.client.get_storage_at( // FieldElement::from_(contract_address), // FieldElement::from_(*key.0.key()),