Skip to content

Commit

Permalink
refactor(network): use DataOrFin in db executor's Data
Browse files Browse the repository at this point in the history
  • Loading branch information
ShahakShama committed Jun 17, 2024
1 parent 25c2370 commit eb6b58e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 60 deletions.
49 changes: 17 additions & 32 deletions crates/papyrus_network/src/db_executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ pub struct DataEncodingError;
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
#[derive(Clone)]
pub enum Data {
BlockHeaderAndSignature(SignedBlockHeader),
StateDiffChunk(StateDiffChunk),
Fin(DataType),
BlockHeaderAndSignature(DataOrFin<SignedBlockHeader>),
StateDiffChunk(DataOrFin<StateDiffChunk>),
}

impl Default for Data {
fn default() -> Self {
// TODO: consider this default data type.
Data::Fin(DataType::SignedBlockHeader)
Self::BlockHeaderAndSignature(DataOrFin(None))
}
}

Expand All @@ -57,33 +56,16 @@ impl Data {
B: BufMut,
{
match self {
Data::BlockHeaderAndSignature(signed_block_header) => {
let data: protobuf::BlockHeadersResponse = Some(signed_block_header).into();
data.encode(buf).map_err(|_| DataEncodingError)
Data::BlockHeaderAndSignature(maybe_signed_block_header) => {
let block_headers_response =
protobuf::BlockHeadersResponse::from(maybe_signed_block_header);
block_headers_response.encode(buf).map_err(|_| DataEncodingError)
}
Data::StateDiffChunk(state_diff) => {
let state_diff_chunk = DataOrFin(Some(state_diff));
let state_diffs_response = protobuf::StateDiffsResponse::from(state_diff_chunk);
Data::StateDiffChunk(maybe_state_diff_chunk) => {
let state_diffs_response =
protobuf::StateDiffsResponse::from(maybe_state_diff_chunk);
state_diffs_response.encode(buf).map_err(|_| DataEncodingError)
}
Data::Fin(data_type) => match data_type {
DataType::SignedBlockHeader => {
let block_header_response = protobuf::BlockHeadersResponse {
header_message: Some(protobuf::block_headers_response::HeaderMessage::Fin(
protobuf::Fin {},
)),
};
block_header_response.encode(buf).map_err(|_| DataEncodingError)
}
DataType::StateDiff => {
let state_diff_response = protobuf::StateDiffsResponse {
state_diff_message: Some(
protobuf::state_diffs_response::StateDiffMessage::Fin(protobuf::Fin {}),
),
};
state_diff_response.encode(buf).map_err(|_| DataEncodingError)
}
},
}
}
}
Expand Down Expand Up @@ -215,10 +197,10 @@ impl FetchBlockDataFromDb for DataType {
let signature = txn
.get_block_signature(block_number)?
.ok_or(DBExecutorError::SignatureNotFound { block_number })?;
Ok(vec![Data::BlockHeaderAndSignature(SignedBlockHeader {
Ok(vec![Data::BlockHeaderAndSignature(DataOrFin(Some(SignedBlockHeader {
block_header: header,
signatures: vec![signature],
})])
})))])
}
DataType::StateDiff => {
let thin_state_diff =
Expand All @@ -227,15 +209,18 @@ impl FetchBlockDataFromDb for DataType {
})?;
let vec_data = split_thin_state_diff(thin_state_diff)
.into_iter()
.map(Data::StateDiffChunk)
.map(|state_diff_chunk| Data::StateDiffChunk(DataOrFin(Some(state_diff_chunk))))
.collect();
Ok(vec_data)
}
}
}

fn fin(&self) -> Data {
Data::Fin(*self)
match self {
DataType::SignedBlockHeader => Data::BlockHeaderAndSignature(DataOrFin(None)),
DataType::StateDiff => Data::StateDiffChunk(DataOrFin(None)),
}
}
}

Expand Down
23 changes: 14 additions & 9 deletions crates/papyrus_network/src/db_executor/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use futures::channel::mpsc::Receiver;
use futures::stream::SelectAll;
use futures::{FutureExt, StreamExt};
use papyrus_common::state::create_random_state_diff;
use papyrus_protobuf::sync::{BlockHashOrNumber, Direction, Query};
use papyrus_protobuf::sync::{BlockHashOrNumber, DataOrFin, Direction, Query};
use papyrus_storage::header::{HeaderStorageReader, HeaderStorageWriter};
use papyrus_storage::state::StateStorageWriter;
use papyrus_storage::test_utils::get_test_storage;
Expand Down Expand Up @@ -65,17 +65,22 @@ async fn header_db_executor_can_register_and_run_a_query() {
assert_eq!(len, NUM_OF_BLOCKS as usize + 1);
}
for (i, data) in data.into_iter().enumerate() {
match &data {
Data::BlockHeaderAndSignature(_) => {
assert_eq!(*requested_data_type, DataType::SignedBlockHeader);
}
Data::StateDiffChunk(_) => {
assert_eq!(*requested_data_type, DataType::StateDiff);
}
}
match data {
Data::BlockHeaderAndSignature(signed_header) => {
Data::BlockHeaderAndSignature(DataOrFin(Some(signed_header))) => {
assert_eq!(signed_header.block_header.block_number.0, i as u64);
assert_eq!(*requested_data_type, DataType::SignedBlockHeader);
}
Data::StateDiffChunk (_state_diff) => {
Data::StateDiffChunk(DataOrFin(Some(_state_diff))) => {
// TODO: check the state diff.
assert_eq!(*requested_data_type, DataType::StateDiff);
}
Data::Fin(data_type) => {
assert_eq!(data_type, *requested_data_type);
_ => {
assert_eq!(i, len - 1);
}
}
Expand Down Expand Up @@ -123,10 +128,10 @@ async fn header_db_executor_start_block_given_by_hash() {
assert_eq!(len, NUM_OF_BLOCKS as usize + 1);
for (i, data) in res.into_iter().enumerate() {
match data {
Data::BlockHeaderAndSignature(signed_header) => {
Data::BlockHeaderAndSignature(DataOrFin(Some(signed_header))) => {
assert_eq!(signed_header.block_header.block_number.0, i as u64);
}
Data::Fin(DataType::SignedBlockHeader) => assert_eq!(i, len - 1),
Data::BlockHeaderAndSignature(DataOrFin(None)) => assert_eq!(i, len - 1),
_ => panic!("Unexpected data type"),
};
}
Expand Down
6 changes: 5 additions & 1 deletion crates/papyrus_network/src/network_manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use libp2p::swarm::SwarmEvent;
use libp2p::{PeerId, Swarm};
use metrics::gauge;
use papyrus_common::metrics as papyrus_metrics;
use papyrus_protobuf::sync::DataOrFin;
use papyrus_storage::StorageReader;
use sqmr::Bytes;
use tracing::{debug, error, info, trace};
Expand Down Expand Up @@ -387,7 +388,10 @@ impl<DBExecutorT: DBExecutorTrait, SwarmT: SwarmTrait> GenericNetworkManager<DBE
self.query_results_router = StreamCollection::new();
}
let (data, inbound_session_id) = res;
let is_fin = matches!(data, Data::Fin(_));
let is_fin = matches!(
data,
Data::BlockHeaderAndSignature(DataOrFin(None)) | Data::StateDiffChunk(DataOrFin(None))
);
let mut data_bytes = vec![];
data.encode(&mut data_bytes).expect("failed to encode data");
self.swarm.send_data(data_bytes, inbound_session_id).unwrap_or_else(|e| {
Expand Down
32 changes: 14 additions & 18 deletions crates/papyrus_network/src/network_manager/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ use super::swarm_trait::{Event, SwarmTrait};
use super::{GenericNetworkManager, SqmrSubscriberChannels};
use crate::db_executor::{DBExecutorError, DBExecutorTrait, Data, FetchBlockDataFromDb};
use crate::gossipsub_impl::{self, Topic};
use crate::mixed_behaviour;
use crate::sqmr::behaviour::{PeerNotConnected, SessionIdNotFoundError};
use crate::sqmr::{Bytes, GenericEvent, InboundSessionId, OutboundSessionId};
use crate::{mixed_behaviour, DataType};

const TIMEOUT: Duration = Duration::from_secs(1);

Expand Down Expand Up @@ -143,15 +143,12 @@ impl SwarmTrait for MockSwarm {
.inbound_session_id_to_data_sender
.get(&inbound_session_id)
.expect("Called send_data without calling get_data_sent_to_inbound_session first");
let decoded_data =
protobuf::BlockHeadersResponse::decode(&data[..]).unwrap().try_into().unwrap();
let (data, is_fin) = match decoded_data {
Some(signed_block_header) => {
(Data::BlockHeaderAndSignature(signed_block_header), false)
}
None => (Data::Fin(DataType::SignedBlockHeader), true),
};
data_sender.unbounded_send(data).unwrap();
let data = DataOrFin::<SignedBlockHeader>::try_from(
protobuf::BlockHeadersResponse::decode(&data[..]).unwrap(),
)
.unwrap();
let is_fin = data.0.is_none();
data_sender.unbounded_send(Data::BlockHeaderAndSignature(data)).unwrap();
if is_fin {
data_sender.close_channel();
}
Expand Down Expand Up @@ -236,13 +233,12 @@ impl DBExecutorTrait for MockDBExecutor {
for header in headers.iter().cloned() {
// Using poll_fn because Sender::poll_ready is not a future
if let Ok(()) = poll_fn(|cx| sender.poll_ready(cx)).await {
sender.start_send(Data::BlockHeaderAndSignature(SignedBlockHeader {
block_header: header,
signatures: vec![],
}))?;
sender.start_send(Data::BlockHeaderAndSignature(DataOrFin(Some(
SignedBlockHeader { block_header: header, signatures: vec![] },
))))?;
}
}
sender.start_send(Data::Fin(DataType::SignedBlockHeader))?;
sender.start_send(Data::BlockHeaderAndSignature(DataOrFin(None)))?;
Ok(())
}
}));
Expand Down Expand Up @@ -372,12 +368,12 @@ async fn process_incoming_query() {
let mut expected_data = headers
.into_iter()
.map(|header| {
Data::BlockHeaderAndSignature(SignedBlockHeader {
Data::BlockHeaderAndSignature(DataOrFin(Some(SignedBlockHeader {
block_header: header, signatures: vec![]
})
})))
})
.collect::<Vec<_>>();
expected_data.push(Data::Fin(DataType::SignedBlockHeader));
expected_data.push(Data::BlockHeaderAndSignature(DataOrFin(None)));
assert_eq!(inbound_session_data, expected_data);
}
_ = network_manager.run() => {
Expand Down

0 comments on commit eb6b58e

Please sign in to comment.