diff --git a/crates/papyrus_network/src/db_executor/mod.rs b/crates/papyrus_network/src/db_executor/mod.rs index 20543d9c9f..44fda349e4 100644 --- a/crates/papyrus_network/src/db_executor/mod.rs +++ b/crates/papyrus_network/src/db_executor/mod.rs @@ -1,12 +1,8 @@ use std::vec; use async_trait::async_trait; -use bytes::BufMut; use futures::channel::mpsc::Sender; use futures::future::{pending, poll_fn}; -#[cfg(test)] -use mockall::automock; -use papyrus_protobuf::protobuf; use papyrus_protobuf::sync::{ BlockHashOrNumber, ContractDiff, @@ -20,56 +16,15 @@ use papyrus_protobuf::sync::{ use papyrus_storage::header::HeaderStorageReader; use papyrus_storage::state::StateStorageReader; use papyrus_storage::{db, StorageReader, StorageTxn}; -use prost::Message; use starknet_api::block::BlockNumber; use starknet_api::state::ThinStateDiff; use tracing::error; -use crate::DataType; - #[cfg(test)] mod test; mod utils; -#[derive(thiserror::Error, Debug)] -#[error("Failed to encode data")] -pub struct DataEncodingError; - -#[cfg_attr(test, derive(Debug, PartialEq, Eq))] -#[derive(Clone)] -pub enum Data { - BlockHeaderAndSignature(DataOrFin), - StateDiffChunk(DataOrFin), -} - -impl Default for Data { - fn default() -> Self { - // TODO: consider this default data type. - Self::BlockHeaderAndSignature(DataOrFin(None)) - } -} - -impl Data { - pub fn encode(self, buf: &mut B) -> Result<(), DataEncodingError> - where - B: BufMut, - { - match self { - Data::BlockHeaderAndSignature(maybe_signed_block_header) => { - let block_headers_response = - protobuf::BlockHeadersResponse::from(maybe_signed_block_header); - block_headers_response.encode(buf).map_err(|_| DataEncodingError) - } - Data::StateDiffChunk(maybe_state_diff_chunk) => { - let state_diffs_response = - protobuf::StateDiffsResponse::from(maybe_state_diff_chunk); - state_diffs_response.encode(buf).map_err(|_| DataEncodingError) - } - } - } -} - #[derive(thiserror::Error, Debug)] pub enum DBExecutorError { #[error(transparent)] @@ -106,11 +61,10 @@ impl DBExecutorError { pub trait DBExecutorTrait { /// Send a query to be executed in the DBExecutor. The query will be run concurrently with the /// calling code and the result will be over the given channel. - fn register_query( - &mut self, + fn register_query( + &self, query: Query, - data_type: impl FetchBlockDataFromDb + Send + 'static, - sender: Sender, + sender: Sender>, ); /// Polls incoming queries. @@ -130,16 +84,14 @@ impl DBExecutor { #[async_trait] impl DBExecutorTrait for DBExecutor { - fn register_query( - &mut self, + fn register_query( + &self, query: Query, - data_type: impl FetchBlockDataFromDb + Send + 'static, - sender: Sender, + sender: Sender>, ) { let storage_reader_clone = self.storage_reader.clone(); tokio::task::spawn(async move { - let result = - send_data_for_query(storage_reader_clone, query.clone(), data_type, sender).await; + let result = send_data_for_query(storage_reader_clone, query.clone(), sender).await; if let Err(error) = result { if error.should_log_in_error_level() { error!("Running inbound query {query:?} failed on {error:?}"); @@ -158,69 +110,49 @@ impl DBExecutorTrait for DBExecutor { } } -#[cfg_attr(test, automock)] -// we need to tell clippy to ignore the "needless" lifetime warning because it's not true. -// we do need the lifetime for the automock, following clippy's suggestion will break the code. -#[allow(clippy::needless_lifetimes)] -pub trait FetchBlockDataFromDb { - fn fetch_block_data_from_db<'a>( - &self, +pub trait FetchBlockDataFromDb: Sized { + fn fetch_block_data_from_db( block_number: BlockNumber, - txn: &StorageTxn<'a, db::RO>, - ) -> Result, DBExecutorError>; - - fn fin(&self) -> Data; + txn: &StorageTxn<'_, db::RO>, + ) -> Result, DBExecutorError>; } -impl FetchBlockDataFromDb for DataType { +impl FetchBlockDataFromDb for SignedBlockHeader { fn fetch_block_data_from_db( - &self, block_number: BlockNumber, txn: &StorageTxn<'_, db::RO>, - ) -> Result, DBExecutorError> { - match self { - DataType::SignedBlockHeader => { - let mut header = - txn.get_block_header(block_number)?.ok_or(DBExecutorError::BlockNotFound { + ) -> Result, DBExecutorError> { + let mut header = + txn.get_block_header(block_number)?.ok_or(DBExecutorError::BlockNotFound { + block_hash_or_number: BlockHashOrNumber::Number(block_number), + })?; + // TODO(shahak) Remove this once central sync fills the state_diff_length field. + if header.state_diff_length.is_none() { + header.state_diff_length = Some( + txn.get_state_diff(block_number)? + .ok_or(DBExecutorError::BlockNotFound { block_hash_or_number: BlockHashOrNumber::Number(block_number), - })?; - // TODO(shahak) Remove this once central sync fills the state_diff_length field. - if header.state_diff_length.is_none() { - header.state_diff_length = Some( - txn.get_state_diff(block_number)? - .ok_or(DBExecutorError::BlockNotFound { - block_hash_or_number: BlockHashOrNumber::Number(block_number), - })? - .len(), - ); - } - let signature = txn - .get_block_signature(block_number)? - .ok_or(DBExecutorError::SignatureNotFound { block_number })?; - Ok(vec![Data::BlockHeaderAndSignature(DataOrFin(Some(SignedBlockHeader { - block_header: header, - signatures: vec![signature], - })))]) - } - DataType::StateDiff => { - let thin_state_diff = - txn.get_state_diff(block_number)?.ok_or(DBExecutorError::BlockNotFound { - block_hash_or_number: BlockHashOrNumber::Number(block_number), - })?; - let vec_data = split_thin_state_diff(thin_state_diff) - .into_iter() - .map(|state_diff_chunk| Data::StateDiffChunk(DataOrFin(Some(state_diff_chunk)))) - .collect(); - Ok(vec_data) - } + })? + .len(), + ); } + let signature = txn + .get_block_signature(block_number)? + .ok_or(DBExecutorError::SignatureNotFound { block_number })?; + Ok(vec![SignedBlockHeader { block_header: header, signatures: vec![signature] }]) } +} - fn fin(&self) -> Data { - match self { - DataType::SignedBlockHeader => Data::BlockHeaderAndSignature(DataOrFin(None)), - DataType::StateDiff => Data::StateDiffChunk(DataOrFin(None)), - } +impl FetchBlockDataFromDb for StateDiffChunk { + fn fetch_block_data_from_db( + block_number: BlockNumber, + txn: &StorageTxn<'_, db::RO>, + ) -> Result, DBExecutorError> { + let thin_state_diff = + txn.get_state_diff(block_number)?.ok_or(DBExecutorError::BlockNotFound { + block_hash_or_number: BlockHashOrNumber::Number(block_number), + })?; + Ok(split_thin_state_diff(thin_state_diff)) } } @@ -265,26 +197,22 @@ pub fn split_thin_state_diff(thin_state_diff: ThinStateDiff) -> Vec( storage_reader: StorageReader, query: Query, - data_type: impl FetchBlockDataFromDb + Send + 'static, - mut sender: Sender, + mut sender: Sender>, ) -> Result<(), DBExecutorError> { - let fin = data_type.fin(); // If this function fails, we still want to send fin before failing. - let result = - send_data_without_fin_for_query(&storage_reader, query, data_type, &mut sender).await; + let result = send_data_without_fin_for_query(&storage_reader, query, &mut sender).await; poll_fn(|cx| sender.poll_ready(cx)).await?; - sender.start_send(fin)?; + sender.start_send(DataOrFin(None))?; result } -async fn send_data_without_fin_for_query( +async fn send_data_without_fin_for_query( storage_reader: &StorageReader, query: Query, - data_type: impl FetchBlockDataFromDb + Send + 'static, - sender: &mut Sender, + sender: &mut Sender>, ) -> Result<(), DBExecutorError> { let txn = storage_reader.begin_ro_txn()?; let start_block_number = match query.start_block { @@ -300,12 +228,12 @@ async fn send_data_without_fin_for_query( for block_counter in 0..query.limit { let block_number = BlockNumber(utils::calculate_block_number(&query, start_block_number, block_counter)?); - let data_vec = data_type.fetch_block_data_from_db(block_number, &txn)?; + let data_vec = Data::fetch_block_data_from_db(block_number, &txn)?; // Using poll_fn because Sender::poll_ready is not a future poll_fn(|cx| sender.poll_ready(cx)).await?; for data in data_vec { // TODO: consider implement retry mechanism. - sender.start_send(data)?; + sender.start_send(DataOrFin(Some(data)))?; } } Ok(()) diff --git a/crates/papyrus_network/src/db_executor/test.rs b/crates/papyrus_network/src/db_executor/test.rs index b5b85d9125..d5105ce818 100644 --- a/crates/papyrus_network/src/db_executor/test.rs +++ b/crates/papyrus_network/src/db_executor/test.rs @@ -1,8 +1,6 @@ -use futures::channel::mpsc::Receiver; -use futures::stream::SelectAll; -use futures::{FutureExt, StreamExt}; +use futures::StreamExt; use papyrus_common::state::create_random_state_diff; -use papyrus_protobuf::sync::{BlockHashOrNumber, DataOrFin, Direction, Query}; +use papyrus_protobuf::sync::{BlockHashOrNumber, DataOrFin, Direction, Query, SignedBlockHeader}; use papyrus_storage::header::{HeaderStorageReader, HeaderStorageWriter}; use papyrus_storage::state::StateStorageWriter; use papyrus_storage::test_utils::get_test_storage; @@ -11,13 +9,13 @@ use rand::random; use starknet_api::block::{BlockHash, BlockHeader, BlockNumber, BlockSignature}; use test_utils::get_rng; -use crate::db_executor::{DBExecutorError, DBExecutorTrait, Data, MockFetchBlockDataFromDb}; -use crate::DataType; +use crate::db_executor::DBExecutorTrait; const BUFFER_SIZE: usize = 10; +// TODO(shahak): Add test for state_diff_query_positive_flow. #[tokio::test] -async fn header_db_executor_can_register_and_run_a_query() { +async fn header_query_positive_flow() { let ((storage_reader, mut storage_writer), _temp_dir) = get_test_storage(); let mut db_executor = super::DBExecutor::new(storage_reader); @@ -32,66 +30,31 @@ async fn header_db_executor_can_register_and_run_a_query() { limit: NUM_OF_BLOCKS, step: 1, }; - type ReceiversType = Vec<(Receiver, DataType)>; - let mut receivers: ReceiversType = enum_iterator::all::() - .map(|data_type| { - let (sender, receiver) = futures::channel::mpsc::channel(BUFFER_SIZE); - db_executor.register_query(query.clone(), data_type, sender); - (receiver, data_type) - }) - .collect(); - let mut receivers_stream = SelectAll::new(); - receivers - .iter_mut() - .map(|(receiver, requested_data_type)| { - receiver - .collect::>() - .map(|collected| async move { (collected, requested_data_type) }) - }) - .for_each(|fut| { - receivers_stream.push(fut.into_stream()); - }); + let (sender, data_receiver) = futures::channel::mpsc::channel(BUFFER_SIZE); + db_executor.register_query::(query.clone(), sender); // run the executor and collect query results. tokio::select! { _ = db_executor.run() => { panic!("DB executor should never finish its run."); }, - _ = async { - while let Some(res) = receivers_stream.next().await { - let (data, requested_data_type) = res.await; - let len = data.len(); - if matches!(requested_data_type, DataType::SignedBlockHeader) { - assert_eq!(len, NUM_OF_BLOCKS as usize + 1); - } - for (i, data) in data.into_iter().enumerate() { - match &data { - Data::BlockHeaderAndSignature(_) => { - assert_eq!(*requested_data_type, DataType::SignedBlockHeader); - } - Data::StateDiffChunk(_) => { - assert_eq!(*requested_data_type, DataType::StateDiff); - } - } - match data { - Data::BlockHeaderAndSignature(DataOrFin(Some(signed_header))) => { - assert_eq!(signed_header.block_header.block_number.0, i as u64); - } - Data::StateDiffChunk(DataOrFin(Some(_state_diff))) => { - // TODO: check the state diff. - } - _ => { - assert_eq!(i, len - 1); - } + all_data = data_receiver.collect::>() => { + let len = all_data.len(); + assert_eq!(len, NUM_OF_BLOCKS as usize + 1); + for (i, data) in all_data.into_iter().enumerate() { + match data { + DataOrFin(Some(signed_header)) => { + assert_eq!(signed_header.block_header.block_number.0, i as u64); } + DataOrFin(None) => assert_eq!(i, len - 1), } } - } => {} + } } } #[tokio::test] -async fn header_db_executor_start_block_given_by_hash() { +async fn header_query_start_block_given_by_hash() { let ((storage_reader, mut storage_writer), _temp_dir) = get_test_storage(); // put some data in the storage. @@ -116,7 +79,7 @@ async fn header_db_executor_start_block_given_by_hash() { limit: NUM_OF_BLOCKS, step: 1, }; - db_executor.register_query(query, DataType::SignedBlockHeader, sender); + db_executor.register_query::(query, sender); // run the executor and collect query results. tokio::select! { @@ -128,11 +91,10 @@ async fn header_db_executor_start_block_given_by_hash() { assert_eq!(len, NUM_OF_BLOCKS as usize + 1); for (i, data) in res.into_iter().enumerate() { match data { - Data::BlockHeaderAndSignature(DataOrFin(Some(signed_header))) => { + DataOrFin(Some(signed_header)) => { assert_eq!(signed_header.block_header.block_number.0, i as u64); } - Data::BlockHeaderAndSignature(DataOrFin(None)) => assert_eq!(i, len - 1), - _ => panic!("Unexpected data type"), + DataOrFin(None) => assert_eq!(i, len - 1), }; } } @@ -140,7 +102,7 @@ async fn header_db_executor_start_block_given_by_hash() { } #[tokio::test] -async fn header_db_executor_query_of_missing_block() { +async fn header_query_some_blocks_are_missing() { let ((storage_reader, mut storage_writer), _temp_dir) = get_test_storage(); let mut db_executor = super::DBExecutor::new(storage_reader); @@ -156,20 +118,7 @@ async fn header_db_executor_query_of_missing_block() { limit: NUM_OF_BLOCKS, step: 1, }; - let mut mock_data_type = MockFetchBlockDataFromDb::new(); - mock_data_type.expect_fetch_block_data_from_db().times((BLOCKS_DELTA + 1) as usize).returning( - |block_number, _| { - if block_number.0 == NUM_OF_BLOCKS { - Err(DBExecutorError::BlockNotFound { - block_hash_or_number: BlockHashOrNumber::Number(block_number), - }) - } else { - Ok(vec![Data::default()]) - } - }, - ); - mock_data_type.expect_fin().times(1).returning(Data::default); - db_executor.register_query(query, mock_data_type, sender); + db_executor.register_query::(query, sender); tokio::select! { _ = db_executor.run() => { @@ -177,6 +126,9 @@ async fn header_db_executor_query_of_missing_block() { }, res = receiver.collect::>() => { assert_eq!(res.len(), (BLOCKS_DELTA + 1) as usize); + for (i, data) in res.into_iter().enumerate() { + assert_eq!(i == usize::try_from(BLOCKS_DELTA).unwrap(), data.0.is_none()); + } } } } diff --git a/crates/papyrus_network/src/network_manager/mod.rs b/crates/papyrus_network/src/network_manager/mod.rs index d11d0cf35f..11d8700cca 100644 --- a/crates/papyrus_network/src/network_manager/mod.rs +++ b/crates/papyrus_network/src/network_manager/mod.rs @@ -1,32 +1,34 @@ mod swarm_trait; -#[cfg(test)] -mod test; +// TODO(shahak): Uncomment +// #[cfg(test)] +// mod test; use std::collections::HashMap; use futures::channel::mpsc::{Receiver, SendError, Sender, UnboundedReceiver, UnboundedSender}; use futures::future::{ready, Ready}; use futures::sink::With; -use futures::stream::{self, Chain, Map, Once}; +use futures::stream::{self, BoxStream, Map}; use futures::{SinkExt, StreamExt}; use libp2p::gossipsub::{SubscriptionError, TopicHash}; use libp2p::swarm::SwarmEvent; use libp2p::{PeerId, Swarm}; use metrics::gauge; use papyrus_common::metrics as papyrus_metrics; +use papyrus_protobuf::sync::{DataOrFin, SignedBlockHeader, StateDiffChunk}; use papyrus_storage::StorageReader; use sqmr::Bytes; use tracing::{debug, error, info, trace}; use self::swarm_trait::SwarmTrait; use crate::bin_utils::build_swarm; -use crate::db_executor::{DBExecutor, DBExecutorTrait, Data}; +use crate::db_executor::{DBExecutor, DBExecutorTrait}; use crate::gossipsub_impl::Topic; use crate::mixed_behaviour::{self, BridgedBehaviour}; use crate::sqmr::{self, InboundSessionId, OutboundSessionId, SessionId}; use crate::utils::StreamHashMap; -use crate::{gossipsub_impl, DataType, NetworkConfig, Protocol}; +use crate::{gossipsub_impl, NetworkConfig, Protocol}; #[derive(thiserror::Error, Debug)] pub enum NetworkError { @@ -38,7 +40,8 @@ pub struct GenericNetworkManager>, + sqmr_inbound_response_receivers: + StreamHashMap>>, // Splitting the response receivers from the query senders in order to poll all // receivers simultaneously. // Each receiver has a matching sender and vice versa (i.e the maps have the same keys). @@ -297,19 +300,38 @@ impl GenericNetworkManager Option = Some; - self.sqmr_inbound_response_receivers.insert( - inbound_session_id, - receiver.map(response_fn).chain(stream::once(ready(None))), - ); + match protocol { + Protocol::SignedBlockHeader => { + let (sender, receiver) = futures::channel::mpsc::channel::< + DataOrFin, + >(self.header_buffer_size); + self.db_executor.register_query(internal_query, sender); + self.sqmr_inbound_response_receivers.insert( + inbound_session_id, + receiver + .map(|data| Some(Bytes::from(data))) + .chain(stream::once(ready(None))) + .boxed(), + ); + } + Protocol::StateDiff => { + let (sender, receiver) = futures::channel::mpsc::channel::< + DataOrFin, + >(self.header_buffer_size); + self.db_executor.register_query(internal_query, sender); + self.sqmr_inbound_response_receivers.insert( + inbound_session_id, + receiver + .map(|data| Some(Bytes::from(data))) + .chain(stream::once(ready(None))) + .boxed(), + ); + } + } } sqmr::behaviour::ExternalEvent::ReceivedData { outbound_session_id, data, peer_id } => { trace!( @@ -382,13 +404,11 @@ impl GenericNetworkManager)) { + fn handle_response_for_inbound_query(&mut self, res: (InboundSessionId, Option)) { let (inbound_session_id, maybe_data) = res; match maybe_data { Some(data) => { - let mut data_bytes = vec![]; - data.encode(&mut data_bytes).expect("failed to encode data"); - self.swarm.send_data(data_bytes, inbound_session_id).unwrap_or_else(|e| { + self.swarm.send_data(data, inbound_session_id).unwrap_or_else(|e| { error!( "Failed to send data to peer. Session id: {inbound_session_id:?} not \ found error: {e:?}" @@ -531,6 +551,3 @@ pub struct BroadcastSubscriberChannels> { pub messages_to_broadcast_sender: SubscriberSender, pub broadcasted_messages_receiver: SubscriberReceiver, } - -type SqmrResponseReceiver = - Chain, fn(Response) -> Option>, Once>>>; diff --git a/crates/papyrus_protobuf/src/converters/state_diff.rs b/crates/papyrus_protobuf/src/converters/state_diff.rs index 7ce2d042c6..c54b7b527d 100644 --- a/crates/papyrus_protobuf/src/converters/state_diff.rs +++ b/crates/papyrus_protobuf/src/converters/state_diff.rs @@ -39,7 +39,6 @@ impl TryFrom for DataOrFin { } } auto_impl_try_from_vec_u8!(DataOrFin, protobuf::StateDiffsResponse); -auto_impl_try_from_vec_u8!(DataOrFin, protobuf::StateDiffsResponse); impl TryFrom for DataOrFin { type Error = ProtobufConversionError; @@ -87,6 +86,7 @@ impl From> for protobuf::StateDiffsResponse { protobuf::StateDiffsResponse { state_diff_message: Some(state_diff_message) } } } +auto_impl_into_and_try_from_vec_u8!(DataOrFin, protobuf::StateDiffsResponse); impl TryFrom for ThinStateDiff { type Error = ProtobufConversionError;