Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(network): remove protobuf from network_manager test #2109

Merged
merged 1 commit into from
Jun 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 27 additions & 60 deletions crates/papyrus_network/src/network_manager/test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// TODO(shahak): Remove protobuf from these tests.

use std::collections::{HashMap, HashSet};
use std::pin::Pin;
use std::sync::Arc;
Expand All @@ -14,19 +12,13 @@ use futures::channel::oneshot;
use futures::future::{pending, poll_fn, FutureExt};
use futures::stream::{FuturesUnordered, Stream};
use futures::{pin_mut, Future, SinkExt, StreamExt};
use lazy_static::lazy_static;
use libp2p::core::ConnectedPoint;
use libp2p::gossipsub::{SubscriptionError, TopicHash};
use libp2p::swarm::ConnectionId;
use libp2p::{Multiaddr, PeerId};
use papyrus_protobuf::protobuf;
use papyrus_protobuf::sync::{
BlockHashOrNumber,
DataOrFin,
Direction,
HeaderQuery,
Query,
SignedBlockHeader,
};
use papyrus_protobuf::sync::{BlockHashOrNumber, DataOrFin, Direction, Query, SignedBlockHeader};
use prost::Message;
use starknet_api::block::{BlockHeader, BlockNumber};
use tokio::select;
Expand All @@ -47,7 +39,6 @@ const TIMEOUT: Duration = Duration::from_secs(1);
#[derive(Default)]
struct MockSwarm {
pub pending_events: Queue<Event>,
pub sent_queries: Vec<(Query, PeerId)>,
pub subscribed_topics: HashSet<TopicHash>,
broadcasted_messages_senders: Vec<UnboundedSender<(Bytes, TopicHash)>>,
reported_peer_senders: Vec<UnboundedSender<PeerId>>,
Expand Down Expand Up @@ -101,32 +92,18 @@ impl MockSwarm {
receiver
}

fn create_received_data_events_for_query(
fn create_response_events_for_query_each_num_becomes_response(
&self,
query: Query,
query: Vec<u8>,
outbound_session_id: OutboundSessionId,
peer_id: PeerId,
) {
let BlockHashOrNumber::Number(BlockNumber(start_block_number)) = query.start_block else {
unimplemented!("test does not support start block as block hash")
};
let block_max_number = start_block_number + (query.step * query.limit);
for block_number in (start_block_number..block_max_number)
.step_by(query.step.try_into().expect("step too large to convert to usize"))
{
let signed_header = SignedBlockHeader {
block_header: BlockHeader {
block_number: BlockNumber(block_number),
..Default::default()
},
signatures: vec![],
};
let data_bytes =
protobuf::BlockHeadersResponse::from(Some(signed_header)).encode_to_vec();
for data in query {
self.pending_events.push(Event::Behaviour(mixed_behaviour::Event::ExternalEvent(
mixed_behaviour::ExternalEvent::Sqmr(GenericEvent::ReceivedData {
data: data_bytes,
data: vec![data],
outbound_session_id,
peer_id: PeerId::random(),
peer_id,
}),
)));
}
Expand Down Expand Up @@ -161,13 +138,12 @@ impl SwarmTrait for MockSwarm {
peer_id: PeerId,
_protocol: crate::Protocol,
) -> Result<OutboundSessionId, PeerNotConnected> {
let query: Query = protobuf::BlockHeadersRequest::decode(&query[..])
.expect("failed to decode protobuf BlockHeadersRequest")
.try_into()
.expect("failed to convert BlockHeadersRequest");
self.sent_queries.push((query.clone(), peer_id));
let outbound_session_id = OutboundSessionId { value: self.next_outbound_session_id };
self.create_received_data_events_for_query(query, outbound_session_id);
self.create_response_events_for_query_each_num_becomes_response(
query,
outbound_session_id,
peer_id,
);
self.next_outbound_session_id += 1;
Ok(outbound_session_id)
}
Expand Down Expand Up @@ -259,7 +235,7 @@ impl DBExecutorTrait for MockDBExecutor {
const BUFFER_SIZE: usize = 100;

#[tokio::test]
async fn register_subscriber_and_use_channels() {
async fn register_sqmr_subscriber_and_use_channels() {
// mock swarm to send and track connection established event
let mut mock_swarm = MockSwarm::default();
let peer_id = PeerId::random();
Expand All @@ -270,48 +246,35 @@ async fn register_subscriber_and_use_channels() {
// network manager to register subscriber and send query
let mut network_manager =
GenericNetworkManager::generic_new(mock_swarm, MockDBExecutor::default(), BUFFER_SIZE);
// define query
let query_limit: usize = 5;
let start_block_number = 0;
let query = Query {
start_block: BlockHashOrNumber::Number(BlockNumber(start_block_number)),
direction: Direction::Forward,
limit: query_limit.try_into().unwrap(),
step: 1,
};

// register subscriber and send query
let SqmrSubscriberChannels { mut query_sender, response_receiver } = network_manager
.register_sqmr_subscriber::<HeaderQuery, DataOrFin<SignedBlockHeader>>(
crate::Protocol::SignedBlockHeader,
);
.register_sqmr_subscriber::<Vec<u8>, Vec<u8>>(crate::Protocol::SignedBlockHeader);

let response_receiver_length = Arc::new(Mutex::new(0));
let cloned_response_receiver_length = Arc::clone(&response_receiver_length);
let response_receiver_collector = response_receiver
.enumerate()
.take(query_limit)
.map(|(i, (signed_block_header_result, _report_callback))| {
let signed_block_header = signed_block_header_result.unwrap();
assert_eq!(
signed_block_header.clone().0.unwrap().block_header.block_number.0,
i as u64
);
signed_block_header
.take(VEC.len())
.map(|(i, (result, _report_callback))| {
let result = result.unwrap();
// this simulates how the mock swarm parses the query and sends responses to it
assert_eq!(result, vec![VEC[i]]);
result
})
.collect::<Vec<_>>();
tokio::select! {
_ = network_manager.run() => panic!("network manager ended"),
_ = poll_fn(|cx| event_listner.poll_unpin(cx)).then(|_| async move {
query_sender.send(HeaderQuery(query)).await.unwrap()})
query_sender.send(VEC.clone()).await.unwrap()})
.then(|_| async move {
*cloned_response_receiver_length.lock().await = response_receiver_collector.await.len();
}) => {},
_ = sleep(Duration::from_secs(5)) => {
panic!("Test timed out");
}
}
assert_eq!(*response_receiver_length.lock().await, query_limit);
assert_eq!(*response_receiver_length.lock().await, VEC.len());
}

#[tokio::test]
Expand Down Expand Up @@ -528,3 +491,7 @@ fn get_test_connection_established_event(mock_peer_id: PeerId) -> Event {
established_in: Duration::from_secs(0),
}
}

lazy_static! {
static ref VEC: Vec<u8> = vec![1, 2, 3, 4, 5];
}
Loading