Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marijam/block storage values update #2

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/cheatnet/src/forking/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl ToString for ForkCacheContent {
}
}

#[derive(Debug)]
#[derive(Debug, Default)]
pub struct ForkCache {
fork_cache_content: ForkCacheContent,
cache_file: Utf8PathBuf,
Expand Down
133 changes: 122 additions & 11 deletions crates/cheatnet/src/forking/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use flate2::read::GzDecoder;
use num_bigint::BigUint;
use runtime::starknet::context::SerializableGasPrices;
use starknet::core::types::{
BlockId, ContractClass as ContractClassStarknet, FieldElement, MaybePendingBlockWithTxHashes,
StarknetError,
BlockId, ContractClass as ContractClassStarknet, ContractStorageDiffItem, FieldElement,
MaybePendingBlockWithTxHashes, StarknetError, TransactionTrace,
};
use starknet::providers::jsonrpc::HttpTransport;
use starknet::providers::{JsonRpcClient, Provider, ProviderError};
Expand All @@ -37,18 +37,38 @@ pub struct ForkStateReader {
client: JsonRpcClient<HttpTransport>,
block_number: BlockNumber,
cache: RefCell<ForkCache>,
storage_diff: HashMap<ContractAddress, HashMap<StorageKey, StarkFelt>>,
}

impl ForkStateReader {
pub fn new(url: Url, block_number: BlockNumber, cache_dir: &str) -> Result<Self> {
Ok(ForkStateReader {
cache: RefCell::new(
ForkCache::load_or_new(&url, block_number, cache_dir)
.context("Could not create fork cache")?,
),
client: JsonRpcClient::new(HttpTransport::new(url)),
block_number,
})
pub fn new(
url: Url,
block_number: BlockNumber,
transaction_index: usize,
cache_dir: &str,
) -> Result<Self> {
let fork_cache = ForkCache::load_or_new(&url, block_number, cache_dir)
.context("Could not load fork cache")?;
let mut fork_state_reader = ForkStateReader {
cache: RefCell::new(fork_cache),
client: JsonRpcClient::new(HttpTransport::new(url.clone())),
block_number: BlockNumber(block_number.0),
storage_diff: HashMap::new(),
};

let real_block_id = BlockId::Number(block_number.0 + 1);
let tx_in_block = fork_state_reader
.get_block_transaction_count(real_block_id)
.context("Unable to get block transactions count from node provider")?;
if tx_in_block > 1 {
//Get over all transaction till transaction_index and store new storage values in
//storage_diff hash map
fork_state_reader
.get_transactions_storage_diff(real_block_id, transaction_index)
.context("Unable to get trace block transactions from node provider")?;
}
// Return the initialized and state updated ForkStateReader
Ok(fork_state_reader)
}

fn block_id(&self) -> BlockId {
Expand All @@ -64,6 +84,83 @@ impl ForkStateReader {
.get_compiled_contract_class(&class_hash)
.cloned()
}

pub fn get_block_transaction_count(&self, block_id: BlockId) -> Result<u64, StateError> {
let result = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(self.client.get_block_transaction_count(block_id))
})
.map_err(|err| {
StateError::StateReadError(format!(
"Unable to get block transactions count from fork ({err})"
))
})?;

Ok(result)
}

pub fn get_transactions_storage_diff(
&mut self,
block_id: BlockId,
transaction_index: usize,
) -> Result<(), StateError> {
let results = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(self.client.trace_block_transactions(block_id))
})
.map_err(|err| {
StateError::StateReadError(format!(
"Unable to get trace block transactions from fork ({err})"
))
})?;

for (index, result) in results.into_iter().enumerate() {
if index == transaction_index {
break;
}
match &result.trace_root {
TransactionTrace::Invoke(invoke_trace) => {
if let Some(state_diff) = &invoke_trace.state_diff {
let contract_storage_diff = &state_diff.storage_diffs;
self.collect_storage_diffs(contract_storage_diff);
}
}
TransactionTrace::Declare(declare_trace) => {
if let Some(state_diff) = &declare_trace.state_diff {
let contract_storage_diff = &state_diff.storage_diffs;
self.collect_storage_diffs(contract_storage_diff);
}
}
TransactionTrace::DeployAccount(deploy_trace) => {
if let Some(state_diff) = &deploy_trace.state_diff {
let contract_storage_diff = &state_diff.storage_diffs;
self.collect_storage_diffs(contract_storage_diff);
}
}
TransactionTrace::L1Handler(l1handler_trace) => {
if let Some(state_diff) = &l1handler_trace.state_diff {
let contract_storage_diff = &state_diff.storage_diffs;
self.collect_storage_diffs(contract_storage_diff);
}
}
}
}

Ok(())
}

fn collect_storage_diffs(&mut self, storage_diffs: &[ContractStorageDiffItem]) {
for storage_diff in storage_diffs.iter() {
let contract_address: ContractAddress =
ContractAddress::try_from(StarkFelt::from(storage_diff.address)).unwrap();
let contract_storage = self.storage_diff.entry(contract_address).or_default();
for storage_entry in storage_diff.storage_entries.iter() {
let key = StorageKey::try_from(StarkFelt::from(storage_entry.key)).unwrap();
let new_value: StarkFelt = storage_entry.value.into_();
contract_storage.insert(key, new_value);
}
}
}
}

#[allow(clippy::needless_pass_by_value)]
Expand Down Expand Up @@ -121,17 +218,31 @@ impl StateReader for ForkStateReader {
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<StarkFelt> {
// First check cache
if let Some(cache_hit) = self.cache.borrow().get_storage_at(&contract_address, &key) {
return Ok(cache_hit);
}

// Second check the storage_diff hash map
if let Some(contract_updates) = self.storage_diff.get(&contract_address) {
if let Some(&value) = contract_updates.get(&key) {
self.cache
.borrow_mut()
.cache_get_storage_at(contract_address, key, value);

return Ok(value);
}
}

// Third ping provider
match tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(self.client.get_storage_at(
FieldElement::from_(contract_address),
FieldElement::from_(*key.0.key()),
self.block_id(),
))
}) {

// match self.runtime.block_on(self.client.get_storage_at(
// FieldElement::from_(contract_address),
// FieldElement::from_(*key.0.key()),
Expand Down
Loading