Skip to content

Commit

Permalink
chore: added test only methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Eagle941 committed Sep 4, 2024
1 parent b51bc6f commit 7042bd1
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 76 deletions.
7 changes: 3 additions & 4 deletions crates/blockifier/src/bouncer_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ use crate::blockifier::transaction_executor::{
use crate::bouncer::{verify_tx_weights_in_bounds, Bouncer, BouncerWeights, BuiltinCount};
use crate::context::BlockContext;
use crate::execution::call_info::ExecutionSummary;
use crate::state::cached_state::{StateChangesKeys, TransactionalState};
use crate::state::visited_pcs::VisitedPcsSet;
use crate::state::cached_state::StateChangesKeys;
use crate::storage_key;
use crate::test_utils::initial_test_state::test_state;
use crate::transaction::errors::TransactionExecutionError;
Expand Down Expand Up @@ -185,11 +184,11 @@ fn test_bouncer_try_update(
) {
use cairo_vm::vm::runners::cairo_runner::ExecutionResources;

use crate::state::cached_state::TransactionalState;
use crate::transaction::objects::TransactionResources;

let state = &mut test_state(&BlockContext::create_for_account_testing().chain_info, 0, &[]);
let mut transactional_state: TransactionalState<'_, _, VisitedPcsSet> =
TransactionalState::create_transactional(state);
let mut transactional_state = TransactionalState::create_transactional_for_testing(state);

// Setup the bouncer.
let block_max_capacity = BouncerWeights {
Expand Down
13 changes: 5 additions & 8 deletions crates/blockifier/src/concurrency/flow_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use starknet_api::{contract_address, felt, patricia_key};
use crate::abi::sierra_types::{SierraType, SierraU128};
use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus};
use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE};
use crate::concurrency::versioned_state::{ThreadSafeVersionedState, VersionedStateProxy};
use crate::concurrency::versioned_state::ThreadSafeVersionedState;
use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps};
use crate::state::state_api::UpdatableState;
use crate::state::visited_pcs::VisitedPcsSet;
Expand All @@ -29,7 +29,6 @@ fn scheduler_flow_test(
// transaction sequentially advances a counter by reading the previous value and bumping it by
// 1.

use crate::concurrency::versioned_state::VersionedStateProxy;
use crate::state::visited_pcs::VisitedPcsSet;
let scheduler = Arc::new(Scheduler::new(DEFAULT_CHUNK_SIZE));
let versioned_state =
Expand Down Expand Up @@ -76,8 +75,7 @@ fn scheduler_flow_test(
Task::AskForTask
}
Task::ValidationTask(tx_index) => {
let state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
versioned_state.pin_version(tx_index);
let state_proxy = versioned_state.pin_version_for_testing(tx_index);
let (reads, writes) =
get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state);
let read_set_valid = state_proxy.validate_reads(&reads);
Expand Down Expand Up @@ -129,14 +127,14 @@ fn get_reads_writes_for(
) -> (StateMaps, StateMaps) {
match task {
Task::ExecutionTask(tx_index) => {
let state_proxy: VersionedStateProxy<_, VisitedPcsSet> = match tx_index {
let state_proxy = match tx_index {
0 => {
return (
state_maps_with_single_storage_entry(0),
state_maps_with_single_storage_entry(1),
);
}
_ => versioned_state.pin_version(tx_index - 1),
_ => versioned_state.pin_version_for_testing(tx_index - 1),
};
let tx_written_value = SierraU128::from_storage(
&state_proxy,
Expand All @@ -151,8 +149,7 @@ fn get_reads_writes_for(
)
}
Task::ValidationTask(tx_index) => {
let state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
versioned_state.pin_version(tx_index);
let state_proxy = versioned_state.pin_version_for_testing(tx_index);
let tx_written_value = SierraU128::from_storage(
&state_proxy,
&contract_address!(CONTRACT_ADDRESS),
Expand Down
8 changes: 8 additions & 0 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,14 @@ impl<S: StateReader> ThreadSafeVersionedState<S> {
VersionedStateProxy { tx_index, state: self.0.clone(), _marker: PhantomData }
}

#[cfg(test)]
pub fn pin_version_for_testing(
&self,
tx_index: TxIndex,
) -> VersionedStateProxy<S, crate::state::visited_pcs::VisitedPcsSet> {
VersionedStateProxy { tx_index, state: self.0.clone(), _marker: PhantomData }
}

pub fn into_inner_state(self) -> VersionedState<S> {
Arc::try_unwrap(self.0)
.unwrap_or_else(|_| {
Expand Down
86 changes: 42 additions & 44 deletions crates/blockifier/src/concurrency/versioned_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ use crate::abi::abi_utils::{get_fee_token_var_address, get_storage_var_address};
use crate::concurrency::test_utils::{
class_hash, contract_address, safe_versioned_state_for_testing,
};
use crate::concurrency::versioned_state::{
ThreadSafeVersionedState, VersionedState, VersionedStateProxy,
};
use crate::concurrency::versioned_state::{ThreadSafeVersionedState, VersionedState};
use crate::concurrency::TxIndex;
use crate::context::BlockContext;
use crate::state::cached_state::{
Expand Down Expand Up @@ -73,9 +71,8 @@ fn test_versioned_state_proxy() {
let versioned_state = Arc::new(Mutex::new(VersionedState::new(cached_state)));

let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state));
let versioned_state_proxys: Vec<
VersionedStateProxy<CachedState<DictStateReader, VisitedPcsSet>, VisitedPcsSet>,
> = (0..20).map(|i| safe_versioned_state.pin_version(i)).collect();
let versioned_state_proxys: Vec<_> =
(0..20).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect();

// Read initial data
assert_eq!(versioned_state_proxys[5].get_nonce_at(contract_address).unwrap(), nonce);
Expand Down Expand Up @@ -210,14 +207,12 @@ fn test_run_parallel_txs(max_resource_bounds: ResourceBoundsMapping) {
))));

let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state));
let mut versioned_state_proxy_1: VersionedStateProxy<_, VisitedPcsSet> =
safe_versioned_state.pin_version(1);
let mut state_1: TransactionalState<'_, _, VisitedPcsSet> =
TransactionalState::create_transactional(&mut versioned_state_proxy_1);
let mut versioned_state_proxy_2: VersionedStateProxy<_, VisitedPcsSet> =
safe_versioned_state.pin_version(2);
let mut state_2: TransactionalState<'_, _, VisitedPcsSet> =
TransactionalState::create_transactional(&mut versioned_state_proxy_2);
let mut versioned_state_proxy_1 = safe_versioned_state.pin_version_for_testing(1);
let mut state_1 =
TransactionalState::create_transactional_for_testing(&mut versioned_state_proxy_1);
let mut versioned_state_proxy_2 = safe_versioned_state.pin_version_for_testing(2);
let mut state_2 =
TransactionalState::create_transactional_for_testing(&mut versioned_state_proxy_2);

// Prepare transactions
let deploy_account_tx_1 = deploy_account_tx(
Expand Down Expand Up @@ -288,10 +283,9 @@ fn test_validate_reads(
) {
let storage_key = storage_key!(0x10_u8);

let mut version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
safe_versioned_state.pin_version(1);
let transactional_state: TransactionalState<'_, _, VisitedPcsSet> =
TransactionalState::create_transactional(&mut version_state_proxy);
let mut version_state_proxy = safe_versioned_state.pin_version_for_testing(1);
let transactional_state =
TransactionalState::create_transactional_for_testing(&mut version_state_proxy);

// Validating tx index 0 always succeeds.
assert!(
Expand Down Expand Up @@ -380,8 +374,7 @@ fn test_false_validate_reads(
#[case] tx_0_writes: StateMaps,
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader, VisitedPcsSet>>,
) {
let version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
safe_versioned_state.pin_version(0);
let version_state_proxy = safe_versioned_state.pin_version_for_testing(0);
version_state_proxy.state().apply_writes(0, &tx_0_writes, &HashMap::default());
assert!(!safe_versioned_state.pin_version::<VisitedPcsSet>(1).validate_reads(&tx_1_reads));
}
Expand All @@ -398,8 +391,7 @@ fn test_false_validate_reads_declared_contracts(
declared_contracts: HashMap::from([(class_hash!(1_u8), true)]),
..Default::default()
};
let version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
safe_versioned_state.pin_version(0);
let version_state_proxy = safe_versioned_state.pin_version_for_testing(0);
let compiled_contract_calss = FeatureContract::TestContract(CairoVersion::Cairo1).get_class();
let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]);
version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class);
Expand All @@ -412,10 +404,12 @@ fn test_apply_writes(
class_hash: ClassHash,
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader, VisitedPcsSet>>,
) {
let mut versioned_proxy_states: Vec<VersionedStateProxy<_, VisitedPcsSet>> =
(0..2).map(|i| safe_versioned_state.pin_version(i)).collect();
let mut transactional_states: Vec<TransactionalState<'_, _, VisitedPcsSet>> =
versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();
let mut versioned_proxy_states: Vec<_> =
(0..2).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect();
let mut transactional_states: Vec<_> = versioned_proxy_states
.iter_mut()
.map(TransactionalState::create_transactional_for_testing)
.collect();

// Transaction 0 class hash.
let class_hash_0 = class_hash!(76_u8);
Expand All @@ -429,7 +423,7 @@ fn test_apply_writes(
transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap();
assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1);

safe_versioned_state.pin_version(0).apply_writes(
safe_versioned_state.pin_version_for_testing(0).apply_writes(
&transactional_states[0].cache.borrow().writes,
&transactional_states[0].class_hash_to_class.borrow().clone(),
&VisitedPcsSet::default(),
Expand All @@ -447,10 +441,12 @@ fn test_apply_writes_reexecute_scenario(
class_hash: ClassHash,
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader, VisitedPcsSet>>,
) {
let mut versioned_proxy_states: Vec<VersionedStateProxy<_, VisitedPcsSet>> =
(0..2).map(|i| safe_versioned_state.pin_version(i)).collect();
let mut transactional_states: Vec<TransactionalState<'_, _, VisitedPcsSet>> =
versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();
let mut versioned_proxy_states: Vec<_> =
(0..2).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect();
let mut transactional_states: Vec<_> = versioned_proxy_states
.iter_mut()
.map(TransactionalState::create_transactional_for_testing)
.collect();

// Transaction 0 class hash.
let class_hash_0 = class_hash!(76_u8);
Expand All @@ -460,7 +456,7 @@ fn test_apply_writes_reexecute_scenario(
// updated.
assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash);

safe_versioned_state.pin_version(0).apply_writes(
safe_versioned_state.pin_version_for_testing(0).apply_writes(
&transactional_states[0].cache.borrow().writes,
&transactional_states[0].class_hash_to_class.borrow().clone(),
&VisitedPcsSet::default(),
Expand All @@ -471,7 +467,7 @@ fn test_apply_writes_reexecute_scenario(

// TODO: Use re-execution native util once it's ready.
// "Re-execute" the transaction.
let mut versioned_state_proxy = safe_versioned_state.pin_version(1);
let mut versioned_state_proxy = safe_versioned_state.pin_version_for_testing(1);
transactional_states[1] = TransactionalState::create_transactional(&mut versioned_state_proxy);
// The class hash should be updated.
assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0);
Expand All @@ -483,10 +479,12 @@ fn test_delete_writes(
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader, VisitedPcsSet>>,
) {
let num_of_txs = 3;
let mut versioned_proxy_states: Vec<VersionedStateProxy<_, VisitedPcsSet>> =
(0..num_of_txs).map(|i| safe_versioned_state.pin_version(i)).collect();
let mut transactional_states: Vec<TransactionalState<'_, _, VisitedPcsSet>> =
versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();
let mut versioned_proxy_states: Vec<_> =
(0..num_of_txs).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect();
let mut transactional_states: Vec<_> = versioned_proxy_states
.iter_mut()
.map(TransactionalState::create_transactional_for_testing)
.collect();

// Setting 2 instances of the contract to ensure `delete_writes` removes information from
// multiple keys. Class hash values are not checked in this test.
Expand All @@ -504,7 +502,7 @@ fn test_delete_writes(
tx_state
.set_contract_class(feature_contract.get_class_hash(), feature_contract.get_class())
.unwrap();
safe_versioned_state.pin_version(i).apply_writes(
safe_versioned_state.pin_version_for_testing(i).apply_writes(
&tx_state.cache.borrow().writes,
&tx_state.class_hash_to_class.borrow(),
&VisitedPcsSet::default(),
Expand Down Expand Up @@ -564,7 +562,7 @@ fn test_delete_writes_completeness(
HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_class())]);

let tx_index = 0;
let mut versioned_state_proxy = safe_versioned_state.pin_version(tx_index);
let mut versioned_state_proxy = safe_versioned_state.pin_version_for_testing(tx_index);

versioned_state_proxy.apply_writes(
&state_maps_writes,
Expand Down Expand Up @@ -608,13 +606,13 @@ fn test_versioned_proxy_state_flow(
let contract_address = contract_address!("0x1");
let class_hash = ClassHash(felt!(27_u8));

let mut versioned_proxy_states: Vec<VersionedStateProxy<_, VisitedPcsSet>> =
(0..4).map(|i| safe_versioned_state.pin_version(i)).collect();
let mut versioned_proxy_states: Vec<_> =
(0..4).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect();

let mut transactional_states: Vec<TransactionalState<'_, _, VisitedPcsSet>> =
Vec::with_capacity(4);
let mut transactional_states = Vec::with_capacity(4);
for proxy_state in &mut versioned_proxy_states {
transactional_states.push(TransactionalState::create_transactional(proxy_state));
transactional_states
.push(TransactionalState::create_transactional_for_testing(proxy_state));
}

// Clients class hash values.
Expand Down
13 changes: 12 additions & 1 deletion crates/blockifier/src/concurrency/worker_logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::state::cached_state::{
ContractClassMapping, StateChanges, StateMaps, TransactionalState,
};
use crate::state::state_api::{StateReader, UpdatableState};
use crate::state::visited_pcs::VisitedPcs;
use crate::state::visited_pcs::{VisitedPcs, VisitedPcsSet};
use crate::transaction::objects::{TransactionExecutionInfo, TransactionExecutionResult};
use crate::transaction::transaction_execution::Transaction;
use crate::transaction::transactions::{ExecutableTransaction, ExecutionFlags};
Expand Down Expand Up @@ -45,6 +45,17 @@ pub struct WorkerExecutor<'a, S: StateReader, V: VisitedPcs> {
pub block_context: &'a BlockContext,
pub bouncer: Mutex<&'a mut Bouncer>,
}
impl<'a, S: StateReader> WorkerExecutor<'a, S, VisitedPcsSet> {
#[cfg(test)]
pub fn new_for_testing(
state: ThreadSafeVersionedState<S>,
chunk: &'a [Transaction],
block_context: &'a BlockContext,
bouncer: Mutex<&'a mut Bouncer>,
) -> WorkerExecutor<'a, S, VisitedPcsSet> {
WorkerExecutor::new(state, chunk, block_context, bouncer)
}
}
impl<'a, S: StateReader, V: VisitedPcs> WorkerExecutor<'a, S, V> {
pub fn new(
state: ThreadSafeVersionedState<S>,
Expand Down
Loading

0 comments on commit 7042bd1

Please sign in to comment.