Skip to content
This repository has been archived by the owner on Dec 26, 2024. It is now read-only.

Commit

Permalink
refactor(network): db executor communicates through channels
Browse files Browse the repository at this point in the history
  • Loading branch information
ShahakShama committed Jun 20, 2024
1 parent 7267f39 commit cb8c766
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 96 deletions.
127 changes: 93 additions & 34 deletions crates/papyrus_network/src/db_executor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use std::vec;

use async_trait::async_trait;
use futures::channel::mpsc::Sender;
use futures::future::{pending, poll_fn};
use futures::future::pending;
use futures::{FutureExt, Sink, SinkExt, StreamExt};
use papyrus_protobuf::sync::{
BlockHashOrNumber,
ContractDiff,
DataOrFin,
DeclaredClass,
DeprecatedDeclaredClass,
HeaderQuery,
Query,
SignedBlockHeader,
StateDiffChunk,
StateDiffQuery,
};
use papyrus_storage::header::HeaderStorageReader;
use papyrus_storage::state::StateStorageReader;
Expand All @@ -20,6 +22,8 @@ use starknet_api::block::BlockNumber;
use starknet_api::state::ThinStateDiff;
use tracing::error;

use crate::network_manager::SqmrQueryReceiver;

#[cfg(test)]
mod test;

Expand Down Expand Up @@ -59,12 +63,14 @@ impl DBExecutorError {
/// A DBExecutor receives inbound queries and returns their corresponding data.
#[async_trait]
pub trait DBExecutorTrait {
/// Send a query to be executed in the DBExecutor. The query will be run concurrently with the
/// calling code and the result will be over the given channel.
fn register_query<Data: FetchBlockDataFromDb + Send + 'static>(
&self,
query: Query,
sender: Sender<DataOrFin<Data>>,
fn set_header_queries_receiver(
&mut self,
receiver: SqmrQueryReceiver<HeaderQuery, DataOrFin<SignedBlockHeader>>,
);

fn set_state_diff_queries_receiver(
&mut self,
receiver: SqmrQueryReceiver<StateDiffQuery, DataOrFin<StateDiffChunk>>,
);

/// Polls incoming queries.
Expand All @@ -74,21 +80,73 @@ pub trait DBExecutorTrait {

pub struct DBExecutor {
storage_reader: StorageReader,
// TODO(shahak): Make this non-option.
header_queries_receiver: Option<SqmrQueryReceiver<HeaderQuery, DataOrFin<SignedBlockHeader>>>,
// TODO(shahak): Make this non-option.
state_diff_queries_receiver:
Option<SqmrQueryReceiver<StateDiffQuery, DataOrFin<StateDiffChunk>>>,
}

#[async_trait]
impl DBExecutorTrait for DBExecutor {
fn set_header_queries_receiver(
&mut self,
receiver: SqmrQueryReceiver<HeaderQuery, DataOrFin<SignedBlockHeader>>,
) {
self.header_queries_receiver = Some(receiver);
}

fn set_state_diff_queries_receiver(
&mut self,
receiver: SqmrQueryReceiver<StateDiffQuery, DataOrFin<StateDiffChunk>>,
) {
self.state_diff_queries_receiver = Some(receiver);
}

async fn run(&mut self) {
loop {
let header_queries_receiver_future =
if let Some(header_queries_receiver) = self.header_queries_receiver.as_mut() {
header_queries_receiver.next().boxed()
} else {
pending().boxed()
};
let state_diff_queries_receiver_future = if let Some(state_diff_queries_receiver) =
self.state_diff_queries_receiver.as_mut()
{
state_diff_queries_receiver.next().boxed()
} else {
pending().boxed()
};

tokio::select! {
Some((query_result, response_sender)) = header_queries_receiver_future => {
// TODO(shahak): Report if query_result is Err.
if let Ok(query) = query_result {
self.register_query(query.0, response_sender);
}
}
Some((query_result, response_sender)) = state_diff_queries_receiver_future => {
if let Ok(query) = query_result {
self.register_query(query.0, response_sender);
}
}
};
}
}
}

impl DBExecutor {
pub fn new(storage_reader: StorageReader) -> Self {
Self { storage_reader }
Self { storage_reader, header_queries_receiver: None, state_diff_queries_receiver: None }
}
}

#[async_trait]
impl DBExecutorTrait for DBExecutor {
fn register_query<Data: FetchBlockDataFromDb + Send + 'static>(
&self,
query: Query,
sender: Sender<DataOrFin<Data>>,
) {
fn register_query<Data, Sender>(&self, query: Query, sender: Sender)
where
Data: FetchBlockDataFromDb + Send + 'static,
Sender: Sink<DataOrFin<Data>> + Unpin + Send + 'static,
DBExecutorError: From<<Sender as Sink<DataOrFin<Data>>>::Error>,
{
let storage_reader_clone = self.storage_reader.clone();
tokio::task::spawn(async move {
let result = send_data_for_query(storage_reader_clone, query.clone(), sender).await;
Expand All @@ -102,12 +160,6 @@ impl DBExecutorTrait for DBExecutor {
}
});
}

async fn run(&mut self) {
// TODO(shahak): Parse incoming queries once we receive them through channel instead of
// through function.
pending::<()>().await
}
}

pub trait FetchBlockDataFromDb: Sized {
Expand Down Expand Up @@ -197,23 +249,32 @@ pub fn split_thin_state_diff(thin_state_diff: ThinStateDiff) -> Vec<StateDiffChu
state_diff_chunks
}

async fn send_data_for_query<Data: FetchBlockDataFromDb + Send + 'static>(
async fn send_data_for_query<Data, Sender>(
storage_reader: StorageReader,
query: Query,
mut sender: Sender<DataOrFin<Data>>,
) -> Result<(), DBExecutorError> {
mut sender: Sender,
) -> Result<(), DBExecutorError>
where
Data: FetchBlockDataFromDb + Send + 'static,
Sender: Sink<DataOrFin<Data>> + Unpin + Send + 'static,
DBExecutorError: From<<Sender as Sink<DataOrFin<Data>>>::Error>,
{
// If this function fails, we still want to send fin before failing.
let result = send_data_without_fin_for_query(&storage_reader, query, &mut sender).await;
poll_fn(|cx| sender.poll_ready(cx)).await?;
sender.start_send(DataOrFin(None))?;
sender.feed(DataOrFin(None)).await?;
result
}

async fn send_data_without_fin_for_query<Data: FetchBlockDataFromDb + Send + 'static>(
async fn send_data_without_fin_for_query<Data, Sender>(
storage_reader: &StorageReader,
query: Query,
sender: &mut Sender<DataOrFin<Data>>,
) -> Result<(), DBExecutorError> {
sender: &mut Sender,
) -> Result<(), DBExecutorError>
where
Data: FetchBlockDataFromDb + Send + 'static,
Sender: Sink<DataOrFin<Data>> + Unpin + Send + 'static,
DBExecutorError: From<<Sender as Sink<DataOrFin<Data>>>::Error>,
{
let txn = storage_reader.begin_ro_txn()?;
let start_block_number = match query.start_block {
BlockHashOrNumber::Number(BlockNumber(num)) => num,
Expand All @@ -229,11 +290,9 @@ async fn send_data_without_fin_for_query<Data: FetchBlockDataFromDb + Send + 'st
let block_number =
BlockNumber(utils::calculate_block_number(&query, start_block_number, block_counter)?);
let data_vec = Data::fetch_block_data_from_db(block_number, &txn)?;
// Using poll_fn because Sender::poll_ready is not a future
poll_fn(|cx| sender.poll_ready(cx)).await?;
for data in data_vec {
// TODO: consider implement retry mechanism.
sender.start_send(DataOrFin(Some(data)))?;
sender.feed(DataOrFin(Some(data))).await?;
}
}
Ok(())
Expand Down
6 changes: 3 additions & 3 deletions crates/papyrus_network/src/db_executor/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async fn header_query_positive_flow() {
step: 1,
};
let (sender, data_receiver) = futures::channel::mpsc::channel(BUFFER_SIZE);
db_executor.register_query::<SignedBlockHeader>(query.clone(), sender);
db_executor.register_query::<SignedBlockHeader, _>(query.clone(), sender);

// run the executor and collect query results.
tokio::select! {
Expand Down Expand Up @@ -77,7 +77,7 @@ async fn header_query_start_block_given_by_hash() {
limit: NUM_OF_BLOCKS,
step: 1,
};
db_executor.register_query::<SignedBlockHeader>(query, sender);
db_executor.register_query::<SignedBlockHeader, _>(query, sender);

// run the executor and collect query results.
tokio::select! {
Expand Down Expand Up @@ -116,7 +116,7 @@ async fn header_query_some_blocks_are_missing() {
limit: NUM_OF_BLOCKS,
step: 1,
};
db_executor.register_query::<SignedBlockHeader>(query, sender);
db_executor.register_query::<SignedBlockHeader, _>(query, sender);

tokio::select! {
_ = db_executor.run() => {
Expand Down
17 changes: 0 additions & 17 deletions crates/papyrus_network/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ use papyrus_config::converters::{
use papyrus_config::dumping::{ser_optional_param, ser_param, SerializeConfig};
use papyrus_config::validators::validate_vec_u256;
use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam};
use papyrus_protobuf::protobuf;
use papyrus_protobuf::sync::Query;
use prost::Message;
use serde::{Deserialize, Serialize};
use validator::Validate;

Expand Down Expand Up @@ -98,20 +95,6 @@ impl Protocol {
Protocol::StateDiff => "/starknet/state_diffs/1",
}
}

pub fn bytes_query_to_protobuf_request(&self, query: Vec<u8>) -> Query {
// TODO: make this function return errors instead of panicking.
match self {
Protocol::SignedBlockHeader => protobuf::BlockHeadersRequest::decode(&query[..])
.expect("failed to decode protobuf BlockHeadersRequest")
.try_into()
.expect("failed to convert BlockHeadersRequest"),
Protocol::StateDiff => protobuf::StateDiffsRequest::decode(&query[..])
.expect("failed to decode protobuf StateDiffsRequest")
.try_into()
.expect("failed to convert StateDiffsRequest"),
}
}
}

impl From<Protocol> for StreamProtocol {
Expand Down
Loading

0 comments on commit cb8c766

Please sign in to comment.