Skip to content

Commit

Permalink
refactor(network): remove protobuf from network_manager test
Browse files Browse the repository at this point in the history
  • Loading branch information
asmaastarkware committed Jun 17, 2024
1 parent 14ee3bb commit 000b74e
Showing 1 changed file with 28 additions and 62 deletions.
90 changes: 28 additions & 62 deletions crates/papyrus_network/src/network_manager/test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// TODO(shahak): Remove protobuf from these tests.

use std::collections::{HashMap, HashSet};
use std::pin::Pin;
use std::sync::Arc;
Expand All @@ -14,19 +12,13 @@ use futures::channel::oneshot;
use futures::future::{pending, poll_fn, FutureExt};
use futures::stream::{FuturesUnordered, Stream};
use futures::{pin_mut, Future, SinkExt, StreamExt};
use lazy_static::lazy_static;
use libp2p::core::ConnectedPoint;
use libp2p::gossipsub::{SubscriptionError, TopicHash};
use libp2p::swarm::ConnectionId;
use libp2p::{Multiaddr, PeerId};
use papyrus_protobuf::protobuf;
use papyrus_protobuf::sync::{
BlockHashOrNumber,
DataOrFin,
Direction,
HeaderQuery,
Query,
SignedBlockHeader,
};
use papyrus_protobuf::sync::{BlockHashOrNumber, DataOrFin, Direction, Query, SignedBlockHeader};
use prost::Message;
use starknet_api::block::{BlockHeader, BlockNumber};
use tokio::select;
Expand All @@ -47,7 +39,6 @@ const TIMEOUT: Duration = Duration::from_secs(1);
#[derive(Default)]
struct MockSwarm {
pub pending_events: Queue<Event>,
pub sent_queries: Vec<(Query, PeerId)>,
pub subscribed_topics: HashSet<TopicHash>,
broadcasted_messages_senders: Vec<UnboundedSender<(Bytes, TopicHash)>>,
reported_peer_senders: Vec<UnboundedSender<PeerId>>,
Expand Down Expand Up @@ -101,32 +92,18 @@ impl MockSwarm {
receiver
}

fn create_received_data_events_for_query(
fn create_response_events_for_query_each_num_becomes_response(
&self,
query: Query,
query: Vec<u8>,
outbound_session_id: OutboundSessionId,
peer_id: PeerId,
) {
let BlockHashOrNumber::Number(BlockNumber(start_block_number)) = query.start_block else {
unimplemented!("test does not support start block as block hash")
};
let block_max_number = start_block_number + (query.step * query.limit);
for block_number in (start_block_number..block_max_number)
.step_by(query.step.try_into().expect("step too large to convert to usize"))
{
let signed_header = SignedBlockHeader {
block_header: BlockHeader {
block_number: BlockNumber(block_number),
..Default::default()
},
signatures: vec![],
};
let data_bytes =
protobuf::BlockHeadersResponse::from(Some(signed_header)).encode_to_vec();
for data in query {
self.pending_events.push(Event::Behaviour(mixed_behaviour::Event::ExternalEvent(
mixed_behaviour::ExternalEvent::Sqmr(GenericEvent::ReceivedData {
data: data_bytes,
data: vec![data],
outbound_session_id,
peer_id: PeerId::random(),
peer_id,
}),
)));
}
Expand All @@ -143,8 +120,7 @@ impl SwarmTrait for MockSwarm {
.inbound_session_id_to_data_sender
.get(&inbound_session_id)
.expect("Called send_data without calling get_data_sent_to_inbound_session first");
let decoded_data =
protobuf::BlockHeadersResponse::decode(&data[..]).unwrap().try_into().unwrap();
let decoded_data = DataOrFin::try_from(data).unwrap().0;
let (data, is_fin) = match decoded_data {
Some(signed_block_header) => {
(Data::BlockHeaderAndSignature(signed_block_header), false)
Expand All @@ -164,13 +140,12 @@ impl SwarmTrait for MockSwarm {
peer_id: PeerId,
_protocol: crate::Protocol,
) -> Result<OutboundSessionId, PeerNotConnected> {
let query: Query = protobuf::BlockHeadersRequest::decode(&query[..])
.expect("failed to decode protobuf BlockHeadersRequest")
.try_into()
.expect("failed to convert BlockHeadersRequest");
self.sent_queries.push((query.clone(), peer_id));
let outbound_session_id = OutboundSessionId { value: self.next_outbound_session_id };
self.create_received_data_events_for_query(query, outbound_session_id);
self.create_response_events_for_query_each_num_becomes_response(
query,
outbound_session_id,
peer_id,
);
self.next_outbound_session_id += 1;
Ok(outbound_session_id)
}
Expand Down Expand Up @@ -262,7 +237,7 @@ impl DBExecutorTrait for MockDBExecutor {
const BUFFER_SIZE: usize = 100;

#[tokio::test]
async fn register_subscriber_and_use_channels() {
async fn register_sqmr_subscriber_and_use_channels() {
// mock swarm to send and track connection established event
let mut mock_swarm = MockSwarm::default();
let peer_id = PeerId::random();
Expand All @@ -273,48 +248,35 @@ async fn register_subscriber_and_use_channels() {
// network manager to register subscriber and send query
let mut network_manager =
GenericNetworkManager::generic_new(mock_swarm, MockDBExecutor::default(), BUFFER_SIZE);
// define query
let query_limit: usize = 5;
let start_block_number = 0;
let query = Query {
start_block: BlockHashOrNumber::Number(BlockNumber(start_block_number)),
direction: Direction::Forward,
limit: query_limit.try_into().unwrap(),
step: 1,
};

// register subscriber and send query
let SqmrSubscriberChannels { mut query_sender, response_receiver } = network_manager
.register_sqmr_subscriber::<HeaderQuery, DataOrFin<SignedBlockHeader>>(
crate::Protocol::SignedBlockHeader,
);
.register_sqmr_subscriber::<Vec<u8>, Vec<u8>>(crate::Protocol::SignedBlockHeader);

let response_receiver_length = Arc::new(Mutex::new(0));
let cloned_response_receiver_length = Arc::clone(&response_receiver_length);
let response_receiver_collector = response_receiver
.enumerate()
.take(query_limit)
.map(|(i, (signed_block_header_result, _report_callback))| {
let signed_block_header = signed_block_header_result.unwrap();
assert_eq!(
signed_block_header.clone().0.unwrap().block_header.block_number.0,
i as u64
);
signed_block_header
.take(VEC.len())
.map(|(i, (result, _report_callback))| {
let result = result.unwrap();
// this simulates how the mock swarm parses the query and sends responses to it
assert_eq!(result, vec![VEC[i]]);
result
})
.collect::<Vec<_>>();
tokio::select! {
_ = network_manager.run() => panic!("network manager ended"),
_ = poll_fn(|cx| event_listner.poll_unpin(cx)).then(|_| async move {
query_sender.send(HeaderQuery(query)).await.unwrap()})
query_sender.send(VEC.clone()).await.unwrap()})
.then(|_| async move {
*cloned_response_receiver_length.lock().await = response_receiver_collector.await.len();
}) => {},
_ = sleep(Duration::from_secs(5)) => {
panic!("Test timed out");
}
}
assert_eq!(*response_receiver_length.lock().await, query_limit);
assert_eq!(*response_receiver_length.lock().await, VEC.len());
}

#[tokio::test]
Expand Down Expand Up @@ -531,3 +493,7 @@ fn get_test_connection_established_event(mock_peer_id: PeerId) -> Event {
established_in: Duration::from_secs(0),
}
}

lazy_static! {
static ref VEC: Vec<u8> = vec![1, 2, 3, 4, 5];
}

0 comments on commit 000b74e

Please sign in to comment.