diff --git a/ipa-core/src/error.rs b/ipa-core/src/error.rs index 5fb22dfca..0bd38cda2 100644 --- a/ipa-core/src/error.rs +++ b/ipa-core/src/error.rs @@ -6,7 +6,7 @@ use std::{ use thiserror::Error; -use crate::{report::InvalidReportError, task::JoinError}; +use crate::{helpers::Role, report::InvalidReportError, sharding::ShardIndex, task::JoinError}; /// An error raised by the IPA protocol. /// @@ -52,8 +52,10 @@ pub enum Error { #[error("failed to parse json: {0}")] #[cfg(feature = "enable-serde")] Serde(#[from] serde_json::Error), - #[error("Infrastructure error: {0}")] - InfraError(#[from] crate::helpers::Error), + #[error("MPC Infrastructure error: {0}")] + MpcInfraError(#[from] crate::helpers::Error), + #[error("Shard Infrastructure error: {0}")] + ShardInfraError(#[from] crate::helpers::Error), #[error("Value truncation error: {0}")] FieldValueTruncation(String), #[error("Invalid query parameter: {0}")] diff --git a/ipa-core/src/helpers/buffers/unordered_receiver.rs b/ipa-core/src/helpers/buffers/unordered_receiver.rs index 4a236a59d..f9ba225bb 100644 --- a/ipa-core/src/helpers/buffers/unordered_receiver.rs +++ b/ipa-core/src/helpers/buffers/unordered_receiver.rs @@ -11,7 +11,7 @@ use generic_array::GenericArray; use typenum::Unsigned; use crate::{ - helpers::{Error, Message}, + helpers::{Error, Message, Role}, protocol::RecordId, sync::{Arc, Mutex}, }; @@ -160,7 +160,7 @@ pub enum ReceiveError { #[error("Error deserializing {0:?} record: {1}")] DeserializationError(RecordId, #[source] M::DeserializationError), #[error(transparent)] - InfraError(#[from] Error), + InfraError(#[from] Error), } impl OperatingState diff --git a/ipa-core/src/helpers/error.rs b/ipa-core/src/helpers/error.rs index d73c38359..9f4fd2156 100644 --- a/ipa-core/src/helpers/error.rs +++ b/ipa-core/src/helpers/error.rs @@ -1,35 +1,17 @@ use thiserror::Error; -use tokio::sync::mpsc::error::SendError; use crate::{ error::BoxError, - helpers::{ChannelId, HelperIdentity, Message, Role, TotalRecords}, - protocol::{step::Gate, RecordId}, + helpers::{ChannelId, TotalRecords, TransportIdentity}, + protocol::RecordId, }; /// An error raised by the IPA supporting infrastructure. #[derive(Error, Debug)] -pub enum Error { - #[error("An error occurred while sending data to {channel:?}: {inner}")] - SendError { - channel: ChannelId, - - #[source] - inner: BoxError, - }, - #[error("An error occurred while sending data over a reordering channel: {inner}")] - OrderedChannelError { - #[source] - inner: BoxError, - }, - #[error("An error occurred while sending data to unknown helper: {inner}")] - PollSendError { - #[source] - inner: BoxError, - }, +pub enum Error { #[error("An error occurred while receiving data from {source:?}/{step}: {inner}")] ReceiveError { - source: Role, + source: I, step: String, #[source] inner: BoxError, @@ -39,54 +21,10 @@ pub enum Error { // TODO(mt): add more fields, like step and role. record_id: RecordId, }, - #[error("An error occurred while serializing or deserializing data for {record_id:?} and step {step}: {inner}")] - SerializationError { - record_id: RecordId, - step: String, - #[source] - inner: BoxError, - }, - #[error("Encountered unknown identity {0:?}")] - UnknownIdentity(HelperIdentity), #[error("record ID {record_id:?} is out of range for {channel_id:?} (expected {total_records:?} records)")] TooManyRecords { record_id: RecordId, - channel_id: ChannelId, + channel_id: ChannelId, total_records: TotalRecords, }, } - -impl Error { - pub fn send_error>>( - channel: ChannelId, - inner: E, - ) -> Error { - Self::SendError { - channel, - inner: inner.into(), - } - } - - #[must_use] - pub fn serialization_error>( - record_id: RecordId, - gate: &Gate, - inner: E, - ) -> Error { - Self::SerializationError { - record_id, - step: String::from(gate.as_ref()), - inner: inner.into(), - } - } -} - -impl From> for Error { - fn from(_: SendError<(usize, M)>) -> Self { - Self::OrderedChannelError { - inner: "ordered string".into(), - } - } -} - -pub type Result = std::result::Result; diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 018431d62..a9449cd99 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -15,10 +15,12 @@ pub(super) use stall_detection::InstrumentedGateway; use crate::{ helpers::{ + buffers::UnorderedReceiver, gateway::{ receive::GatewayReceivers, send::GatewaySenders, transport::RoleResolvingTransport, }, - ChannelId, Message, Role, RoleAssignment, TotalRecords, Transport, + HelperChannelId, HelperIdentity, Message, Role, RoleAssignment, RouteId, TotalRecords, + Transport, }, protocol::QueryId, }; @@ -28,17 +30,18 @@ use crate::{ /// To avoid proliferation of type parameters, most code references this concrete type alias, rather /// than a type parameter `T: Transport`. #[cfg(feature = "in-memory-infra")] -pub type TransportImpl = super::transport::InMemoryTransport; +pub type TransportImpl = super::transport::InMemoryTransport; #[cfg(feature = "real-world-infra")] pub type TransportImpl = crate::sync::Arc; -pub type TransportError = ::Error; +pub type TransportError = >::Error; /// Gateway into IPA Network infrastructure. It allows helpers send and receive messages. pub struct Gateway { config: GatewayConfig, transport: RoleResolvingTransport, + query_id: QueryId, #[cfg(feature = "stall-detection")] inner: crate::sync::Arc, #[cfg(not(feature = "stall-detection"))] @@ -74,12 +77,11 @@ impl Gateway { ) -> Self { #[allow(clippy::useless_conversion)] // not useless in stall-detection build Self { + query_id, config, transport: RoleResolvingTransport { - query_id, roles, inner: transport, - config, }, inner: State::default().into(), } @@ -87,7 +89,7 @@ impl Gateway { #[must_use] pub fn role(&self) -> Role { - self.transport.role() + self.transport.identity() } #[must_use] @@ -101,7 +103,7 @@ impl Gateway { #[must_use] pub fn get_sender( &self, - channel_id: &ChannelId, + channel_id: &HelperChannelId, total_records: TotalRecords, ) -> send::SendingEnd { let (tx, maybe_stream) = self.inner.senders.get_or_create::( @@ -113,10 +115,15 @@ impl Gateway { tokio::spawn({ let channel_id = channel_id.clone(); let transport = self.transport.clone(); + let query_id = self.query_id; async move { // TODO(651): In the HTTP case we probably need more robust error handling here. transport - .send(&channel_id, stream) + .send( + channel_id.peer, + (RouteId::Records, query_id, channel_id.gate), + stream, + ) .await .expect("{channel_id:?} receiving end should be accepted by transport"); } @@ -127,12 +134,21 @@ impl Gateway { } #[must_use] - pub fn get_receiver(&self, channel_id: &ChannelId) -> receive::ReceivingEnd { + pub fn get_receiver( + &self, + channel_id: &HelperChannelId, + ) -> receive::ReceivingEnd { receive::ReceivingEnd::new( channel_id.clone(), - self.inner - .receivers - .get_or_create(channel_id, || self.transport.receive(channel_id)), + self.inner.receivers.get_or_create(channel_id, || { + UnorderedReceiver::new( + Box::pin( + self.transport + .receive(channel_id.peer, (self.query_id, channel_id.gate.clone())), + ), + self.config.active_work(), + ) + }), ) } } diff --git a/ipa-core/src/helpers/gateway/receive.rs b/ipa-core/src/helpers/gateway/receive.rs index 0b2686ff1..6326cf5c3 100644 --- a/ipa-core/src/helpers/gateway/receive.rs +++ b/ipa-core/src/helpers/gateway/receive.rs @@ -4,13 +4,16 @@ use dashmap::{mapref::entry::Entry, DashMap}; use futures::Stream; use crate::{ - helpers::{buffers::UnorderedReceiver, ChannelId, Error, Message, Transport, TransportImpl}, + helpers::{ + buffers::UnorderedReceiver, gateway::transport::RoleResolvingTransport, Error, + HelperChannelId, Message, Role, Transport, + }, protocol::RecordId, }; -/// Receiving end end of the gateway channel. +/// Receiving end of the gateway channel. pub struct ReceivingEnd { - channel_id: ChannelId, + channel_id: HelperChannelId, unordered_rx: UR, _phantom: PhantomData, } @@ -18,16 +21,16 @@ pub struct ReceivingEnd { /// Receiving channels, indexed by (role, step). #[derive(Default)] pub(super) struct GatewayReceivers { - pub(super) inner: DashMap, + pub(super) inner: DashMap, } pub(super) type UR = UnorderedReceiver< - ::RecordsStream, - <::RecordsStream as Stream>::Item, + >::RecordsStream, + <>::RecordsStream as Stream>::Item, >; impl ReceivingEnd { - pub(super) fn new(channel_id: ChannelId, rx: UR) -> Self { + pub(super) fn new(channel_id: HelperChannelId, rx: UR) -> Self { Self { channel_id, unordered_rx: rx, @@ -44,13 +47,13 @@ impl ReceivingEnd { /// ## Panics /// This will panic if message size does not fit into 8 bytes and it somehow got serialized /// and sent to this helper. - #[tracing::instrument(level = "trace", "receive", skip_all, fields(i = %record_id, from = ?self.channel_id.role, gate = ?self.channel_id.gate.as_ref()))] - pub async fn receive(&self, record_id: RecordId) -> Result { + #[tracing::instrument(level = "trace", "receive", skip_all, fields(i = %record_id, from = ?self.channel_id.peer, gate = ?self.channel_id.gate.as_ref()))] + pub async fn receive(&self, record_id: RecordId) -> Result> { self.unordered_rx .recv::(record_id) .await .map_err(|e| Error::ReceiveError { - source: self.channel_id.role, + source: self.channel_id.peer, step: self.channel_id.gate.to_string(), inner: Box::new(e), }) @@ -58,7 +61,7 @@ impl ReceivingEnd { } impl GatewayReceivers { - pub fn get_or_create UR>(&self, channel_id: &ChannelId, ctr: F) -> UR { + pub fn get_or_create UR>(&self, channel_id: &HelperChannelId, ctr: F) -> UR { // TODO: raw entry API if it becomes available to avoid cloning the key match self.inner.entry(channel_id.clone()) { Entry::Occupied(entry) => entry.get().clone(), diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index 00d8de096..473deb486 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -11,7 +11,7 @@ use futures::Stream; use typenum::Unsigned; use crate::{ - helpers::{buffers::OrderingSender, ChannelId, Error, Message, Role, TotalRecords}, + helpers::{buffers::OrderingSender, Error, HelperChannelId, Message, Role, TotalRecords}, protocol::RecordId, sync::Arc, telemetry::{ @@ -23,7 +23,7 @@ use crate::{ /// Sending end of the gateway channel. pub struct SendingEnd { sender_role: Role, - channel_id: ChannelId, + channel_id: HelperChannelId, inner: Arc, _phantom: PhantomData, } @@ -31,11 +31,11 @@ pub struct SendingEnd { /// Sending channels, indexed by (role, step). #[derive(Default)] pub(super) struct GatewaySenders { - pub(super) inner: DashMap>, + pub(super) inner: DashMap>, } pub(super) struct GatewaySender { - channel_id: ChannelId, + channel_id: HelperChannelId, ordering_tx: OrderingSender, total_records: TotalRecords, } @@ -45,7 +45,7 @@ pub(super) struct GatewaySendStream { } impl GatewaySender { - fn new(channel_id: ChannelId, tx: OrderingSender, total_records: TotalRecords) -> Self { + fn new(channel_id: HelperChannelId, tx: OrderingSender, total_records: TotalRecords) -> Self { Self { channel_id, ordering_tx: tx, @@ -57,7 +57,7 @@ impl GatewaySender { &self, record_id: RecordId, msg: B, - ) -> Result<(), Error> { + ) -> Result<(), Error> { debug_assert!( self.total_records.is_specified(), "total_records cannot be unspecified when sending" @@ -95,7 +95,11 @@ impl GatewaySender { } impl SendingEnd { - pub(super) fn new(sender: Arc, role: Role, channel_id: &ChannelId) -> Self { + pub(super) fn new( + sender: Arc, + role: Role, + channel_id: &HelperChannelId, + ) -> Self { Self { sender_role: role, channel_id: channel_id.clone(), @@ -113,8 +117,8 @@ impl SendingEnd { /// call. /// /// [`set_total_records`]: crate::protocol::context::Context::set_total_records - #[tracing::instrument(level = "trace", "send", skip_all, fields(i = %record_id, total = %self.inner.total_records, to = ?self.channel_id.role, gate = ?self.channel_id.gate.as_ref()))] - pub async fn send>(&self, record_id: RecordId, msg: B) -> Result<(), Error> { + #[tracing::instrument(level = "trace", "send", skip_all, fields(i = %record_id, total = %self.inner.total_records, to = ?self.channel_id.peer, gate = ?self.channel_id.gate.as_ref()))] + pub async fn send>(&self, record_id: RecordId, msg: B) -> Result<(), Error> { let r = self.inner.send(record_id, msg).await; metrics::increment_counter!(RECORDS_SENT, STEP => self.channel_id.gate.as_ref().to_string(), @@ -135,7 +139,7 @@ impl GatewaySenders { /// messages to get through. pub(crate) fn get_or_create( &self, - channel_id: &ChannelId, + channel_id: &HelperChannelId, capacity: NonZeroUsize, total_records: TotalRecords, // TODO track children for indeterminate senders ) -> (Arc, Option) { diff --git a/ipa-core/src/helpers/gateway/stall_detection.rs b/ipa-core/src/helpers/gateway/stall_detection.rs index 9a1b28732..654fbb11c 100644 --- a/ipa-core/src/helpers/gateway/stall_detection.rs +++ b/ipa-core/src/helpers/gateway/stall_detection.rs @@ -74,8 +74,8 @@ mod gateway { use crate::{ helpers::{ gateway::{Gateway, State}, - ChannelId, GatewayConfig, Message, ReceivingEnd, Role, RoleAssignment, SendingEnd, - TotalRecords, TransportImpl, + GatewayConfig, HelperChannelId, Message, ReceivingEnd, Role, RoleAssignment, + SendingEnd, TotalRecords, TransportImpl, }, protocol::QueryId, sync::Arc, @@ -149,7 +149,7 @@ mod gateway { #[must_use] pub fn get_sender( &self, - channel_id: &ChannelId, + channel_id: &HelperChannelId, total_records: TotalRecords, ) -> SendingEnd { Observed::wrap( @@ -159,7 +159,7 @@ mod gateway { } #[must_use] - pub fn get_receiver(&self, channel_id: &ChannelId) -> ReceivingEnd { + pub fn get_receiver(&self, channel_id: &HelperChannelId) -> ReceivingEnd { Observed::wrap( Weak::clone(self.get_sn()), self.inner().gateway.get_receiver(channel_id), @@ -221,7 +221,7 @@ mod receive { helpers::{ error::Error, gateway::{receive::GatewayReceivers, ReceivingEnd}, - ChannelId, Message, + HelperChannelId, Message, Role, }, protocol::RecordId, }; @@ -230,12 +230,12 @@ mod receive { delegate::delegate! { to { self.advance(); self.inner() } { #[inline] - pub async fn receive(&self, record_id: RecordId) -> Result; + pub async fn receive(&self, record_id: RecordId) -> Result>; } } } - pub struct WaitingTasks(BTreeMap>); + pub struct WaitingTasks(BTreeMap>); impl Debug for WaitingTasks { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { @@ -243,7 +243,7 @@ mod receive { write!( f, "\n\"{:?}\", from={:?}. Waiting to receive records {:?}.", - channel.gate, channel.role, records + channel.gate, channel.peer, records )?; } @@ -280,7 +280,7 @@ mod send { helpers::{ error::Error, gateway::send::{GatewaySender, GatewaySenders}, - ChannelId, Message, TotalRecords, + HelperChannelId, Message, Role, TotalRecords, }, protocol::RecordId, }; @@ -289,12 +289,12 @@ mod send { delegate::delegate! { to { self.advance(); self.inner() } { #[inline] - pub async fn send>(&self, record_id: RecordId, msg: B) -> Result<(), Error>; + pub async fn send>(&self, record_id: RecordId, msg: B) -> Result<(), Error>; } } } - pub struct WaitingTasks(BTreeMap)>); + pub struct WaitingTasks(BTreeMap)>); impl Debug for WaitingTasks { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { @@ -302,7 +302,7 @@ mod send { write!( f, "\n\"{:?}\", to={:?}. Waiting to send records {:?} out of {total:?}.", - channel.gate, channel.role, records + channel.gate, channel.peer, records )?; } diff --git a/ipa-core/src/helpers/gateway/transport.rs b/ipa-core/src/helpers/gateway/transport.rs index 8c90a29ee..efbc90970 100644 --- a/ipa-core/src/helpers/gateway/transport.rs +++ b/ipa-core/src/helpers/gateway/transport.rs @@ -1,64 +1,94 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use async_trait::async_trait; +use futures::Stream; + use crate::{ helpers::{ - buffers::UnorderedReceiver, - gateway::{receive::UR, send::GatewaySendStream}, - ChannelId, GatewayConfig, Role, RoleAssignment, RouteId, Transport, TransportImpl, + HelperIdentity, NoResourceIdentifier, QueryIdBinding, Role, RoleAssignment, RouteId, + RouteParams, StepBinding, Transport, TransportImpl, }, - protocol::QueryId, + protocol::{step::Gate, QueryId}, }; +#[derive(Debug, thiserror::Error)] +#[error("Failed to send to {0:?}: {1:?}")] +pub struct SendToRoleError(Role, >::Error); + +/// This struct exists to hide the generic type used to index streams internally. +#[pin_project::pin_project] +pub struct RoleRecordsStream(#[pin] >::RecordsStream); + /// Transport adapter that resolves [`Role`] -> [`HelperIdentity`] mapping. As gateways created /// per query, it is not ambiguous. /// /// [`HelperIdentity`]: crate::helpers::HelperIdentity #[derive(Clone)] -pub(super) struct RoleResolvingTransport { - pub query_id: QueryId, - pub roles: RoleAssignment, - pub config: GatewayConfig, - pub inner: TransportImpl, +pub struct RoleResolvingTransport { + pub(super) roles: RoleAssignment, + pub(super) inner: TransportImpl, +} + +impl Stream for RoleRecordsStream { + type Item = Vec; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().0.poll_next(cx) + } } -impl RoleResolvingTransport { - pub(crate) async fn send( +#[async_trait] +impl Transport for RoleResolvingTransport { + type RecordsStream = RoleRecordsStream; + type Error = SendToRoleError; + + fn identity(&self) -> Role { + let helper_identity = self.inner.identity(); + self.roles.role(helper_identity) + } + + async fn send< + D: Stream> + Send + 'static, + Q: QueryIdBinding, + S: StepBinding, + R: RouteParams, + >( &self, - channel_id: &ChannelId, - data: GatewaySendStream, - ) -> Result<(), ::Error> { - let dest_identity = self.roles.identity(channel_id.role); + dest: Role, + route: R, + data: D, + ) -> Result<(), Self::Error> + where + Option: From, + Option: From, + { + let dest_helper = self.roles.identity(dest); assert_ne!( - dest_identity, + dest_helper, self.inner.identity(), "can't send message to itself" ); - self.inner - .send( - dest_identity, - (RouteId::Records, self.query_id, channel_id.gate.clone()), - data, - ) + .send(dest_helper, route, data) .await + .map_err(|e| SendToRoleError(dest, e)) } - pub(crate) fn receive(&self, channel_id: &ChannelId) -> UR { - let peer = self.roles.identity(channel_id.role); + fn receive>( + &self, + from: Role, + route: R, + ) -> Self::RecordsStream { + let origin_helper = self.roles.identity(from); assert_ne!( - peer, + origin_helper, self.inner.identity(), "can't receive message from itself" ); - UnorderedReceiver::new( - Box::pin( - self.inner - .receive(peer, (self.query_id, channel_id.gate.clone())), - ), - self.config.active_work(), - ) - } - - pub(crate) fn role(&self) -> Role { - self.roles.role(self.inner.identity()) + RoleRecordsStream(self.inner.receive(origin_helper, route)) } } diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 4544c2795..f9f9acc42 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -18,7 +18,7 @@ use std::ops::{Index, IndexMut}; /// to validate that transport can actually send streams of this type #[cfg(test)] pub use buffers::OrderingSender; -pub use error::{Error, Result}; +pub use error::Error; pub use futures::MaybeFuture; #[cfg(feature = "stall-detection")] @@ -51,9 +51,10 @@ pub use prss_protocol::negotiate as negotiate_prss; #[cfg(feature = "web-app")] pub use transport::WrappedAxumBodyStream; pub use transport::{ - callbacks::*, query, BodyStream, BytesStream, LengthDelimitedStream, LogErrors, - NoResourceIdentifier, QueryIdBinding, ReceiveRecords, RecordsStream, RouteId, RouteParams, - StepBinding, StreamCollection, StreamKey, Transport, WrappedBoxBodyStream, + callbacks::*, query, BodyStream, BytesStream, Identity as TransportIdentity, + LengthDelimitedStream, LogErrors, NoResourceIdentifier, QueryIdBinding, ReceiveRecords, + RecordsStream, RouteId, RouteParams, StepBinding, StreamCollection, StreamKey, Transport, + WrappedBoxBodyStream, }; #[cfg(feature = "in-memory-infra")] pub use transport::{InMemoryNetwork, InMemoryTransport}; @@ -407,23 +408,26 @@ impl TryFrom<[Role; 3]> for RoleAssignment { /// Combination of helper role and step that uniquely identifies a single channel of communication /// between two helpers. #[derive(Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] -pub struct ChannelId { - pub role: Role, +pub struct ChannelId { + /// Entity we are talking to through this channel. It can be a source or a destination. + pub peer: I, // TODO: step could be either reference or owned value. references are convenient to use inside // gateway , owned values can be used inside lookup tables. pub gate: Gate, } -impl ChannelId { +pub type HelperChannelId = ChannelId; + +impl ChannelId { #[must_use] - pub fn new(role: Role, gate: Gate) -> Self { - Self { role, gate } + pub fn new(peer: I, gate: Gate) -> Self { + Self { peer, gate } } } -impl Debug for ChannelId { +impl Debug for ChannelId { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "channel[{:?},{:?}]", self.role, self.gate.as_ref()) + write!(f, "channel[{:?},{:?}]", self.peer, self.gate.as_ref()) } } diff --git a/ipa-core/src/helpers/prss_protocol.rs b/ipa-core/src/helpers/prss_protocol.rs index 4dddd21fb..348a36596 100644 --- a/ipa-core/src/helpers/prss_protocol.rs +++ b/ipa-core/src/helpers/prss_protocol.rs @@ -3,7 +3,7 @@ use rand_core::{CryptoRng, RngCore}; use x25519_dalek::PublicKey; use crate::{ - helpers::{ChannelId, Direction, Error, Gateway, TotalRecords}, + helpers::{ChannelId, Direction, Error, Gateway, Role, TotalRecords}, protocol::{ prss, step::{Gate, Step, StepNarrow}, @@ -28,7 +28,7 @@ pub async fn negotiate( gateway: &Gateway, gate: &Gate, rng: &mut R, -) -> Result { +) -> Result> { // setup protocol to exchange prss public keys. This protocol sends one message per peer. // Each message contains this helper's public key. At the end of this protocol, all helpers // have completed key exchange and each of them have established a shared secret with each peer. diff --git a/ipa-core/src/helpers/transport/in_memory/mod.rs b/ipa-core/src/helpers/transport/in_memory/mod.rs index 564caee0c..36a9eaed4 100644 --- a/ipa-core/src/helpers/transport/in_memory/mod.rs +++ b/ipa-core/src/helpers/transport/in_memory/mod.rs @@ -3,19 +3,19 @@ mod transport; pub use transport::Setup; use crate::{ - helpers::{HelperIdentity, TransportCallbacks}, + helpers::{HelperIdentity, TransportCallbacks, TransportIdentity}, sync::{Arc, Weak}, }; -pub type InMemoryTransport = Weak; +pub type InMemoryTransport = Weak>; /// Container for all active transports #[derive(Clone)] -pub struct InMemoryNetwork { - pub transports: [Arc; 3], +pub struct InMemoryNetwork { + pub transports: [Arc>; 3], } -impl Default for InMemoryNetwork { +impl Default for InMemoryNetwork { fn default() -> Self { Self::new([ TransportCallbacks::default(), @@ -26,26 +26,10 @@ impl Default for InMemoryNetwork { } #[allow(dead_code)] -impl InMemoryNetwork { - #[must_use] - pub fn new(callbacks: [TransportCallbacks; 3]) -> Self { - let [mut first, mut second, mut third]: [_; 3] = - HelperIdentity::make_three().map(Setup::new); - - first.connect(&mut second); - second.connect(&mut third); - third.connect(&mut first); - - let [cb1, cb2, cb3] = callbacks; - - Self { - transports: [first.start(cb1), second.start(cb2), third.start(cb3)], - } - } - +impl InMemoryNetwork { #[must_use] #[allow(clippy::missing_panics_doc)] - pub fn helper_identities(&self) -> [HelperIdentity; 3] { + pub fn identities(&self) -> [I; 3] { self.transports .iter() .map(|t| t.identity()) @@ -59,7 +43,7 @@ impl InMemoryNetwork { /// ## Panics /// If [`HelperIdentity`] is somehow points to a non-existent helper, which shouldn't happen. #[must_use] - pub fn transport(&self, id: HelperIdentity) -> InMemoryTransport { + pub fn transport(&self, id: I) -> InMemoryTransport { self.transports .iter() .find(|t| t.identity() == id) @@ -68,8 +52,8 @@ impl InMemoryNetwork { #[allow(clippy::missing_panics_doc)] #[must_use] - pub fn transports(&self) -> [InMemoryTransport; 3] { - let transports: [InMemoryTransport; 3] = self + pub fn transports(&self) -> [InMemoryTransport; 3] { + let transports: [InMemoryTransport<_>; 3] = self .transports .iter() .map(Arc::downgrade) @@ -87,3 +71,21 @@ impl InMemoryNetwork { } } } + +impl InMemoryNetwork { + #[must_use] + pub fn new(callbacks: [TransportCallbacks>; 3]) -> Self { + let [mut first, mut second, mut third]: [_; 3] = + HelperIdentity::make_three().map(Setup::new); + + first.connect(&mut second); + second.connect(&mut third); + third.connect(&mut first); + + let [cb1, cb2, cb3] = callbacks; + + Self { + transports: [first.start(cb1), second.start(cb2), third.start(cb3)], + } + } +} diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index f23d586bc..ab8173766 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -25,19 +25,23 @@ use crate::{ helpers::{ query::{PrepareQuery, QueryConfig}, HelperIdentity, NoResourceIdentifier, QueryIdBinding, ReceiveRecords, RouteId, RouteParams, - StepBinding, StreamCollection, Transport, TransportCallbacks, + StepBinding, StreamCollection, Transport, TransportCallbacks, TransportIdentity, }, protocol::{step::Gate, QueryId}, sync::{Arc, Weak}, }; -type Packet = (Addr, InMemoryStream, oneshot::Sender>); -type ConnectionTx = Sender; -type ConnectionRx = Receiver; +type Packet = ( + Addr, + InMemoryStream, + oneshot::Sender>>, +); +type ConnectionTx = Sender>; +type ConnectionRx = Receiver>; type StreamItem = Vec; #[derive(Debug, thiserror::Error)] -pub enum Error { +pub enum Error { #[error(transparent)] Io { #[from] @@ -45,7 +49,7 @@ pub enum Error { }, #[error("Request rejected by remote {dest:?}: {inner:?}")] Rejected { - dest: HelperIdentity, + dest: I, #[source] inner: BoxError, }, @@ -54,15 +58,15 @@ pub enum Error { /// In-memory implementation of [`Transport`] backed by Tokio mpsc channels. /// Use [`Setup`] to initialize it and call [`Setup::start`] to make it actively listen for /// incoming messages. -pub struct InMemoryTransport { - identity: HelperIdentity, - connections: HashMap, - record_streams: StreamCollection, +pub struct InMemoryTransport { + identity: I, + connections: HashMap>, + record_streams: StreamCollection, } -impl InMemoryTransport { +impl InMemoryTransport { #[must_use] - fn new(identity: HelperIdentity, connections: HashMap) -> Self { + fn new(identity: I, connections: HashMap>) -> Self { Self { identity, connections, @@ -71,7 +75,7 @@ impl InMemoryTransport { } #[must_use] - pub fn identity(&self) -> HelperIdentity { + pub fn identity(&self) -> I { self.identity } @@ -79,7 +83,11 @@ impl InMemoryTransport { /// out and processes it, the same way as query processor does. That will allow all tasks to be /// created in one place (driver). It does not affect the [`Transport`] interface, /// so I'll leave it as is for now. - fn listen(self: &Arc, callbacks: TransportCallbacks>, mut rx: ConnectionRx) { + fn listen( + self: &Arc, + callbacks: TransportCallbacks>, + mut rx: ConnectionRx, + ) { tokio::spawn( { let streams = self.record_streams.clone(); @@ -132,7 +140,7 @@ impl InMemoryTransport { ); } - fn get_channel(&self, dest: HelperIdentity) -> ConnectionTx { + fn get_channel(&self, dest: I) -> ConnectionTx { self.connections .get(&dest) .unwrap_or_else(|| { @@ -151,11 +159,11 @@ impl InMemoryTransport { } #[async_trait] -impl Transport for Weak { - type RecordsStream = ReceiveRecords; - type Error = Error; +impl Transport for Weak> { + type RecordsStream = ReceiveRecords; + type Error = Error; - fn identity(&self) -> HelperIdentity { + fn identity(&self) -> I { self.upgrade().unwrap().identity } @@ -166,10 +174,10 @@ impl Transport for Weak { R: RouteParams, >( &self, - dest: HelperIdentity, + dest: I, route: R, data: D, - ) -> Result<(), Error> + ) -> Result<(), Error> where Option: From, Option: From, @@ -197,7 +205,7 @@ impl Transport for Weak { fn receive>( &self, - from: HelperIdentity, + from: I, route: R, ) -> Self::RecordsStream { ReceiveRecords::new( @@ -261,18 +269,18 @@ impl Debug for InMemoryStream { } } -struct Addr { +struct Addr { route: RouteId, - origin: Option, + origin: Option, query_id: Option, gate: Option, params: String, } -impl Addr { +impl Addr { #[allow(clippy::needless_pass_by_value)] // to avoid using double-reference at callsites fn from_route>( - origin: HelperIdentity, + origin: I, route: R, ) -> Self where @@ -293,7 +301,7 @@ impl Addr { } #[cfg(all(test, unit_test))] - fn records(from: HelperIdentity, query_id: QueryId, gate: Gate) -> Self { + fn records(from: I, query_id: QueryId, gate: Gate) -> Self { Self { route: RouteId::Records, origin: Some(from), @@ -304,7 +312,7 @@ impl Addr { } } -impl Debug for Addr { +impl Debug for Addr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, @@ -316,9 +324,9 @@ impl Debug for Addr { pub struct Setup { identity: HelperIdentity, - tx: ConnectionTx, - rx: ConnectionRx, - connections: HashMap, + tx: ConnectionTx, + rx: ConnectionRx, + connections: HashMap>, } impl Setup { @@ -350,8 +358,11 @@ impl Setup { fn into_active_conn( self, - callbacks: TransportCallbacks>, - ) -> (ConnectionTx, Arc) { + callbacks: TransportCallbacks>>, + ) -> ( + ConnectionTx, + Arc>, + ) { let transport = Arc::new(InMemoryTransport::new(self.identity, self.connections)); transport.listen(callbacks, self.rx); @@ -361,8 +372,8 @@ impl Setup { #[must_use] pub fn start( self, - callbacks: TransportCallbacks>, - ) -> Arc { + callbacks: TransportCallbacks>>, + ) -> Arc> { self.into_active_conn(callbacks).1 } } @@ -391,6 +402,7 @@ mod tests { InMemoryNetwork, Setup, }, HelperIdentity, OrderingSender, RouteId, Transport, TransportCallbacks, + TransportIdentity, }, protocol::{step::Gate, QueryId}, sync::Arc, @@ -398,7 +410,11 @@ mod tests { const STEP: &str = "in-memory-transport"; - async fn send_and_ack(sender: &ConnectionTx, addr: Addr, data: InMemoryStream) { + async fn send_and_ack( + sender: &ConnectionTx, + addr: Addr, + data: InMemoryStream, + ) { let (tx, rx) = oneshot::channel(); sender.send((addr, data, tx)).await.unwrap(); rx.await @@ -491,7 +507,7 @@ mod tests { async fn send_and_verify( from: HelperIdentity, to: HelperIdentity, - transports: &HashMap>, + transports: &HashMap>>, ) { let (stream_tx, stream_rx) = channel(1); let stream = InMemoryStream::from(stream_rx); diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index acbbb8e8e..2bdcb5ace 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -1,4 +1,4 @@ -use std::borrow::Borrow; +use std::{borrow::Borrow, fmt::Debug, hash::Hash}; use async_trait::async_trait; use futures::Stream; @@ -25,6 +25,19 @@ pub use stream::{ WrappedBoxBodyStream, }; +use crate::{helpers::Role, sharding::ShardIndex}; + +/// An identity of a peer that can be communicated with using [`Transport`]. There are currently two +/// types of peers - helpers and shards. +pub trait Identity: Copy + Clone + Debug + PartialEq + Eq + Hash + Send + Sync + 'static {} + +impl Identity for ShardIndex {} +impl Identity for HelperIdentity {} + +/// Role is an identifier of helper peer, only valid within a given query. For every query, there +/// exists a static mapping from role to helper identity. +impl Identity for Role {} + pub trait ResourceIdentifier: Sized {} pub trait QueryIdBinding: Sized where @@ -125,21 +138,16 @@ impl RouteParams for (RouteId, QueryId, Gate) { /// Transport that supports per-query,per-step channels #[async_trait] -pub trait Transport: Clone + Send + Sync + 'static { +pub trait Transport: Clone + Send + Sync + 'static { type RecordsStream: Stream> + Send + Unpin; type Error: std::fmt::Debug; - fn identity(&self) -> HelperIdentity; + fn identity(&self) -> I; /// Sends a new request to the given destination helper party. /// Depending on the specific request, it may or may not require acknowledgment by the remote /// party - async fn send( - &self, - dest: HelperIdentity, - route: R, - data: D, - ) -> Result<(), Self::Error> + async fn send(&self, dest: I, route: R, data: D) -> Result<(), Self::Error> where Option: From, Option: From, @@ -152,7 +160,7 @@ pub trait Transport: Clone + Send + Sync + 'static { /// and step fn receive>( &self, - from: HelperIdentity, + from: I, route: R, ) -> Self::RecordsStream; diff --git a/ipa-core/src/helpers/transport/receive.rs b/ipa-core/src/helpers/transport/receive.rs index e797e7dc9..fec775d0b 100644 --- a/ipa-core/src/helpers/transport/receive.rs +++ b/ipa-core/src/helpers/transport/receive.rs @@ -1,15 +1,19 @@ use std::{ - pin::Pin, + pin::{pin, Pin}, task::{Context, Poll}, }; use futures::Stream; use futures_util::StreamExt; +use pin_project::pin_project; use tracing::error; use crate::{ error::BoxError, - helpers::transport::stream::{StreamCollection, StreamKey}, + helpers::{ + transport::stream::{StreamCollection, StreamKey}, + TransportIdentity, + }, }; /// Adapt a stream of `Result>, Error>` to a stream of `Vec`. @@ -66,47 +70,49 @@ where /// If stream is not received yet, each poll generates a waker that is used internally to wake up /// the task when stream is received. /// Once stream is received, it is moved to this struct and it acts as a proxy to it. -pub struct ReceiveRecords { - inner: ReceiveRecordsInner, +#[pin_project] +pub struct ReceiveRecords { + #[pin] + inner: ReceiveRecordsInner, } -impl ReceiveRecords { - pub(crate) fn new(key: StreamKey, coll: StreamCollection) -> Self { +impl ReceiveRecords { + pub(crate) fn new(key: StreamKey, coll: StreamCollection) -> Self { Self { inner: ReceiveRecordsInner::Pending(key, coll), } } } -impl Stream for ReceiveRecords { +impl Stream for ReceiveRecords { type Item = S::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::get_mut(self).inner.poll_next_unpin(cx) + self.project().inner.poll_next(cx) } } /// Inner state for [`ReceiveRecords`] struct -enum ReceiveRecordsInner { - Pending(StreamKey, StreamCollection), - Ready(S), +#[pin_project(project = ReceiveRecordsInnerProj)] +enum ReceiveRecordsInner { + Pending(StreamKey, StreamCollection), + Ready(#[pin] S), } -impl Stream for ReceiveRecordsInner { +impl Stream for ReceiveRecordsInner { type Item = S::Item; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = Pin::get_mut(self); + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - match this { - Self::Pending(key, streams) => { + match self.as_mut().project() { + ReceiveRecordsInnerProj::Pending(key, streams) => { if let Some(stream) = streams.add_waker(key, cx.waker()) { - *this = Self::Ready(stream); + self.set(Self::Ready(stream)); } else { return Poll::Pending; } } - Self::Ready(stream) => return stream.poll_next_unpin(cx), + ReceiveRecordsInnerProj::Ready(stream) => return stream.poll_next(cx), } } } diff --git a/ipa-core/src/helpers/transport/stream/collection.rs b/ipa-core/src/helpers/transport/stream/collection.rs index a60abca55..b5b1a8bcf 100644 --- a/ipa-core/src/helpers/transport/stream/collection.rs +++ b/ipa-core/src/helpers/transport/stream/collection.rs @@ -7,14 +7,14 @@ use std::{ use futures::Stream; use crate::{ - helpers::HelperIdentity, + helpers::TransportIdentity, protocol::{step::Gate, QueryId}, sync::{Arc, Mutex}, }; /// Each stream is indexed by query id, the identity of helper where stream is originated from /// and step. -pub type StreamKey = (QueryId, HelperIdentity, Gate); +pub type StreamKey = (QueryId, I, Gate); /// Thread-safe append-only collection of homogeneous record streams. /// Streams are indexed by [`StreamKey`] and the lifecycle of each stream is described by the @@ -22,11 +22,11 @@ pub type StreamKey = (QueryId, HelperIdentity, Gate); /// /// Each stream can be inserted and taken away exactly once, any deviation from this behaviour will /// result in panic. -pub struct StreamCollection { - inner: Arc>>>, +pub struct StreamCollection { + inner: Arc, StreamState>>>, } -impl Default for StreamCollection { +impl Default for StreamCollection { fn default() -> Self { Self { inner: Arc::new(Mutex::new(HashMap::default())), @@ -34,7 +34,7 @@ impl Default for StreamCollection { } } -impl Clone for StreamCollection { +impl Clone for StreamCollection { fn clone(&self) -> Self { Self { inner: Arc::clone(&self.inner), @@ -42,12 +42,12 @@ impl Clone for StreamCollection { } } -impl StreamCollection { +impl StreamCollection { /// Adds a new stream associated with the given key. /// /// ## Panics /// If there was another stream associated with the same key some time in the past. - pub fn add_stream(&self, key: StreamKey, stream: S) { + pub fn add_stream(&self, key: StreamKey, stream: S) { let mut streams = self.inner.lock().unwrap(); match streams.entry(key) { Entry::Occupied(mut entry) => match entry.get_mut() { @@ -77,7 +77,7 @@ impl StreamCollection { /// /// ## Panics /// If [`Waker`] that exists already inside this collection will not wake the given one. - pub fn add_waker(&self, key: &StreamKey, waker: &Waker) -> Option { + pub fn add_waker(&self, key: &StreamKey, waker: &Waker) -> Option { let mut streams = self.inner.lock().unwrap(); match streams.entry(key.clone()) { diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 50e2d98d6..abb5a3aa1 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -33,6 +33,7 @@ mod exact; mod seq_join; #[cfg(feature = "enable-serde")] mod serde; +mod sharding; pub use app::{HelperApp, Setup as AppSetup}; diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index fcdc63c33..19c048e6c 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -33,7 +33,7 @@ pub struct HttpTransport { clients: [MpcHelperClient; 3], // TODO(615): supporting multiple queries likely require a hashmap here. It will be ok if we // only allow one query at a time. - record_streams: StreamCollection, + record_streams: StreamCollection, } impl HttpTransport { @@ -123,8 +123,8 @@ impl HttpTransport { } #[async_trait] -impl Transport for Arc { - type RecordsStream = ReceiveRecords; +impl Transport for Arc { + type RecordsStream = ReceiveRecords; type Error = Error; fn identity(&self) -> HelperIdentity { diff --git a/ipa-core/src/sharding.rs b/ipa-core/src/sharding.rs new file mode 100644 index 000000000..564d9cd6f --- /dev/null +++ b/ipa-core/src/sharding.rs @@ -0,0 +1,116 @@ +use std::fmt::{Display, Formatter}; + +/// A unique zero-based index of the helper shard. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ShardIndex(u32); + +impl Display for ShardIndex { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +/// Shard-specific configuration required by sharding API. Each shard must know its own index and +/// the total number of shards in the system. +pub trait ShardConfiguration { + /// Returns the index of the current shard. + fn shard_id(&self) -> ShardIndex; + + /// Total number of shards present on this helper. It is expected that all helpers have the + /// same number of shards. + fn shard_count(&self) -> ShardIndex; + + /// Returns an iterator that yields shard indices for all shards present in the system, except + /// this one. Shards are yielded in ascending order. + /// + /// ## Panics + /// if current shard index is greater or equal to the total number of shards. + fn peer_shards(&self) -> impl Iterator { + let this = self.shard_id(); + let max = self.shard_count(); + assert!( + this < max, + "Current shard index '{this}' >= '{max}' (total number of shards)" + ); + + max.iter().filter(move |&v| v != this) + } +} + +impl ShardIndex { + pub const FIRST: Self = Self(0); + + /// Returns an iterator over all shard indices that precede this one, excluding this one. + pub fn iter(self) -> impl Iterator { + (0..self.0).map(Self) + } +} + +impl From for ShardIndex { + fn from(value: u32) -> Self { + Self(value) + } +} + +#[cfg(target_pointer_width = "64")] +impl From for usize { + fn from(value: ShardIndex) -> Self { + usize::try_from(value.0).unwrap() + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::iter::empty; + + use crate::sharding::ShardIndex; + + fn shards>(input: I) -> impl Iterator { + input.into_iter().map(ShardIndex) + } + + #[test] + fn iter() { + assert!(ShardIndex::FIRST.iter().eq(empty())); + assert!(shards([0, 1, 2]).eq(ShardIndex::from(3).iter())); + } + + /// It is often useful to keep a collection of elements indexed by shard. + #[test] + fn indexing() { + let arr = [0, 1, 2]; + assert_eq!(0, arr[usize::from(ShardIndex::FIRST)]); + } + + mod conf { + use crate::sharding::{tests::shards, ShardConfiguration, ShardIndex}; + + struct StaticConfig(u32, u32); + impl ShardConfiguration for StaticConfig { + fn shard_id(&self) -> ShardIndex { + self.0.into() + } + + fn shard_count(&self) -> ShardIndex { + self.1.into() + } + } + + #[test] + fn excludes_this_shard() { + assert!(shards([0, 1, 2, 4]).eq(StaticConfig(3, 5).peer_shards())); + } + + #[test] + #[should_panic(expected = "Current shard index '5' >= '5' (total number of shards)")] + fn shard_index_eq_shard_count() { + let _ = StaticConfig(5, 5).peer_shards(); + } + + #[test] + #[should_panic(expected = "Current shard index '7' >= '5' (total number of shards)")] + fn shard_index_gt_shard_count() { + let _ = StaticConfig(7, 5).peer_shards(); + } + } +} diff --git a/ipa-core/src/test_fixture/app.rs b/ipa-core/src/test_fixture/app.rs index d38e95da5..d32ce0b23 100644 --- a/ipa-core/src/test_fixture/app.rs +++ b/ipa-core/src/test_fixture/app.rs @@ -8,7 +8,7 @@ use crate::{ ff::Serializable, helpers::{ query::{QueryConfig, QueryInput}, - InMemoryNetwork, InMemoryTransport, + HelperIdentity, InMemoryNetwork, }, protocol::QueryId, query::QueryStatus, @@ -50,7 +50,7 @@ where /// [`TestWorld`]: crate::test_fixture::TestWorld pub struct TestApp { drivers: [HelperApp; 3], - network: InMemoryNetwork, + network: InMemoryNetwork, } fn unzip_tuple_array(input: [(T, U); 3]) -> ([T; 3], [U; 3]) { @@ -68,7 +68,7 @@ impl Default for TestApp { .transports() .iter() .zip(setup) - .map(|(t, s)| s.connect(::clone(t))) + .map(|(t, s)| s.connect(Clone::clone(t))) .collect::>() .try_into() .map_err(|_| "infallible") diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 1b5a691d2..0f9cb4159 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -8,7 +8,7 @@ use rand_core::{RngCore, SeedableRng}; use tracing::{Instrument, Level, Span}; use crate::{ - helpers::{Gateway, GatewayConfig, InMemoryNetwork, Role, RoleAssignment}, + helpers::{Gateway, GatewayConfig, HelperIdentity, InMemoryNetwork, Role, RoleAssignment}, protocol::{ context::{ Context, MaliciousContext, SemiHonestContext, UpgradableContext, UpgradeContext, @@ -49,7 +49,7 @@ pub struct TestWorld { participants: [PrssEndpoint; 3], executions: AtomicUsize, metrics_handle: MetricsHandle, - _network: InMemoryNetwork, + _network: InMemoryNetwork, } #[derive(Clone)] @@ -112,7 +112,7 @@ impl TestWorld { let network = InMemoryNetwork::default(); let role_assignment = config .role_assignment - .unwrap_or_else(|| RoleAssignment::new(network.helper_identities())); + .unwrap_or_else(|| RoleAssignment::new(network.identities())); let mut gateways = [None, None, None]; for i in 0..3 {