Skip to content

Commit

Permalink
allow TcpStreamConnector's to be made in factories
Browse files Browse the repository at this point in the history
  • Loading branch information
GlenDC committed Nov 18, 2024
1 parent e8a709e commit a748478
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 27 deletions.
6 changes: 3 additions & 3 deletions rama-tcp/src/client/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use tokio::{

/// Trait used internally by [`tcp_connect`] and the `TcpConnector`
/// to actually establish the [`TcpStream`.]
pub trait TcpStreamConnector: Send + Sync + 'static {
pub trait TcpStreamConnector: Clone + Send + Sync + 'static {
/// Type of error that can occurr when establishing the connection failed.
type Error;

Expand Down Expand Up @@ -60,7 +60,7 @@ impl<T: TcpStreamConnector> TcpStreamConnector for Arc<T> {

impl<ConnectFn, ConnectFnFut, ConnectFnErr> TcpStreamConnector for ConnectFn
where
ConnectFn: Fn(SocketAddr) -> ConnectFnFut + Send + Sync + 'static,
ConnectFn: FnOnce(SocketAddr) -> ConnectFnFut + Clone + Send + Sync + 'static,
ConnectFnFut: Future<Output = Result<TcpStream, ConnectFnErr>> + Send + 'static,
ConnectFnErr: Into<BoxError> + Send + 'static,
{
Expand All @@ -70,7 +70,7 @@ where
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
(self)(addr)
(self.clone())(addr)
}
}

Expand Down
66 changes: 42 additions & 24 deletions rama-tcp/src/client/service/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ use tokio::net::TcpStream;

use crate::client::connect::TcpStreamConnector;

use super::{CreatedTcpStreamConnector, TcpStreamConnectorCloneFactory, TcpStreamConnectorFactory};

#[derive(Debug, Clone)]
#[non_exhaustive]
/// A connector which can be used to establish a TCP connection to a server.
pub struct TcpConnector<Dns = HickoryDns, Connector = ()> {
pub struct TcpConnector<Dns = HickoryDns, ConnectorFactory = ()> {
dns: Dns,
connector: Connector,
connector_factory: ConnectorFactory,
}

impl<Dns, Connector> TcpConnector<Dns, Connector> {}
Expand All @@ -30,33 +32,43 @@ impl TcpConnector {
pub fn new() -> Self {
Self {
dns: HickoryDns::default(),
connector: (),
connector_factory: (),
}
}
}

impl<Dns, Connector> TcpConnector<Dns, Connector> {
impl<Dns, ConnectorFactory> TcpConnector<Dns, ConnectorFactory> {
/// Consume `self` to attach the given `dns` (a [`DnsResolver`]) as a new [`TcpConnector`].
pub fn with_dns<OtherDns>(self, dns: OtherDns) -> TcpConnector<OtherDns, Connector>
pub fn with_dns<OtherDns>(self, dns: OtherDns) -> TcpConnector<OtherDns, ConnectorFactory>
where
OtherDns: DnsResolver<Error: Into<BoxError>> + Clone,
{
TcpConnector {
dns,
connector: self.connector,
connector_factory: self.connector_factory,
}
}
}

impl<Dns> TcpConnector<Dns, ()> {
/// Consume `self` to attach the given `Connector` (a [`TcpStreamConnector`]) as a new [`TcpConnector`].
pub fn with_connector<Connector>(self, connector: Connector) -> TcpConnector<Dns, Connector>
where
Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
{
pub fn with_connector<Connector>(
self,
connector: Connector,
) -> TcpConnector<Dns, TcpStreamConnectorCloneFactory<Connector>>
where {
TcpConnector {
dns: self.dns,
connector_factory: TcpStreamConnectorCloneFactory(connector),
}
}

/// Consume `self` to attach the given `Factory` (a [`TcpStreamConnectorFactory`]) as a new [`TcpConnector`].
pub fn with_connector_factory<Factory>(self, factory: Factory) -> TcpConnector<Dns, Factory>
where {
TcpConnector {
dns: self.dns,
connector,
connector_factory: factory,
}
}
}
Expand All @@ -67,29 +79,40 @@ impl Default for TcpConnector {
}
}

impl<State, Request, Dns, Connector> Service<State, Request> for TcpConnector<Dns, Connector>
impl<State, Request, Dns, ConnectorFactory> Service<State, Request>
for TcpConnector<Dns, ConnectorFactory>
where
State: Clone + Send + Sync + 'static,
Request: TryRefIntoTransportContext<State> + Send + 'static,
Request::Error: Into<BoxError> + Send + Sync + 'static,
Dns: DnsResolver<Error: Into<BoxError>> + Clone,
Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
ConnectorFactory: TcpStreamConnectorFactory<
State,
Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static>,
Error: Into<BoxError> + Send + 'static,
> + Clone,
{
type Response = EstablishedClientConnection<TcpStream, State, Request>;
type Error = BoxError;

async fn serve(
&self,
mut ctx: Context<State>,
ctx: Context<State>,
req: Request,
) -> Result<Self::Response, Self::Error> {
let CreatedTcpStreamConnector { mut ctx, connector } = self
.connector_factory
.make_connector(ctx)
.await
.map_err(Into::into)?;

if let Some(proxy) = ctx.get::<ProxyAddress>() {
let (conn, addr) = crate::client::tcp_connect(
&ctx,
proxy.authority.clone(),
true,
self.dns.clone(),
self.connector.clone(),
connector,
)
.await
.context("tcp connector: conncept to proxy")?;
Expand Down Expand Up @@ -120,15 +143,10 @@ where
}

let authority = transport_ctx.authority.clone();
let (conn, addr) = crate::client::tcp_connect(
&ctx,
authority,
false,
self.dns.clone(),
self.connector.clone(),
)
.await
.context("tcp connector: connect to server")?;
let (conn, addr) =
crate::client::tcp_connect(&ctx, authority, false, self.dns.clone(), connector)
.await
.context("tcp connector: connect to server")?;

Ok(EstablishedClientConnection {
ctx,
Expand Down
6 changes: 6 additions & 0 deletions rama-tcp/src/client/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ pub use forward::{ForwardAuthority, Forwarder};
mod connector;
#[doc(inline)]
pub use connector::TcpConnector;

mod select;
#[doc(inline)]
pub use select::{
CreatedTcpStreamConnector, TcpStreamConnectorCloneFactory, TcpStreamConnectorFactory,
};
126 changes: 126 additions & 0 deletions rama-tcp/src/client/service/select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
/// Trait used internally by [`tcp_connect`] and the `TcpConnector`
/// to actually establish the [`TcpStream`.]
pub trait TcpStreamConnector: Send + Sync + 'static {
/// Type of error that can occurr when establishing the connection failed.
type Error;
/// Connect to the target via the given [`SocketAddr`]ess to establish a [`TcpStream`].
fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_;
}
*/

use rama_core::error::BoxError;
use rama_core::Context;
use std::{convert::Infallible, future::Future, sync::Arc};

use crate::client::TcpStreamConnector;

pub struct CreatedTcpStreamConnector<State, Connector> {
pub ctx: Context<State>,
pub connector: Connector,
}

pub trait TcpStreamConnectorFactory<State>: Send + Sync + 'static {
type Connector: TcpStreamConnector;
type Error;

fn make_connector(
&self,
ctx: Context<State>,
) -> impl Future<Output = Result<CreatedTcpStreamConnector<State, Self::Connector>, Self::Error>>
+ Send
+ '_;
}

impl<State: Send + Sync + 'static> TcpStreamConnectorFactory<State> for () {
type Connector = ();
type Error = Infallible;

fn make_connector(
&self,
ctx: Context<State>,
) -> impl Future<Output = Result<CreatedTcpStreamConnector<State, Self::Connector>, Self::Error>>
+ Send
+ '_ {
std::future::ready(Ok(CreatedTcpStreamConnector { ctx, connector: () }))
}
}

pub struct TcpStreamConnectorCloneFactory<C>(pub(super) C);

impl<State, C> TcpStreamConnectorFactory<State> for TcpStreamConnectorCloneFactory<C>
where
C: TcpStreamConnector + Clone,
State: Send + Sync + 'static,
{
type Connector = C;
type Error = Infallible;

fn make_connector(
&self,
ctx: Context<State>,
) -> impl Future<Output = Result<CreatedTcpStreamConnector<State, Self::Connector>, Self::Error>>
+ Send
+ '_ {
std::future::ready(Ok(CreatedTcpStreamConnector {
ctx,
connector: self.0.clone(),
}))
}
}

impl<State, F> TcpStreamConnectorFactory<State> for Arc<F>
where
F: TcpStreamConnectorFactory<State>,
State: Send + Sync + 'static,
{
type Connector = F::Connector;
type Error = F::Error;

fn make_connector(
&self,
ctx: Context<State>,
) -> impl Future<Output = Result<CreatedTcpStreamConnector<State, Self::Connector>, Self::Error>>
+ Send
+ '_ {
(**self).make_connector(ctx)
}
}

macro_rules! impl_stream_connector_factory_either {
($id:ident, $($param:ident),+ $(,)?) => {
impl<State, $($param),+> TcpStreamConnectorFactory<State> for ::rama_core::combinators::$id<$($param),+>
where
State: Send + Sync + 'static,
$(
$param: TcpStreamConnectorFactory<State, Connector: TcpStreamConnector<Error: Into<BoxError>>, Error: Into<BoxError>>,
)+
{
type Connector = ::rama_core::combinators::$id<$($param::Connector),+>;
type Error = BoxError;

async fn make_connector(
&self,
ctx: Context<State>,
) -> Result<CreatedTcpStreamConnector<State, Self::Connector>, Self::Error> {
match self {
$(
::rama_core::combinators::$id::$param(s) => match s.make_connector(ctx).await {
Err(e) => Err(e.into()),
Ok(CreatedTcpStreamConnector{ ctx, connector }) => Ok(CreatedTcpStreamConnector{
ctx,
connector: ::rama_core::combinators::$id::$param(connector),
}),
},
)+
}
}
}
};
}

::rama_core::combinators::impl_either!(impl_stream_connector_factory_either);

0 comments on commit a748478

Please sign in to comment.