Skip to content

Commit

Permalink
Make InMemoryNetwork generic
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Mar 10, 2024
1 parent ae80516 commit dfe19fa
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 99 deletions.
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::{
/// To avoid proliferation of type parameters, most code references this concrete type alias, rather
/// than a type parameter `T: Transport`.
#[cfg(feature = "in-memory-infra")]
pub type TransportImpl = super::transport::InMemoryTransport;
pub type TransportImpl = super::transport::InMemoryTransport<HelperIdentity>;

#[cfg(feature = "real-world-infra")]
pub type TransportImpl = crate::sync::Arc<crate::net::HttpTransport>;
Expand Down
54 changes: 28 additions & 26 deletions ipa-core/src/helpers/transport/in_memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ mod transport;
pub use transport::Setup;

use crate::{
helpers::{HelperIdentity, TransportCallbacks},
helpers::{HelperIdentity, TransportCallbacks, TransportIdentity},
sync::{Arc, Weak},
};

pub type InMemoryTransport = Weak<transport::InMemoryTransport>;
pub type InMemoryTransport<I> = Weak<transport::InMemoryTransport<I>>;

/// Container for all active transports
#[derive(Clone)]
pub struct InMemoryNetwork {
pub transports: [Arc<transport::InMemoryTransport>; 3],
pub struct InMemoryNetwork<I> {
pub transports: [Arc<transport::InMemoryTransport<I>>; 3],
}

impl Default for InMemoryNetwork {
impl Default for InMemoryNetwork<HelperIdentity> {
fn default() -> Self {
Self::new([
TransportCallbacks::default(),
Expand All @@ -26,26 +26,10 @@ impl Default for InMemoryNetwork {
}

#[allow(dead_code)]
impl InMemoryNetwork {
#[must_use]
pub fn new(callbacks: [TransportCallbacks<InMemoryTransport>; 3]) -> Self {
let [mut first, mut second, mut third]: [_; 3] =
HelperIdentity::make_three().map(Setup::new);

first.connect(&mut second);
second.connect(&mut third);
third.connect(&mut first);

let [cb1, cb2, cb3] = callbacks;

Self {
transports: [first.start(cb1), second.start(cb2), third.start(cb3)],
}
}

impl<I: TransportIdentity> InMemoryNetwork<I> {
#[must_use]
#[allow(clippy::missing_panics_doc)]
pub fn helper_identities(&self) -> [HelperIdentity; 3] {
pub fn identities(&self) -> [I; 3] {
self.transports
.iter()
.map(|t| t.identity())
Expand All @@ -59,7 +43,7 @@ impl InMemoryNetwork {
/// ## Panics
/// If [`HelperIdentity`] is somehow points to a non-existent helper, which shouldn't happen.
#[must_use]
pub fn transport(&self, id: HelperIdentity) -> InMemoryTransport {
pub fn transport(&self, id: I) -> InMemoryTransport<I> {
self.transports
.iter()
.find(|t| t.identity() == id)
Expand All @@ -68,8 +52,8 @@ impl InMemoryNetwork {

#[allow(clippy::missing_panics_doc)]
#[must_use]
pub fn transports(&self) -> [InMemoryTransport; 3] {
let transports: [InMemoryTransport; 3] = self
pub fn transports(&self) -> [InMemoryTransport<I>; 3] {
let transports: [InMemoryTransport<_>; 3] = self
.transports
.iter()
.map(Arc::downgrade)
Expand All @@ -87,3 +71,21 @@ impl InMemoryNetwork {
}
}
}

impl InMemoryNetwork<HelperIdentity> {
#[must_use]
pub fn new(callbacks: [TransportCallbacks<InMemoryTransport<HelperIdentity>>; 3]) -> Self {
let [mut first, mut second, mut third]: [_; 3] =
HelperIdentity::make_three().map(Setup::new);

first.connect(&mut second);
second.connect(&mut third);
third.connect(&mut first);

let [cb1, cb2, cb3] = callbacks;

Self {
transports: [first.start(cb1), second.start(cb2), third.start(cb3)],
}
}
}
90 changes: 53 additions & 37 deletions ipa-core/src/helpers/transport/in_memory/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,31 @@ use crate::{
helpers::{
query::{PrepareQuery, QueryConfig},
HelperIdentity, NoResourceIdentifier, QueryIdBinding, ReceiveRecords, RouteId, RouteParams,
StepBinding, StreamCollection, Transport, TransportCallbacks,
StepBinding, StreamCollection, Transport, TransportCallbacks, TransportIdentity,
},
protocol::{step::Gate, QueryId},
sync::{Arc, Weak},
};

type Packet = (Addr, InMemoryStream, oneshot::Sender<Result<(), Error>>);
type ConnectionTx = Sender<Packet>;
type ConnectionRx = Receiver<Packet>;
type Packet<I> = (
Addr<I>,
InMemoryStream,
oneshot::Sender<Result<(), Error<I>>>,
);
type ConnectionTx<I> = Sender<Packet<I>>;
type ConnectionRx<I> = Receiver<Packet<I>>;
type StreamItem = Vec<u8>;

#[derive(Debug, thiserror::Error)]
pub enum Error {
pub enum Error<I> {
#[error(transparent)]
Io {
#[from]
inner: io::Error,
},
#[error("Request rejected by remote {dest:?}: {inner:?}")]
Rejected {
dest: HelperIdentity,
dest: I,
#[source]
inner: BoxError,
},
Expand All @@ -54,15 +58,15 @@ pub enum Error {
/// In-memory implementation of [`Transport`] backed by Tokio mpsc channels.
/// Use [`Setup`] to initialize it and call [`Setup::start`] to make it actively listen for
/// incoming messages.
pub struct InMemoryTransport {
identity: HelperIdentity,
connections: HashMap<HelperIdentity, ConnectionTx>,
record_streams: StreamCollection<InMemoryStream>,
pub struct InMemoryTransport<I> {
identity: I,
connections: HashMap<I, ConnectionTx<I>>,
record_streams: StreamCollection<I, InMemoryStream>,
}

impl InMemoryTransport {
impl<I: TransportIdentity> InMemoryTransport<I> {
#[must_use]
fn new(identity: HelperIdentity, connections: HashMap<HelperIdentity, ConnectionTx>) -> Self {
fn new(identity: I, connections: HashMap<I, ConnectionTx<I>>) -> Self {
Self {
identity,
connections,
Expand All @@ -71,15 +75,19 @@ impl InMemoryTransport {
}

#[must_use]
pub fn identity(&self) -> HelperIdentity {
pub fn identity(&self) -> I {
self.identity
}

/// TODO: maybe it shouldn't be active, but rather expose a method that takes the next message
/// out and processes it, the same way as query processor does. That will allow all tasks to be
/// created in one place (driver). It does not affect the [`Transport`] interface,
/// so I'll leave it as is for now.
fn listen(self: &Arc<Self>, callbacks: TransportCallbacks<Weak<Self>>, mut rx: ConnectionRx) {
fn listen(
self: &Arc<Self>,
callbacks: TransportCallbacks<Weak<Self>>,
mut rx: ConnectionRx<I>,
) {
tokio::spawn(
{
let streams = self.record_streams.clone();
Expand Down Expand Up @@ -132,7 +140,7 @@ impl InMemoryTransport {
);
}

fn get_channel(&self, dest: HelperIdentity) -> ConnectionTx {
fn get_channel(&self, dest: I) -> ConnectionTx<I> {
self.connections
.get(&dest)
.unwrap_or_else(|| {
Expand All @@ -151,11 +159,11 @@ impl InMemoryTransport {
}

#[async_trait]
impl Transport<HelperIdentity> for Weak<InMemoryTransport> {
type RecordsStream = ReceiveRecords<InMemoryStream>;
type Error = Error;
impl<I: TransportIdentity> Transport<I> for Weak<InMemoryTransport<I>> {
type RecordsStream = ReceiveRecords<I, InMemoryStream>;
type Error = Error<I>;

fn identity(&self) -> HelperIdentity {
fn identity(&self) -> I {
self.upgrade().unwrap().identity
}

Expand All @@ -166,10 +174,10 @@ impl Transport<HelperIdentity> for Weak<InMemoryTransport> {
R: RouteParams<RouteId, Q, S>,
>(
&self,
dest: HelperIdentity,
dest: I,
route: R,
data: D,
) -> Result<(), Error>
) -> Result<(), Error<I>>
where
Option<QueryId>: From<Q>,
Option<Gate>: From<S>,
Expand Down Expand Up @@ -197,7 +205,7 @@ impl Transport<HelperIdentity> for Weak<InMemoryTransport> {

fn receive<R: RouteParams<NoResourceIdentifier, QueryId, Gate>>(
&self,
from: HelperIdentity,
from: I,
route: R,
) -> Self::RecordsStream {
ReceiveRecords::new(
Expand Down Expand Up @@ -261,18 +269,18 @@ impl Debug for InMemoryStream {
}
}

struct Addr {
struct Addr<I> {
route: RouteId,
origin: Option<HelperIdentity>,
origin: Option<I>,
query_id: Option<QueryId>,
gate: Option<Gate>,
params: String,
}

impl Addr {
impl<I: TransportIdentity> Addr<I> {
#[allow(clippy::needless_pass_by_value)] // to avoid using double-reference at callsites
fn from_route<Q: QueryIdBinding, S: StepBinding, R: RouteParams<RouteId, Q, S>>(
origin: HelperIdentity,
origin: I,
route: R,
) -> Self
where
Expand All @@ -293,7 +301,7 @@ impl Addr {
}

#[cfg(all(test, unit_test))]
fn records(from: HelperIdentity, query_id: QueryId, gate: Gate) -> Self {
fn records(from: I, query_id: QueryId, gate: Gate) -> Self {
Self {
route: RouteId::Records,
origin: Some(from),
Expand All @@ -304,7 +312,7 @@ impl Addr {
}
}

impl Debug for Addr {
impl<I: Debug> Debug for Addr<I> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
Expand All @@ -316,9 +324,9 @@ impl Debug for Addr {

pub struct Setup {
identity: HelperIdentity,
tx: ConnectionTx,
rx: ConnectionRx,
connections: HashMap<HelperIdentity, ConnectionTx>,
tx: ConnectionTx<HelperIdentity>,
rx: ConnectionRx<HelperIdentity>,
connections: HashMap<HelperIdentity, ConnectionTx<HelperIdentity>>,
}

impl Setup {
Expand Down Expand Up @@ -350,8 +358,11 @@ impl Setup {

fn into_active_conn(
self,
callbacks: TransportCallbacks<Weak<InMemoryTransport>>,
) -> (ConnectionTx, Arc<InMemoryTransport>) {
callbacks: TransportCallbacks<Weak<InMemoryTransport<HelperIdentity>>>,
) -> (
ConnectionTx<HelperIdentity>,
Arc<InMemoryTransport<HelperIdentity>>,
) {
let transport = Arc::new(InMemoryTransport::new(self.identity, self.connections));
transport.listen(callbacks, self.rx);

Expand All @@ -361,8 +372,8 @@ impl Setup {
#[must_use]
pub fn start(
self,
callbacks: TransportCallbacks<Weak<InMemoryTransport>>,
) -> Arc<InMemoryTransport> {
callbacks: TransportCallbacks<Weak<InMemoryTransport<HelperIdentity>>>,
) -> Arc<InMemoryTransport<HelperIdentity>> {
self.into_active_conn(callbacks).1
}
}
Expand Down Expand Up @@ -391,14 +402,19 @@ mod tests {
InMemoryNetwork, Setup,
},
HelperIdentity, OrderingSender, RouteId, Transport, TransportCallbacks,
TransportIdentity,
},
protocol::{step::Gate, QueryId},
sync::Arc,
};

const STEP: &str = "in-memory-transport";

async fn send_and_ack(sender: &ConnectionTx, addr: Addr, data: InMemoryStream) {
async fn send_and_ack<I: TransportIdentity>(
sender: &ConnectionTx<I>,
addr: Addr<I>,
data: InMemoryStream,
) {
let (tx, rx) = oneshot::channel();
sender.send((addr, data, tx)).await.unwrap();
rx.await
Expand Down Expand Up @@ -491,7 +507,7 @@ mod tests {
async fn send_and_verify(
from: HelperIdentity,
to: HelperIdentity,
transports: &HashMap<HelperIdentity, Weak<InMemoryTransport>>,
transports: &HashMap<HelperIdentity, Weak<InMemoryTransport<HelperIdentity>>>,
) {
let (stream_tx, stream_rx) = channel(1);
let stream = InMemoryStream::from(stream_rx);
Expand Down
Loading

0 comments on commit dfe19fa

Please sign in to comment.