Skip to content

Commit

Permalink
feat(network): remove stream in DBExecutorTrait and return Result (#2089
Browse files Browse the repository at this point in the history
)

feat(network): remove stream in DBExecutorTrait
  • Loading branch information
asmaastarkware authored Jun 10, 2024
1 parent b6f3076 commit 8dfbf74
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 194 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/papyrus_network/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ thiserror.workspace = true
tokio = { workspace = true, features = ["full", "sync"] }
tracing.workspace = true
unsigned-varint = { workspace = true, features = ["std"] }
async-trait.workspace = true

# Binaries dependencies
clap = { workspace = true, optional = true, features = ["derive"] }
Expand Down
71 changes: 28 additions & 43 deletions crates/papyrus_network/src/db_executor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use std::pin::Pin;
use std::task::Poll;
use std::vec;

use async_trait::async_trait;
use bytes::BufMut;
use derive_more::Display;
use futures::channel::mpsc::Sender;
use futures::future::poll_fn;
use futures::stream::FuturesUnordered;
use futures::{Stream, StreamExt};
use futures::future::{pending, poll_fn};
#[cfg(test)]
use mockall::automock;
use papyrus_protobuf::converters::common::volition_domain_to_enum_int;
Expand All @@ -29,7 +26,7 @@ use papyrus_storage::{db, StorageReader, StorageTxn};
use prost::Message;
use starknet_api::block::BlockNumber;
use starknet_api::state::ThinStateDiff;
use tokio::task::JoinHandle;
use tracing::error;

use crate::DataType;

Expand Down Expand Up @@ -182,32 +179,36 @@ impl DBExecutorError {
}
}

/// DBExecutorTrait is a stream of queries. Each result is marks the end of a query fulfillment.
/// A query can either succeed (and return Ok(QueryId)) or fail (and return Err(DBExecutorError)).
/// The stream is never exhausted, and it is the responsibility of the user to poll it.
pub trait DBExecutorTrait: Stream<Item = Result<QueryId, DBExecutorError>> + Unpin {
// TODO: add writer functionality
/// A DBExecutor receives inbound queries and returns their corresponding data.
#[async_trait]
pub trait DBExecutorTrait {
/// Send a query to be executed in the DBExecutor. The query will be run concurrently with the
/// calling code and the result will be over the given channel.
fn register_query(
&mut self,
query: Query,
data_type: impl FetchBlockDataFromDb + Send + 'static,
sender: Sender<Vec<Data>>,
// TODO(shahak): Remove QueryId.
) -> QueryId;

/// Polls incoming queries.
// TODO(shahak): Consume self.
async fn run(&mut self);
}

// TODO: currently this executor returns only block headers and signatures.
pub struct DBExecutor {
next_query_id: usize,
storage_reader: StorageReader,
query_execution_set: FuturesUnordered<JoinHandle<Result<QueryId, DBExecutorError>>>,
}

impl DBExecutor {
pub fn new(storage_reader: StorageReader) -> Self {
Self { next_query_id: 0, storage_reader, query_execution_set: FuturesUnordered::new() }
Self { next_query_id: 0, storage_reader }
}
}

#[async_trait]
impl DBExecutorTrait for DBExecutor {
fn register_query(
&mut self,
Expand All @@ -218,8 +219,8 @@ impl DBExecutorTrait for DBExecutor {
let query_id = QueryId(self.next_query_id);
self.next_query_id += 1;
let storage_reader_clone = self.storage_reader.clone();
self.query_execution_set.push(tokio::task::spawn(async move {
{
tokio::task::spawn(async move {
let result: Result<QueryId, DBExecutorError> = {
let txn = storage_reader_clone.begin_ro_txn().map_err(|err| {
DBExecutorError::DBInternalError { query_id, storage_error: err }
})?;
Expand Down Expand Up @@ -261,37 +262,21 @@ impl DBExecutorTrait for DBExecutor {
}
}
Ok(query_id)
};
if let Err(error) = &result {
if error.should_log_in_error_level() {
error!("Running inbound query {query:?} failed on {error:?}");
}
}
}));
result
});
query_id
}
}

impl Stream for DBExecutor {
type Item = Result<QueryId, DBExecutorError>;

fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
poll_query_execution_set(&mut Pin::into_inner(self).query_execution_set, cx)
}
}

pub(crate) fn poll_query_execution_set(
query_execution_set: &mut FuturesUnordered<JoinHandle<Result<QueryId, DBExecutorError>>>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<QueryId, DBExecutorError>>> {
match query_execution_set.poll_next_unpin(cx) {
Poll::Ready(Some(join_result)) => {
let res = join_result?;
Poll::Ready(Some(res))
}
Poll::Ready(None) => {
*query_execution_set = FuturesUnordered::new();
Poll::Pending
}
Poll::Pending => Poll::Pending,
async fn run(&mut self) {
// TODO(shahak): Parse incoming queries once we receive them through channel instead of
// through function.
pending::<()>().await
}
}

Expand Down
142 changes: 29 additions & 113 deletions crates/papyrus_network/src/db_executor/test.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use std::task::Poll;

use assert_matches::assert_matches;
use futures::channel::mpsc::Receiver;
use futures::future::poll_fn;
use futures::stream::SelectAll;
use futures::{FutureExt, StreamExt};
use papyrus_common::state::create_random_state_diff;
use papyrus_protobuf::sync::{BlockHashOrNumber, Direction, Query, SignedBlockHeader};
use papyrus_protobuf::sync::{BlockHashOrNumber, Direction, Query};
use papyrus_storage::header::{HeaderStorageReader, HeaderStorageWriter};
use papyrus_storage::state::StateStorageWriter;
use papyrus_storage::test_utils::get_test_storage;
Expand All @@ -15,14 +11,7 @@ use rand::random;
use starknet_api::block::{BlockHash, BlockHeader, BlockNumber, BlockSignature};
use test_utils::get_rng;

use super::Data::BlockHeaderAndSignature;
use crate::db_executor::{
DBExecutorError,
DBExecutorTrait,
Data,
MockFetchBlockDataFromDb,
QueryId,
};
use crate::db_executor::{DBExecutorError, DBExecutorTrait, Data, MockFetchBlockDataFromDb};
use crate::DataType;

const BUFFER_SIZE: usize = 10;
Expand All @@ -44,14 +33,13 @@ async fn header_db_executor_can_register_and_run_a_query() {
step: 1,
};
type ReceiversType = Vec<(Receiver<Vec<Data>>, DataType)>;
let (query_ids, mut receivers): (Vec<QueryId>, ReceiversType) =
enum_iterator::all::<DataType>()
.map(|data_type| {
let (sender, receiver) = futures::channel::mpsc::channel(BUFFER_SIZE);
let query_id = db_executor.register_query(query.clone(), data_type, sender);
(query_id, (receiver, data_type))
})
.unzip();
let mut receivers: ReceiversType = enum_iterator::all::<DataType>()
.map(|data_type| {
let (sender, receiver) = futures::channel::mpsc::channel(BUFFER_SIZE);
db_executor.register_query(query.clone(), data_type, sender);
(receiver, data_type)
})
.collect();
let mut receivers_stream = SelectAll::new();
receivers
.iter_mut()
Expand All @@ -66,20 +54,18 @@ async fn header_db_executor_can_register_and_run_a_query() {

// run the executor and collect query results.
tokio::select! {
res = db_executor.next() => {
let poll_res = res.unwrap();
let res_query_id = poll_res.unwrap();
assert!(query_ids.iter().any(|query_id| query_id == &res_query_id));
_ = db_executor.run() => {
panic!("DB executor should never finish its run.");
},
_ = async {
while let Some(res) = receivers_stream.next().await {
let (data, requested_data_type) = res.await;
assert_eq!(data.len(), NUM_OF_BLOCKS as usize);
for (i, data) in data.iter().enumerate() {
for (i, data) in data.into_iter().enumerate() {
for data in data.iter() {
match data {
Data::BlockHeaderAndSignature(SignedBlockHeader { block_header: BlockHeader { block_number: BlockNumber(block_number), .. }, .. }) => {
assert_eq!(block_number, &(i as u64));
Data::BlockHeaderAndSignature(signed_header) => {
assert_eq!(signed_header.block_header.block_number.0, i as u64);
assert_eq!(*requested_data_type, DataType::SignedBlockHeader);
}
Data::StateDiffChunk (_state_diff) => {
Expand Down Expand Up @@ -121,23 +107,27 @@ async fn header_db_executor_start_block_given_by_hash() {
limit: NUM_OF_BLOCKS,
step: 1,
};
let query_id = db_executor.register_query(query, DataType::SignedBlockHeader, sender);
db_executor.register_query(query, DataType::SignedBlockHeader, sender);

// run the executor and collect query results.
tokio::select! {
res = db_executor.next() => {
let poll_res = res.unwrap();
let res_query_id = poll_res.unwrap();
assert_eq!(res_query_id, query_id);
}
_ = db_executor.run() => {
panic!("DB executor should never finish its run.");
},
res = receiver.collect::<Vec<_>>() => {
assert_eq!(res.len(), NUM_OF_BLOCKS as usize);
for (i, data) in res.iter().enumerate() {
assert_matches!(data.first().unwrap(), BlockHeaderAndSignature(SignedBlockHeader{block_header: BlockHeader { block_number: BlockNumber(block_number), .. }, ..}) if block_number == &(i as u64));
for (i, data) in res.into_iter().enumerate() {
for data in data.iter() {
let Data::BlockHeaderAndSignature(signed_header) = data else {
panic!("Unexpected data type");
};
assert_eq!(signed_header.block_header.block_number.0, i as u64);
}
}
}
}
}

#[tokio::test]
async fn header_db_executor_query_of_missing_block() {
let ((storage_reader, mut storage_writer), _temp_dir) = get_test_storage();
Expand Down Expand Up @@ -171,89 +161,15 @@ async fn header_db_executor_query_of_missing_block() {
let _query_id = db_executor.register_query(query, mock_data_type, sender);

tokio::select! {
res = db_executor.next() => {
let poll_res = res.unwrap();
assert_matches!(poll_res, Err(DBExecutorError::BlockNotFound{..}));
}
_ = db_executor.run() => {
panic!("DB executor should never finish its run.");
},
res = receiver.collect::<Vec<_>>() => {
assert_eq!(res.len(), (BLOCKS_DELTA) as usize);
}
}
}

#[test]
fn header_db_executor_stream_pending_with_no_query() {
let ((storage_reader, _), _temp_dir) = get_test_storage();
let mut db_executor = super::DBExecutor::new(storage_reader);

// poll without registering a query.
assert!(poll_fn(|cx| db_executor.poll_next_unpin(cx)).now_or_never().is_none());
}

#[tokio::test]
async fn header_db_executor_can_receive_queries_after_stream_is_exhausted() {
let ((storage_reader, mut storage_writer), _temp_dir) = get_test_storage();
let mut db_executor = super::DBExecutor::new(storage_reader);

const NUM_OF_BLOCKS: u64 = 10;
insert_to_storage_test_blocks_up_to(NUM_OF_BLOCKS, &mut storage_writer);

for _ in 0..2 {
// register a query.
let (sender, receiver) = futures::channel::mpsc::channel(BUFFER_SIZE);
let query = Query {
start_block: BlockHashOrNumber::Number(BlockNumber(0)),
direction: Direction::Forward,
limit: NUM_OF_BLOCKS,
step: 1,
};
let mut mock_data_type = MockFetchBlockDataFromDb::new();
mock_data_type
.expect_fetch_block_data_from_db()
.times(NUM_OF_BLOCKS as usize)
.returning(|_, _, _| Ok(vec![Data::default()]));
let query_id = db_executor.register_query(query, mock_data_type, sender);

// run the executor and collect query results.
receiver.collect::<Vec<_>>().await;
let res = db_executor.next().await;
assert_eq!(res.unwrap().unwrap(), query_id);

// make sure the stream is pending.
let res = poll_fn(|cx| match db_executor.poll_next_unpin(cx) {
Poll::Pending => Poll::Ready(Ok(())),
Poll::Ready(ready) => Poll::Ready(Err(ready)),
})
.await;
assert!(res.is_ok());
}
}

#[tokio::test]
async fn header_db_executor_drop_receiver_before_query_is_done() {
let ((storage_reader, mut storage_writer), _temp_dir) = get_test_storage();
let mut db_executor = super::DBExecutor::new(storage_reader);

const NUM_OF_BLOCKS: u64 = 10;
insert_to_storage_test_blocks_up_to(NUM_OF_BLOCKS, &mut storage_writer);

let (sender, receiver) = futures::channel::mpsc::channel(BUFFER_SIZE);
let query = Query {
start_block: BlockHashOrNumber::Number(BlockNumber(1)),
direction: Direction::Forward,
limit: NUM_OF_BLOCKS,
step: 1,
};
drop(receiver);

// register a query.
let _query_id = db_executor.register_query(query, MockFetchBlockDataFromDb::new(), sender);

// executor should return an error.
let res = db_executor.next().await;
assert!(res.unwrap().is_err());
}

fn insert_to_storage_test_blocks_up_to(num_of_blocks: u64, storage_writer: &mut StorageWriter) {
let mut rng = get_rng();
let thin_state_diffs =
Expand Down
23 changes: 2 additions & 21 deletions crates/papyrus_network/src/network_manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use tracing::{debug, error, info, trace};

use self::swarm_trait::SwarmTrait;
use crate::bin_utils::build_swarm;
use crate::db_executor::{self, DBExecutor, DBExecutorTrait, Data, QueryId};
use crate::db_executor::{DBExecutor, DBExecutorTrait, Data, QueryId};
use crate::gossipsub_impl::Topic;
use crate::mixed_behaviour::{self, BridgedBehaviour};
use crate::sqmr::{self, InboundSessionId, OutboundSessionId, SessionId};
Expand Down Expand Up @@ -66,7 +66,7 @@ impl<DBExecutorT: DBExecutorTrait, SwarmT: SwarmTrait> GenericNetworkManager<DBE
loop {
tokio::select! {
Some(event) = self.swarm.next() => self.handle_swarm_event(event),
Some(res) = self.db_executor.next() => self.handle_db_executor_result(res),
_ = self.db_executor.run() => panic!("DB executor should never finish."),
Some(res) = self.query_results_router.next() => self.handle_query_result_routing_to_other_peer(res),
Some((protocol, query)) = self.sqmr_query_receivers.next() => {
self.handle_local_sqmr_query(protocol, query)
Expand Down Expand Up @@ -248,25 +248,6 @@ impl<DBExecutorT: DBExecutorTrait, SwarmT: SwarmTrait> GenericNetworkManager<DBE
}
}

fn handle_db_executor_result(
&mut self,
res: Result<db_executor::QueryId, db_executor::DBExecutorError>,
) {
match res {
Ok(query_id) => {
// TODO: in case we want to do bookkeeping, this is the place.
debug!("Query completed successfully. query_id: {query_id:?}");
}
Err(err) => {
if err.should_log_in_error_level() {
error!("Query failed. error: {err:?}");
} else {
debug!("Query failed. error: {err:?}");
}
}
};
}

fn handle_behaviour_event(&mut self, event: mixed_behaviour::Event) {
match event {
mixed_behaviour::Event::ExternalEvent(external_event) => {
Expand Down
Loading

0 comments on commit 8dfbf74

Please sign in to comment.