Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(network): dont check fin in network manager #2114

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 43 additions & 40 deletions crates/papyrus_network/src/network_manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::collections::HashMap;
use futures::channel::mpsc::{Receiver, SendError, Sender, UnboundedReceiver, UnboundedSender};
use futures::future::{ready, Ready};
use futures::sink::With;
use futures::stream::{BoxStream, Map, SelectAll};
use futures::stream::{self, Chain, Map, Once};
use futures::{SinkExt, StreamExt};
use libp2p::gossipsub::{SubscriptionError, TopicHash};
use libp2p::swarm::SwarmEvent;
Expand All @@ -28,8 +28,6 @@ use crate::sqmr::{self, InboundSessionId, OutboundSessionId, SessionId};
use crate::utils::StreamHashMap;
use crate::{gossipsub_impl, DataType, NetworkConfig, Protocol};

type StreamCollection = SelectAll<BoxStream<'static, (Data, InboundSessionId)>>;

#[derive(thiserror::Error, Debug)]
pub enum NetworkError {
#[error(transparent)]
Expand All @@ -40,12 +38,12 @@ pub struct GenericNetworkManager<DBExecutorT: DBExecutorTrait, SwarmT: SwarmTrai
swarm: SwarmT,
db_executor: DBExecutorT,
header_buffer_size: usize,
query_results_router: StreamCollection,
sqmr_inbound_response_receivers: StreamHashMap<InboundSessionId, SqmrResponseReceiver<Data>>,
// Splitting the response receivers from the query senders in order to poll all
// receivers simultaneously.
// Each receiver has a matching sender and vice versa (i.e the maps have the same keys).
sqmr_query_receivers: StreamHashMap<Protocol, Receiver<Bytes>>,
sqmr_response_senders: HashMap<Protocol, Sender<(Bytes, ReportCallback)>>,
sqmr_outbound_query_receivers: StreamHashMap<Protocol, Receiver<Bytes>>,
sqmr_outbound_response_senders: HashMap<Protocol, Sender<(Bytes, ReportCallback)>>,
// Splitting the broadcast receivers from the broadcasted senders in order to poll all
// receivers simultaneously.
// Each receiver has a matching sender and vice versa (i.e the maps have the same keys).
Expand All @@ -66,8 +64,8 @@ impl<DBExecutorT: DBExecutorTrait, SwarmT: SwarmTrait> GenericNetworkManager<DBE
tokio::select! {
Some(event) = self.swarm.next() => self.handle_swarm_event(event),
_ = self.db_executor.run() => panic!("DB executor should never finish."),
Some(res) = self.query_results_router.next() => self.handle_query_result_routing_to_other_peer(res),
Some((protocol, query)) = self.sqmr_query_receivers.next() => {
Some(res) = self.sqmr_inbound_response_receivers.next() => self.handle_response_for_inbound_query(res),
Some((protocol, query)) = self.sqmr_outbound_query_receivers.next() => {
self.handle_local_sqmr_query(protocol, query)
}
Some((topic_hash, message)) = self.messages_to_broadcast_receivers.next() => {
Expand All @@ -89,9 +87,9 @@ impl<DBExecutorT: DBExecutorTrait, SwarmT: SwarmTrait> GenericNetworkManager<DBE
swarm,
db_executor,
header_buffer_size,
query_results_router: StreamCollection::new(),
sqmr_query_receivers: StreamHashMap::new(HashMap::new()),
sqmr_response_senders: HashMap::new(),
sqmr_inbound_response_receivers: StreamHashMap::new(HashMap::new()),
sqmr_outbound_query_receivers: StreamHashMap::new(HashMap::new()),
sqmr_outbound_response_senders: HashMap::new(),
messages_to_broadcast_receivers: StreamHashMap::new(HashMap::new()),
broadcasted_messages_senders: HashMap::new(),
outbound_session_id_to_protocol: HashMap::new(),
Expand Down Expand Up @@ -119,11 +117,11 @@ impl<DBExecutorT: DBExecutorTrait, SwarmT: SwarmTrait> GenericNetworkManager<DBE
let (response_sender, response_receiver) =
futures::channel::mpsc::channel(self.header_buffer_size);

let insert_result = self.sqmr_query_receivers.insert(protocol, query_receiver);
let insert_result = self.sqmr_outbound_query_receivers.insert(protocol, query_receiver);
if insert_result.is_some() {
panic!("Protocol '{}' has already been registered.", protocol);
}
let insert_result = self.sqmr_response_senders.insert(protocol, response_sender);
let insert_result = self.sqmr_outbound_response_senders.insert(protocol, response_sender);
if insert_result.is_some() {
panic!("Protocol '{}' has already been registered.", protocol);
}
Expand Down Expand Up @@ -307,8 +305,11 @@ impl<DBExecutorT: DBExecutorTrait, SwarmT: SwarmTrait> GenericNetworkManager<DBE
let internal_query = protocol.bytes_query_to_protobuf_request(query);
let data_type = DataType::from(protocol);
self.db_executor.register_query(internal_query, data_type, sender);
self.query_results_router
.push(receiver.map(move |data| (data, inbound_session_id)).boxed());
let response_fn: fn(Data) -> Option<Data> = Some;
self.sqmr_inbound_response_receivers.insert(
inbound_session_id,
receiver.map(response_fn).chain(stream::once(ready(None))),
);
}
sqmr::behaviour::ExternalEvent::ReceivedData { outbound_session_id, data, peer_id } => {
trace!(
Expand All @@ -320,7 +321,8 @@ impl<DBExecutorT: DBExecutorTrait, SwarmT: SwarmTrait> GenericNetworkManager<DBE
.get(&outbound_session_id)
.expect("Received data from an unknown session id");
let report_callback = self.create_external_callback_for_received_data(peer_id);
if let Some(response_sender) = self.sqmr_response_senders.get_mut(protocol) {
if let Some(response_sender) = self.sqmr_outbound_response_senders.get_mut(protocol)
{
// TODO(shahak): Implement the report callback, while removing code duplication
// with broadcast.
if let Err(error) = response_sender.try_send((data, report_callback)) {
Expand Down Expand Up @@ -380,30 +382,28 @@ impl<DBExecutorT: DBExecutorTrait, SwarmT: SwarmTrait> GenericNetworkManager<DBE
}
}

fn handle_query_result_routing_to_other_peer(&mut self, res: (Data, InboundSessionId)) {
if self.query_results_router.is_empty() {
// We're done handling all the queries we had and the stream is exhausted.
// Creating a new stream collection to process new queries.
self.query_results_router = StreamCollection::new();
}
let (data, inbound_session_id) = res;
let is_fin = matches!(data, Data::Fin(_));
let mut data_bytes = vec![];
data.encode(&mut data_bytes).expect("failed to encode data");
self.swarm.send_data(data_bytes, inbound_session_id).unwrap_or_else(|e| {
error!(
"Failed to send data to peer. Session id: {inbound_session_id:?} not found error: \
{e:?}"
);
});
if is_fin {
self.swarm.close_inbound_session(inbound_session_id).unwrap_or_else(|e| {
error!(
"Failed to close session after Fin. Session id: {inbound_session_id:?} not \
found error: {e:?}"
)
});
}
fn handle_response_for_inbound_query(&mut self, res: (InboundSessionId, Option<Data>)) {
let (inbound_session_id, maybe_data) = res;
match maybe_data {
Some(data) => {
let mut data_bytes = vec![];
data.encode(&mut data_bytes).expect("failed to encode data");
self.swarm.send_data(data_bytes, inbound_session_id).unwrap_or_else(|e| {
error!(
"Failed to send data to peer. Session id: {inbound_session_id:?} not \
found error: {e:?}"
);
});
}
None => {
self.swarm.close_inbound_session(inbound_session_id).unwrap_or_else(|e| {
error!(
"Failed to close session after sending all data. Session id: \
{inbound_session_id:?} not found error: {e:?}"
)
});
}
};
}

fn handle_local_sqmr_query(&mut self, protocol: Protocol, query: Bytes) {
Expand Down Expand Up @@ -531,3 +531,6 @@ pub struct BroadcastSubscriberChannels<T: TryFrom<Bytes>> {
pub messages_to_broadcast_sender: SubscriberSender<T>,
pub broadcasted_messages_receiver: SubscriberReceiver<T>,
}

type SqmrResponseReceiver<Response> =
Chain<Map<Receiver<Response>, fn(Response) -> Option<Response>>, Once<Ready<Option<Response>>>>;
1 change: 1 addition & 0 deletions crates/papyrus_network/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl<K: Unpin + Clone + Eq + Hash, V: Stream + Unpin> Stream for StreamHashMap<K
}
}
if finished {
// TODO(shahak): Make StreamHashMap not end in order to accept new inserted streams.
return Poll::Ready(None);
}
Poll::Pending
Expand Down
Loading