diff --git a/crates/papyrus_network/src/db_executor/mod.rs b/crates/papyrus_network/src/db_executor/mod.rs index 44fda349e4..66f9b17b51 100644 --- a/crates/papyrus_network/src/db_executor/mod.rs +++ b/crates/papyrus_network/src/db_executor/mod.rs @@ -1,17 +1,19 @@ use std::vec; use async_trait::async_trait; -use futures::channel::mpsc::Sender; -use futures::future::{pending, poll_fn}; +use futures::future::pending; +use futures::{FutureExt, Sink, SinkExt, StreamExt}; use papyrus_protobuf::sync::{ BlockHashOrNumber, ContractDiff, DataOrFin, DeclaredClass, DeprecatedDeclaredClass, + HeaderQuery, Query, SignedBlockHeader, StateDiffChunk, + StateDiffQuery, }; use papyrus_storage::header::HeaderStorageReader; use papyrus_storage::state::StateStorageReader; @@ -20,6 +22,8 @@ use starknet_api::block::BlockNumber; use starknet_api::state::ThinStateDiff; use tracing::error; +use crate::network_manager::SqmrQueryReceiver; + #[cfg(test)] mod test; @@ -59,12 +63,14 @@ impl DBExecutorError { /// A DBExecutor receives inbound queries and returns their corresponding data. #[async_trait] pub trait DBExecutorTrait { - /// Send a query to be executed in the DBExecutor. The query will be run concurrently with the - /// calling code and the result will be over the given channel. - fn register_query( - &self, - query: Query, - sender: Sender>, + fn set_header_queries_receiver( + &mut self, + receiver: SqmrQueryReceiver>, + ); + + fn set_state_diff_queries_receiver( + &mut self, + receiver: SqmrQueryReceiver>, ); /// Polls incoming queries. @@ -74,21 +80,73 @@ pub trait DBExecutorTrait { pub struct DBExecutor { storage_reader: StorageReader, + // TODO(shahak): Make this non-option. + header_queries_receiver: Option>>, + // TODO(shahak): Make this non-option. + state_diff_queries_receiver: + Option>>, +} + +#[async_trait] +impl DBExecutorTrait for DBExecutor { + fn set_header_queries_receiver( + &mut self, + receiver: SqmrQueryReceiver>, + ) { + self.header_queries_receiver = Some(receiver); + } + + fn set_state_diff_queries_receiver( + &mut self, + receiver: SqmrQueryReceiver>, + ) { + self.state_diff_queries_receiver = Some(receiver); + } + + async fn run(&mut self) { + loop { + let header_queries_receiver_future = + if let Some(header_queries_receiver) = self.header_queries_receiver.as_mut() { + header_queries_receiver.next().boxed() + } else { + pending().boxed() + }; + let state_diff_queries_receiver_future = if let Some(state_diff_queries_receiver) = + self.state_diff_queries_receiver.as_mut() + { + state_diff_queries_receiver.next().boxed() + } else { + pending().boxed() + }; + + tokio::select! { + Some((query_result, response_sender)) = header_queries_receiver_future => { + // TODO(shahak): Report if query_result is Err. + if let Ok(query) = query_result { + self.register_query(query.0, response_sender); + } + } + Some((query_result, response_sender)) = state_diff_queries_receiver_future => { + if let Ok(query) = query_result { + self.register_query(query.0, response_sender); + } + } + }; + } + } } impl DBExecutor { pub fn new(storage_reader: StorageReader) -> Self { - Self { storage_reader } + Self { storage_reader, header_queries_receiver: None, state_diff_queries_receiver: None } } -} -#[async_trait] -impl DBExecutorTrait for DBExecutor { - fn register_query( - &self, - query: Query, - sender: Sender>, - ) { + fn register_query(&self, query: Query, sender: Sender) + where + Data: FetchBlockDataFromDb + Send + 'static, + Sender: Sink> + Unpin + Send + 'static, + DBExecutorError: From<>>::Error>, + { let storage_reader_clone = self.storage_reader.clone(); tokio::task::spawn(async move { let result = send_data_for_query(storage_reader_clone, query.clone(), sender).await; @@ -102,12 +160,6 @@ impl DBExecutorTrait for DBExecutor { } }); } - - async fn run(&mut self) { - // TODO(shahak): Parse incoming queries once we receive them through channel instead of - // through function. - pending::<()>().await - } } pub trait FetchBlockDataFromDb: Sized { @@ -197,23 +249,32 @@ pub fn split_thin_state_diff(thin_state_diff: ThinStateDiff) -> Vec( +async fn send_data_for_query( storage_reader: StorageReader, query: Query, - mut sender: Sender>, -) -> Result<(), DBExecutorError> { + mut sender: Sender, +) -> Result<(), DBExecutorError> +where + Data: FetchBlockDataFromDb + Send + 'static, + Sender: Sink> + Unpin + Send + 'static, + DBExecutorError: From<>>::Error>, +{ // If this function fails, we still want to send fin before failing. let result = send_data_without_fin_for_query(&storage_reader, query, &mut sender).await; - poll_fn(|cx| sender.poll_ready(cx)).await?; - sender.start_send(DataOrFin(None))?; + sender.feed(DataOrFin(None)).await?; result } -async fn send_data_without_fin_for_query( +async fn send_data_without_fin_for_query( storage_reader: &StorageReader, query: Query, - sender: &mut Sender>, -) -> Result<(), DBExecutorError> { + sender: &mut Sender, +) -> Result<(), DBExecutorError> +where + Data: FetchBlockDataFromDb + Send + 'static, + Sender: Sink> + Unpin + Send + 'static, + DBExecutorError: From<>>::Error>, +{ let txn = storage_reader.begin_ro_txn()?; let start_block_number = match query.start_block { BlockHashOrNumber::Number(BlockNumber(num)) => num, @@ -229,11 +290,9 @@ async fn send_data_without_fin_for_query(query.clone(), sender); + db_executor.register_query::(query.clone(), sender); // run the executor and collect query results. tokio::select! { @@ -79,7 +79,7 @@ async fn header_query_start_block_given_by_hash() { limit: NUM_OF_BLOCKS, step: 1, }; - db_executor.register_query::(query, sender); + db_executor.register_query::(query, sender); // run the executor and collect query results. tokio::select! { @@ -118,7 +118,7 @@ async fn header_query_some_blocks_are_missing() { limit: NUM_OF_BLOCKS, step: 1, }; - db_executor.register_query::(query, sender); + db_executor.register_query::(query, sender); tokio::select! { _ = db_executor.run() => { diff --git a/crates/papyrus_network/src/lib.rs b/crates/papyrus_network/src/lib.rs index 54ed83f92a..048fa33a96 100644 --- a/crates/papyrus_network/src/lib.rs +++ b/crates/papyrus_network/src/lib.rs @@ -29,9 +29,6 @@ use papyrus_config::converters::{ use papyrus_config::dumping::{ser_optional_param, ser_param, SerializeConfig}; use papyrus_config::validators::validate_vec_u256; use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam}; -use papyrus_protobuf::protobuf; -use papyrus_protobuf::sync::Query; -use prost::Message; use serde::{Deserialize, Serialize}; use validator::Validate; @@ -98,20 +95,6 @@ impl Protocol { Protocol::StateDiff => "/starknet/state_diffs/1", } } - - pub fn bytes_query_to_protobuf_request(&self, query: Vec) -> Query { - // TODO: make this function return errors instead of panicking. - match self { - Protocol::SignedBlockHeader => protobuf::BlockHeadersRequest::decode(&query[..]) - .expect("failed to decode protobuf BlockHeadersRequest") - .try_into() - .expect("failed to convert BlockHeadersRequest"), - Protocol::StateDiff => protobuf::StateDiffsRequest::decode(&query[..]) - .expect("failed to decode protobuf StateDiffsRequest") - .try_into() - .expect("failed to convert StateDiffsRequest"), - } - } } impl From for StreamProtocol { diff --git a/crates/papyrus_network/src/network_manager/mod.rs b/crates/papyrus_network/src/network_manager/mod.rs index 11d8700cca..7450307dab 100644 --- a/crates/papyrus_network/src/network_manager/mod.rs +++ b/crates/papyrus_network/src/network_manager/mod.rs @@ -16,7 +16,13 @@ use libp2p::swarm::SwarmEvent; use libp2p::{PeerId, Swarm}; use metrics::gauge; use papyrus_common::metrics as papyrus_metrics; -use papyrus_protobuf::sync::{DataOrFin, SignedBlockHeader, StateDiffChunk}; +use papyrus_protobuf::sync::{ + DataOrFin, + HeaderQuery, + SignedBlockHeader, + StateDiffChunk, + StateDiffQuery, +}; use papyrus_storage::StorageReader; use sqmr::Bytes; use tracing::{debug, error, info, trace}; @@ -42,6 +48,7 @@ pub struct GenericNetworkManager>>, + sqmr_inbound_query_senders: HashMap)>>, // Splitting the response receivers from the query senders in order to poll all // receivers simultaneously. // Each receiver has a matching sender and vice versa (i.e the maps have the same keys). @@ -63,6 +70,18 @@ pub struct GenericNetworkManager GenericNetworkManager { pub async fn run(mut self) -> Result<(), NetworkError> { + // TODO(shahak): Move this logic to register_sqmr_subscriber. + let header_db_executor_channel = self + .register_protocol_to_db_executor::>( + Protocol::SignedBlockHeader, + ); + self.db_executor.set_header_queries_receiver(header_db_executor_channel); + let state_diff_db_executor_channel = self + .register_protocol_to_db_executor::>( + Protocol::StateDiff, + ); + self.db_executor.set_state_diff_queries_receiver(state_diff_db_executor_channel); + loop { tokio::select! { Some(event) = self.swarm.next() => self.handle_swarm_event(event), @@ -79,6 +98,26 @@ impl GenericNetworkManager( + &mut self, + protocol: Protocol, + ) -> SqmrQueryReceiver + where + Query: TryFrom, + Bytes: From, + { + let (inbound_query_sender, inbound_query_receiver) = + futures::channel::mpsc::channel(self.header_buffer_size); + self.sqmr_inbound_query_senders.insert(protocol, inbound_query_sender); + + inbound_query_receiver.map(|(query_bytes, response_bytes_sender)| { + ( + Query::try_from(query_bytes), + response_bytes_sender.with(|response| ready(Ok(Bytes::from(response)))), + ) + }) + } + pub(crate) fn generic_new( swarm: SwarmT, db_executor: DBExecutorT, @@ -91,6 +130,7 @@ impl GenericNetworkManager GenericNetworkManager { - let (sender, receiver) = futures::channel::mpsc::channel::< - DataOrFin, - >(self.header_buffer_size); - self.db_executor.register_query(internal_query, sender); - self.sqmr_inbound_response_receivers.insert( - inbound_session_id, - receiver - .map(|data| Some(Bytes::from(data))) - .chain(stream::once(ready(None))) - .boxed(), - ); - } - Protocol::StateDiff => { - let (sender, receiver) = futures::channel::mpsc::channel::< - DataOrFin, - >(self.header_buffer_size); - self.db_executor.register_query(internal_query, sender); - self.sqmr_inbound_response_receivers.insert( - inbound_session_id, - receiver - .map(|data| Some(Bytes::from(data))) - .chain(stream::once(ready(None))) - .boxed(), - ); - } - } + let Some(query_sender) = self.sqmr_inbound_query_senders.get_mut(&protocol) else { + return; + }; + let (response_sender, response_receiver) = + futures::channel::mpsc::channel(self.header_buffer_size); + // TODO(shahak): Close the inbound session if the buffer is full. + send_now( + query_sender, + (query, response_sender), + format!( + "Received an inbound query while the buffer is full. Dropping query for \ + session {inbound_session_id:?}" + ), + ); + self.sqmr_inbound_response_receivers.insert( + inbound_session_id, + response_receiver.map(Some).chain(stream::once(ready(None))).boxed(), + ); } sqmr::behaviour::ExternalEvent::ReceivedData { outbound_session_id, data, peer_id } => { trace!( @@ -345,18 +374,15 @@ impl GenericNetworkManager { @@ -525,6 +551,15 @@ impl NetworkManager { // TODO(shahak): Change to a wrapper of PeerId if Box dyn becomes an overhead. pub type ReportCallback = Box; +// TODO(shahak): Add report callback. +pub type SqmrQueryReceiver = + Map)>, ReceivedQueryConverterFn>; + +type ReceivedQueryConverterFn = + fn( + (Bytes, Sender), + ) -> (Result>::Error>, SubscriberSender); + pub type SubscriberSender = With< Sender, Bytes, @@ -551,3 +586,14 @@ pub struct BroadcastSubscriberChannels> { pub messages_to_broadcast_sender: SubscriberSender, pub broadcasted_messages_receiver: SubscriberReceiver, } + +fn send_now(sender: &mut Sender, item: Item, buffer_full_message: String) { + if let Err(error) = sender.try_send(item) { + if error.is_disconnected() { + panic!("Receiver was dropped. This should never happen.") + } else if error.is_full() { + // TODO(shahak): Consider doing something else rather than dropping the message. + error!(buffer_full_message); + } + } +}