Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add full record of visited_pcs to support starknet-replay #2

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/blockifier/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ workspace = true

[features]
concurrency = []
full_visited_pcs = []
Eagle941 marked this conversation as resolved.
Show resolved Hide resolved
jemalloc = ["dep:tikv-jemallocator"]
testing = ["rand", "rstest"]

Expand Down
21 changes: 13 additions & 8 deletions crates/blockifier/src/blockifier/transaction_executor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#[cfg(feature = "concurrency")]
use std::collections::{HashMap, HashSet};
#[cfg(feature = "concurrency")]
use std::sync::Arc;
#[cfg(feature = "concurrency")]
use std::sync::Mutex;
Expand All @@ -15,7 +13,9 @@ use crate::bouncer::{Bouncer, BouncerWeights};
#[cfg(feature = "concurrency")]
use crate::concurrency::worker_logic::WorkerExecutor;
use crate::context::BlockContext;
use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState};
use crate::state::cached_state::{
CachedState, CommitmentStateDiff, TransactionalState, VisitedPcs,
};
use crate::state::errors::StateError;
use crate::state::state_api::StateReader;
use crate::transaction::errors::TransactionExecutionError;
Expand Down Expand Up @@ -150,14 +150,19 @@ impl<S: StateReader> TransactionExecutor<S> {
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.visited_pcs
.iter()
.map(|(class_hash, class_visited_pcs)| -> TransactionExecutorResult<_> {
.keys()
.map(|class_hash| -> TransactionExecutorResult<_> {
let contract_class = self
.block_state
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.get_compiled_contract_class(*class_hash)?;
Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?))
let visited_pcs_set = self
.block_state
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.get_set_visited_pcs(class_hash);
Ok((*class_hash, contract_class.get_visited_segments(&visited_pcs_set)?))
})
.collect::<TransactionExecutorResult<_>>()?;

Expand Down Expand Up @@ -243,7 +248,7 @@ impl<S: StateReader + Send + Sync> TransactionExecutor<S> {

let n_committed_txs = worker_executor.scheduler.get_n_committed_txs();
let mut tx_execution_results = Vec::new();
let mut visited_pcs: HashMap<ClassHash, HashSet<usize>> = HashMap::new();
let mut visited_pcs: VisitedPcs = VisitedPcs::new();
for execution_output in worker_executor.execution_outputs.iter() {
if tx_execution_results.len() >= n_committed_txs {
break;
Expand All @@ -256,7 +261,7 @@ impl<S: StateReader + Send + Sync> TransactionExecutor<S> {
tx_execution_results
.push(locked_execution_output.result.map_err(TransactionExecutorError::from));
for (class_hash, class_visited_pcs) in locked_execution_output.visited_pcs {
visited_pcs.entry(class_hash).or_default().extend(class_visited_pcs);
visited_pcs.entry(class_hash).or_default().extend(class_visited_pcs.clone());
}
}

Expand Down
7 changes: 3 additions & 4 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex, MutexGuard};

use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
Expand All @@ -8,7 +7,7 @@ use starknet_types_core::felt::Felt;
use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::TxIndex;
use crate::execution::contract_class::ContractClass;
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::cached_state::{ContractClassMapping, StateMaps, VisitedPcs};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult, UpdatableState};

Expand Down Expand Up @@ -202,7 +201,7 @@ impl<U: UpdatableState> VersionedState<U> {
pub fn commit_chunk_and_recover_block_state(
mut self,
n_committed_txs: usize,
visited_pcs: HashMap<ClassHash, HashSet<usize>>,
visited_pcs: VisitedPcs,
) -> U {
if n_committed_txs == 0 {
return self.into_initial_state();
Expand Down Expand Up @@ -277,7 +276,7 @@ impl<S: StateReader> UpdatableState for VersionedStateProxy<S> {
&mut self,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
_visited_pcs: &HashMap<ClassHash, HashSet<usize>>,
_visited_pcs: &VisitedPcs,
) {
self.state().apply_writes(self.tx_index, writes, class_hash_to_class)
}
Expand Down
10 changes: 4 additions & 6 deletions crates/blockifier/src/concurrency/worker_logic.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Mutex;
use std::thread;
use std::time::Duration;

use starknet_api::core::ClassHash;

use super::versioned_state::VersionedState;
use crate::blockifier::transaction_executor::TransactionExecutorError;
use crate::bouncer::Bouncer;
Expand All @@ -16,7 +14,7 @@ use crate::concurrency::versioned_state::ThreadSafeVersionedState;
use crate::concurrency::TxIndex;
use crate::context::BlockContext;
use crate::state::cached_state::{
ContractClassMapping, StateChanges, StateMaps, TransactionalState,
ContractClassMapping, StateChanges, StateMaps, TransactionalState, VisitedPcs,
};
use crate::state::state_api::{StateReader, UpdatableState};
use crate::transaction::objects::{TransactionExecutionInfo, TransactionExecutionResult};
Expand All @@ -34,7 +32,7 @@ pub struct ExecutionTaskOutput {
pub reads: StateMaps,
pub writes: StateMaps,
pub contract_classes: ContractClassMapping,
pub visited_pcs: HashMap<ClassHash, HashSet<usize>>,
pub visited_pcs: VisitedPcs,
pub result: TransactionExecutionResult<TransactionExecutionInfo>,
}

Expand Down Expand Up @@ -264,7 +262,7 @@ impl<'a, U: UpdatableState> WorkerExecutor<'a, U> {
pub fn commit_chunk_and_recover_block_state(
self,
n_committed_txs: usize,
visited_pcs: HashMap<ClassHash, HashSet<usize>>,
visited_pcs: VisitedPcs,
) -> U {
self.state
.into_inner_state()
Expand Down
15 changes: 11 additions & 4 deletions crates/blockifier/src/execution/entry_point_execution.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::HashSet;

use cairo_vm::types::builtin_name::BuiltinName;
use cairo_vm::types::layout_name::LayoutName;
use cairo_vm::types::relocatable::{MaybeRelocatable, Relocatable};
Expand All @@ -19,6 +17,7 @@ use crate::execution::execution_utils::{
read_execution_retdata, write_felt, write_maybe_relocatable, Args, ReadOnlySegments,
};
use crate::execution::syscalls::hint_processor::SyscallHintProcessor;
use crate::state::cached_state::Pcs;
use crate::state::state_api::State;

// TODO(spapini): Try to refactor this file into a StarknetRunner struct.
Expand Down Expand Up @@ -109,7 +108,14 @@ fn register_visited_pcs(
program_segment_size: usize,
bytecode_length: usize,
) -> EntryPointExecutionResult<()> {
let mut class_visited_pcs = HashSet::new();
fn add_element(pcs: &mut Pcs, element: usize) {
#[cfg(not(feature = "full_visited_pcs"))]
pcs.insert(element);

#[cfg(feature = "full_visited_pcs")]
pcs.push(element);
}
let mut class_visited_pcs = Pcs::new();
// Relocate the trace, putting the program segment at address 1 and the execution segment right
// after it.
// TODO(lior): Avoid unnecessary relocation once the VM has a non-relocated `get_trace()`
Expand All @@ -126,10 +132,11 @@ fn register_visited_pcs(
// Jumping to a PC that is not inside the bytecode is possible. For example, to obtain
// the builtin costs. Filter out these values.
if real_pc < bytecode_length {
class_visited_pcs.insert(real_pc);
add_element(&mut class_visited_pcs, real_pc);
}
}
state.add_visited_pcs(class_hash, &class_visited_pcs);

Ok(())
}

Expand Down
7 changes: 1 addition & 6 deletions crates/blockifier/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
// length to pointer type ([not necessarily true](https://github.com/rust-lang/rust/issues/65473),
// but it is a reasonable assumption for now), this attribute protects against potential overflow
// when converting usize to u128.
#![cfg(any(
target_pointer_width = "16",
target_pointer_width = "32",
target_pointer_width = "64",
target_pointer_width = "128"
))]
#![cfg(any(target_pointer_width = "16", target_pointer_width = "32", target_pointer_width = "64"))]
Eagle941 marked this conversation as resolved.
Show resolved Hide resolved

#[cfg(feature = "jemalloc")]
// Override default allocator.
Expand Down
53 changes: 47 additions & 6 deletions crates/blockifier/src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ mod test;

pub type ContractClassMapping = HashMap<ClassHash, ContractClass>;

#[cfg(not(feature = "full_visited_pcs"))]
pub type Pcs = HashSet<usize>;

#[cfg(feature = "full_visited_pcs")]
pub type Pcs = Vec<usize>;

#[cfg(not(feature = "full_visited_pcs"))]
pub type VisitedPcs = HashMap<ClassHash, Pcs>;

#[cfg(feature = "full_visited_pcs")]
pub type VisitedPcs = HashMap<ClassHash, Vec<Pcs>>;

/// Caches read and write requests.
///
/// Writer functionality is builtin, whereas Reader functionality is injected through
Expand All @@ -33,7 +45,7 @@ pub struct CachedState<S: StateReader> {
pub(crate) cache: RefCell<StateCache>,
pub(crate) class_hash_to_class: RefCell<ContractClassMapping>,
/// A map from class hash to the set of PC values that were visited in the class.
pub visited_pcs: HashMap<ClassHash, HashSet<usize>>,
pub visited_pcs: VisitedPcs,
}

impl<S: StateReader> CachedState<S> {
Expand All @@ -59,6 +71,25 @@ impl<S: StateReader> CachedState<S> {
Ok(self.to_state_diff()?.into())
}

pub fn get_set_visited_pcs(&self, class_hash: &ClassHash) -> HashSet<usize> {
#[cfg(not(feature = "full_visited_pcs"))]
fn from_set(class_hash: &ClassHash, visited_pcs: &VisitedPcs) -> HashSet<usize> {
return visited_pcs.get(class_hash).unwrap().clone();
}

#[cfg(feature = "full_visited_pcs")]
fn from_set(class_hash: &ClassHash, visited_pcs: &VisitedPcs) -> HashSet<usize> {
let class_visited_pcs = visited_pcs.get(class_hash).unwrap();
let mut visited_pcs_set: HashSet<usize> = HashSet::new();
for pcs in class_visited_pcs {
visited_pcs_set.extend(pcs.iter());
}
visited_pcs_set
}

from_set(class_hash, &self.visited_pcs)
}

pub fn update_cache(
&mut self,
write_updates: &StateMaps,
Expand All @@ -73,9 +104,9 @@ impl<S: StateReader> CachedState<S> {
self.class_hash_to_class.get_mut().extend(local_contract_cache_updates);
}

pub fn update_visited_pcs_cache(&mut self, visited_pcs: &HashMap<ClassHash, HashSet<usize>>) {
pub fn update_visited_pcs_cache(&mut self, visited_pcs: &VisitedPcs) {
for (class_hash, class_visited_pcs) in visited_pcs {
self.add_visited_pcs(*class_hash, class_visited_pcs);
self.visited_pcs.entry(*class_hash).or_default().extend(class_visited_pcs.clone());
}
}

Expand Down Expand Up @@ -112,7 +143,7 @@ impl<S: StateReader> UpdatableState for CachedState<S> {
&mut self,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
visited_pcs: &HashMap<ClassHash, HashSet<usize>>,
visited_pcs: &VisitedPcs,
) {
// TODO(OriF,15/5/24): Reconsider the clone.
self.update_cache(writes, class_hash_to_class.clone());
Expand Down Expand Up @@ -275,8 +306,18 @@ impl<S: StateReader> State for CachedState<S> {
Ok(())
}

fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet<usize>) {
self.visited_pcs.entry(class_hash).or_default().extend(pcs);
fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &Pcs) {
#[cfg(not(feature = "full_visited_pcs"))]
fn from_set(visited_pcs: &mut VisitedPcs, class_hash: ClassHash, pcs: &Pcs) {
visited_pcs.entry(class_hash).or_default().extend(pcs);
}

#[cfg(feature = "full_visited_pcs")]
fn from_set(visited_pcs: &mut VisitedPcs, class_hash: ClassHash, pcs: &Pcs) {
visited_pcs.entry(class_hash).or_default().push(pcs.to_vec());
}

from_set(&mut self.visited_pcs, class_hash, pcs);
}
}

Expand Down
8 changes: 3 additions & 5 deletions crates/blockifier/src/state/state_api.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::collections::{HashMap, HashSet};

use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;

use super::cached_state::{ContractClassMapping, StateMaps};
use super::cached_state::{ContractClassMapping, Pcs, StateMaps, VisitedPcs};
use crate::abi::abi_utils::get_fee_token_var_address;
use crate::abi::sierra_types::next_storage_key;
use crate::execution::contract_class::ContractClass;
Expand Down Expand Up @@ -107,7 +105,7 @@ pub trait State: StateReader {
/// Marks the given set of PC values as visited for the given class hash.
// TODO(lior): Once we have a BlockResources object, move this logic there. Make sure reverted
// entry points do not affect the final set of PCs.
fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet<usize>);
fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &Pcs);
}

/// A class defining the API for updating a state with transactions writes.
Expand All @@ -116,6 +114,6 @@ pub trait UpdatableState: StateReader {
&mut self,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
visited_pcs: &HashMap<ClassHash, HashSet<usize>>,
visited_pcs: &VisitedPcs,
);
}
7 changes: 1 addition & 6 deletions crates/native_blockifier/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
// The blockifier crate supports only these specific architectures.
#![cfg(any(
target_pointer_width = "16",
target_pointer_width = "32",
target_pointer_width = "64",
target_pointer_width = "128"
))]
#![cfg(any(target_pointer_width = "16", target_pointer_width = "32", target_pointer_width = "64"))]

pub mod errors;
pub mod py_block_executor;
Expand Down