diff --git a/crates/papyrus_network/src/db_executor/mod.rs b/crates/papyrus_network/src/db_executor/mod.rs index eab277a45a..20543d9c9f 100644 --- a/crates/papyrus_network/src/db_executor/mod.rs +++ b/crates/papyrus_network/src/db_executor/mod.rs @@ -39,15 +39,14 @@ pub struct DataEncodingError; #[cfg_attr(test, derive(Debug, PartialEq, Eq))] #[derive(Clone)] pub enum Data { - BlockHeaderAndSignature(SignedBlockHeader), - StateDiffChunk(StateDiffChunk), - Fin(DataType), + BlockHeaderAndSignature(DataOrFin), + StateDiffChunk(DataOrFin), } impl Default for Data { fn default() -> Self { // TODO: consider this default data type. - Data::Fin(DataType::SignedBlockHeader) + Self::BlockHeaderAndSignature(DataOrFin(None)) } } @@ -57,33 +56,16 @@ impl Data { B: BufMut, { match self { - Data::BlockHeaderAndSignature(signed_block_header) => { - let data: protobuf::BlockHeadersResponse = Some(signed_block_header).into(); - data.encode(buf).map_err(|_| DataEncodingError) + Data::BlockHeaderAndSignature(maybe_signed_block_header) => { + let block_headers_response = + protobuf::BlockHeadersResponse::from(maybe_signed_block_header); + block_headers_response.encode(buf).map_err(|_| DataEncodingError) } - Data::StateDiffChunk(state_diff) => { - let state_diff_chunk = DataOrFin(Some(state_diff)); - let state_diffs_response = protobuf::StateDiffsResponse::from(state_diff_chunk); + Data::StateDiffChunk(maybe_state_diff_chunk) => { + let state_diffs_response = + protobuf::StateDiffsResponse::from(maybe_state_diff_chunk); state_diffs_response.encode(buf).map_err(|_| DataEncodingError) } - Data::Fin(data_type) => match data_type { - DataType::SignedBlockHeader => { - let block_header_response = protobuf::BlockHeadersResponse { - header_message: Some(protobuf::block_headers_response::HeaderMessage::Fin( - protobuf::Fin {}, - )), - }; - block_header_response.encode(buf).map_err(|_| DataEncodingError) - } - DataType::StateDiff => { - let state_diff_response = protobuf::StateDiffsResponse { - state_diff_message: Some( - protobuf::state_diffs_response::StateDiffMessage::Fin(protobuf::Fin {}), - ), - }; - state_diff_response.encode(buf).map_err(|_| DataEncodingError) - } - }, } } } @@ -215,10 +197,10 @@ impl FetchBlockDataFromDb for DataType { let signature = txn .get_block_signature(block_number)? .ok_or(DBExecutorError::SignatureNotFound { block_number })?; - Ok(vec![Data::BlockHeaderAndSignature(SignedBlockHeader { + Ok(vec![Data::BlockHeaderAndSignature(DataOrFin(Some(SignedBlockHeader { block_header: header, signatures: vec![signature], - })]) + })))]) } DataType::StateDiff => { let thin_state_diff = @@ -227,7 +209,7 @@ impl FetchBlockDataFromDb for DataType { })?; let vec_data = split_thin_state_diff(thin_state_diff) .into_iter() - .map(Data::StateDiffChunk) + .map(|state_diff_chunk| Data::StateDiffChunk(DataOrFin(Some(state_diff_chunk)))) .collect(); Ok(vec_data) } @@ -235,7 +217,10 @@ impl FetchBlockDataFromDb for DataType { } fn fin(&self) -> Data { - Data::Fin(*self) + match self { + DataType::SignedBlockHeader => Data::BlockHeaderAndSignature(DataOrFin(None)), + DataType::StateDiff => Data::StateDiffChunk(DataOrFin(None)), + } } } diff --git a/crates/papyrus_network/src/db_executor/test.rs b/crates/papyrus_network/src/db_executor/test.rs index a7dd8d7627..b5b85d9125 100644 --- a/crates/papyrus_network/src/db_executor/test.rs +++ b/crates/papyrus_network/src/db_executor/test.rs @@ -2,7 +2,7 @@ use futures::channel::mpsc::Receiver; use futures::stream::SelectAll; use futures::{FutureExt, StreamExt}; use papyrus_common::state::create_random_state_diff; -use papyrus_protobuf::sync::{BlockHashOrNumber, Direction, Query}; +use papyrus_protobuf::sync::{BlockHashOrNumber, DataOrFin, Direction, Query}; use papyrus_storage::header::{HeaderStorageReader, HeaderStorageWriter}; use papyrus_storage::state::StateStorageWriter; use papyrus_storage::test_utils::get_test_storage; @@ -65,17 +65,22 @@ async fn header_db_executor_can_register_and_run_a_query() { assert_eq!(len, NUM_OF_BLOCKS as usize + 1); } for (i, data) in data.into_iter().enumerate() { + match &data { + Data::BlockHeaderAndSignature(_) => { + assert_eq!(*requested_data_type, DataType::SignedBlockHeader); + } + Data::StateDiffChunk(_) => { + assert_eq!(*requested_data_type, DataType::StateDiff); + } + } match data { - Data::BlockHeaderAndSignature(signed_header) => { + Data::BlockHeaderAndSignature(DataOrFin(Some(signed_header))) => { assert_eq!(signed_header.block_header.block_number.0, i as u64); - assert_eq!(*requested_data_type, DataType::SignedBlockHeader); } - Data::StateDiffChunk (_state_diff) => { + Data::StateDiffChunk(DataOrFin(Some(_state_diff))) => { // TODO: check the state diff. - assert_eq!(*requested_data_type, DataType::StateDiff); } - Data::Fin(data_type) => { - assert_eq!(data_type, *requested_data_type); + _ => { assert_eq!(i, len - 1); } } @@ -123,10 +128,10 @@ async fn header_db_executor_start_block_given_by_hash() { assert_eq!(len, NUM_OF_BLOCKS as usize + 1); for (i, data) in res.into_iter().enumerate() { match data { - Data::BlockHeaderAndSignature(signed_header) => { + Data::BlockHeaderAndSignature(DataOrFin(Some(signed_header))) => { assert_eq!(signed_header.block_header.block_number.0, i as u64); } - Data::Fin(DataType::SignedBlockHeader) => assert_eq!(i, len - 1), + Data::BlockHeaderAndSignature(DataOrFin(None)) => assert_eq!(i, len - 1), _ => panic!("Unexpected data type"), }; } diff --git a/crates/papyrus_network/src/network_manager/test.rs b/crates/papyrus_network/src/network_manager/test.rs index 0c497f32e9..dec9c603b7 100644 --- a/crates/papyrus_network/src/network_manager/test.rs +++ b/crates/papyrus_network/src/network_manager/test.rs @@ -38,9 +38,9 @@ use super::swarm_trait::{Event, SwarmTrait}; use super::{GenericNetworkManager, SqmrSubscriberChannels}; use crate::db_executor::{DBExecutorError, DBExecutorTrait, Data, FetchBlockDataFromDb}; use crate::gossipsub_impl::{self, Topic}; +use crate::mixed_behaviour; use crate::sqmr::behaviour::{PeerNotConnected, SessionIdNotFoundError}; use crate::sqmr::{Bytes, GenericEvent, InboundSessionId, OutboundSessionId}; -use crate::{mixed_behaviour, DataType}; const TIMEOUT: Duration = Duration::from_secs(1); @@ -143,15 +143,12 @@ impl SwarmTrait for MockSwarm { .inbound_session_id_to_data_sender .get(&inbound_session_id) .expect("Called send_data without calling get_data_sent_to_inbound_session first"); - let decoded_data = - protobuf::BlockHeadersResponse::decode(&data[..]).unwrap().try_into().unwrap(); - let (data, is_fin) = match decoded_data { - Some(signed_block_header) => { - (Data::BlockHeaderAndSignature(signed_block_header), false) - } - None => (Data::Fin(DataType::SignedBlockHeader), true), - }; - data_sender.unbounded_send(data).unwrap(); + let data = DataOrFin::::try_from( + protobuf::BlockHeadersResponse::decode(&data[..]).unwrap(), + ) + .unwrap(); + let is_fin = data.0.is_none(); + data_sender.unbounded_send(Data::BlockHeaderAndSignature(data)).unwrap(); if is_fin { data_sender.close_channel(); } @@ -236,13 +233,12 @@ impl DBExecutorTrait for MockDBExecutor { for header in headers.iter().cloned() { // Using poll_fn because Sender::poll_ready is not a future if let Ok(()) = poll_fn(|cx| sender.poll_ready(cx)).await { - sender.start_send(Data::BlockHeaderAndSignature(SignedBlockHeader { - block_header: header, - signatures: vec![], - }))?; + sender.start_send(Data::BlockHeaderAndSignature(DataOrFin(Some( + SignedBlockHeader { block_header: header, signatures: vec![] }, + ))))?; } } - sender.start_send(Data::Fin(DataType::SignedBlockHeader))?; + sender.start_send(Data::BlockHeaderAndSignature(DataOrFin(None)))?; Ok(()) } })); @@ -372,12 +368,12 @@ async fn process_incoming_query() { let mut expected_data = headers .into_iter() .map(|header| { - Data::BlockHeaderAndSignature(SignedBlockHeader { + Data::BlockHeaderAndSignature(DataOrFin(Some(SignedBlockHeader { block_header: header, signatures: vec![] - }) + }))) }) .collect::>(); - expected_data.push(Data::Fin(DataType::SignedBlockHeader)); + expected_data.push(Data::BlockHeaderAndSignature(DataOrFin(None))); assert_eq!(inbound_session_data, expected_data); } _ = network_manager.run() => {