diff --git a/crates/papyrus_network/src/network_manager/test.rs b/crates/papyrus_network/src/network_manager/test.rs index dec9c603b7..9bb0ecd38d 100644 --- a/crates/papyrus_network/src/network_manager/test.rs +++ b/crates/papyrus_network/src/network_manager/test.rs @@ -1,5 +1,3 @@ -// TODO(shahak): Remove protobuf from these tests. - use std::collections::{HashMap, HashSet}; use std::pin::Pin; use std::sync::Arc; @@ -14,19 +12,13 @@ use futures::channel::oneshot; use futures::future::{pending, poll_fn, FutureExt}; use futures::stream::{FuturesUnordered, Stream}; use futures::{pin_mut, Future, SinkExt, StreamExt}; +use lazy_static::lazy_static; use libp2p::core::ConnectedPoint; use libp2p::gossipsub::{SubscriptionError, TopicHash}; use libp2p::swarm::ConnectionId; use libp2p::{Multiaddr, PeerId}; use papyrus_protobuf::protobuf; -use papyrus_protobuf::sync::{ - BlockHashOrNumber, - DataOrFin, - Direction, - HeaderQuery, - Query, - SignedBlockHeader, -}; +use papyrus_protobuf::sync::{BlockHashOrNumber, DataOrFin, Direction, Query, SignedBlockHeader}; use prost::Message; use starknet_api::block::{BlockHeader, BlockNumber}; use tokio::select; @@ -47,7 +39,6 @@ const TIMEOUT: Duration = Duration::from_secs(1); #[derive(Default)] struct MockSwarm { pub pending_events: Queue, - pub sent_queries: Vec<(Query, PeerId)>, pub subscribed_topics: HashSet, broadcasted_messages_senders: Vec>, reported_peer_senders: Vec>, @@ -101,32 +92,18 @@ impl MockSwarm { receiver } - fn create_received_data_events_for_query( + fn create_response_events_for_query_each_num_becomes_response( &self, - query: Query, + query: Vec, outbound_session_id: OutboundSessionId, + peer_id: PeerId, ) { - let BlockHashOrNumber::Number(BlockNumber(start_block_number)) = query.start_block else { - unimplemented!("test does not support start block as block hash") - }; - let block_max_number = start_block_number + (query.step * query.limit); - for block_number in (start_block_number..block_max_number) - .step_by(query.step.try_into().expect("step too large to convert to usize")) - { - let signed_header = SignedBlockHeader { - block_header: BlockHeader { - block_number: BlockNumber(block_number), - ..Default::default() - }, - signatures: vec![], - }; - let data_bytes = - protobuf::BlockHeadersResponse::from(Some(signed_header)).encode_to_vec(); + for data in query { self.pending_events.push(Event::Behaviour(mixed_behaviour::Event::ExternalEvent( mixed_behaviour::ExternalEvent::Sqmr(GenericEvent::ReceivedData { - data: data_bytes, + data: vec![data], outbound_session_id, - peer_id: PeerId::random(), + peer_id, }), ))); } @@ -161,13 +138,12 @@ impl SwarmTrait for MockSwarm { peer_id: PeerId, _protocol: crate::Protocol, ) -> Result { - let query: Query = protobuf::BlockHeadersRequest::decode(&query[..]) - .expect("failed to decode protobuf BlockHeadersRequest") - .try_into() - .expect("failed to convert BlockHeadersRequest"); - self.sent_queries.push((query.clone(), peer_id)); let outbound_session_id = OutboundSessionId { value: self.next_outbound_session_id }; - self.create_received_data_events_for_query(query, outbound_session_id); + self.create_response_events_for_query_each_num_becomes_response( + query, + outbound_session_id, + peer_id, + ); self.next_outbound_session_id += 1; Ok(outbound_session_id) } @@ -259,7 +235,7 @@ impl DBExecutorTrait for MockDBExecutor { const BUFFER_SIZE: usize = 100; #[tokio::test] -async fn register_subscriber_and_use_channels() { +async fn register_sqmr_subscriber_and_use_channels() { // mock swarm to send and track connection established event let mut mock_swarm = MockSwarm::default(); let peer_id = PeerId::random(); @@ -270,40 +246,27 @@ async fn register_subscriber_and_use_channels() { // network manager to register subscriber and send query let mut network_manager = GenericNetworkManager::generic_new(mock_swarm, MockDBExecutor::default(), BUFFER_SIZE); - // define query - let query_limit: usize = 5; - let start_block_number = 0; - let query = Query { - start_block: BlockHashOrNumber::Number(BlockNumber(start_block_number)), - direction: Direction::Forward, - limit: query_limit.try_into().unwrap(), - step: 1, - }; // register subscriber and send query let SqmrSubscriberChannels { mut query_sender, response_receiver } = network_manager - .register_sqmr_subscriber::>( - crate::Protocol::SignedBlockHeader, - ); + .register_sqmr_subscriber::, Vec>(crate::Protocol::SignedBlockHeader); let response_receiver_length = Arc::new(Mutex::new(0)); let cloned_response_receiver_length = Arc::clone(&response_receiver_length); let response_receiver_collector = response_receiver .enumerate() - .take(query_limit) - .map(|(i, (signed_block_header_result, _report_callback))| { - let signed_block_header = signed_block_header_result.unwrap(); - assert_eq!( - signed_block_header.clone().0.unwrap().block_header.block_number.0, - i as u64 - ); - signed_block_header + .take(VEC.len()) + .map(|(i, (result, _report_callback))| { + let result = result.unwrap(); + // this simulates how the mock swarm parses the query and sends responses to it + assert_eq!(result, vec![VEC[i]]); + result }) .collect::>(); tokio::select! { _ = network_manager.run() => panic!("network manager ended"), _ = poll_fn(|cx| event_listner.poll_unpin(cx)).then(|_| async move { - query_sender.send(HeaderQuery(query)).await.unwrap()}) + query_sender.send(VEC.clone()).await.unwrap()}) .then(|_| async move { *cloned_response_receiver_length.lock().await = response_receiver_collector.await.len(); }) => {}, @@ -311,7 +274,7 @@ async fn register_subscriber_and_use_channels() { panic!("Test timed out"); } } - assert_eq!(*response_receiver_length.lock().await, query_limit); + assert_eq!(*response_receiver_length.lock().await, VEC.len()); } #[tokio::test] @@ -528,3 +491,7 @@ fn get_test_connection_established_event(mock_peer_id: PeerId) -> Event { established_in: Duration::from_secs(0), } } + +lazy_static! { + static ref VEC: Vec = vec![1, 2, 3, 4, 5]; +}