diff --git a/bin/relayer/src/agent.rs b/bin/relayer/src/agent.rs index 51613d8..7de88c6 100644 --- a/bin/relayer/src/agent.rs +++ b/bin/relayer/src/agent.rs @@ -55,14 +55,14 @@ impl Clone for AgentSession { } } -pub enum AgentListenerEvent { +pub enum AgentListenerEvent { Connected(AgentId, AgentSession), - IncomingStream(AgentId, S), + IncomingStream(AgentId, C, S), Disconnected(AgentId, AgentSessionId), } -pub trait AgentListener { - async fn recv(&mut self) -> anyhow::Result>; +pub trait AgentListener { + async fn recv(&mut self) -> anyhow::Result>; async fn shutdown(&mut self); } diff --git a/bin/relayer/src/agent/quic.rs b/bin/relayer/src/agent/quic.rs index ba9bb6f..a9c36a3 100644 --- a/bin/relayer/src/agent/quic.rs +++ b/bin/relayer/src/agent/quic.rs @@ -2,7 +2,11 @@ use std::{marker::PhantomData, net::SocketAddr, sync::Arc, time::Instant}; use anyhow::anyhow; use metrics::histogram; -use protocol::{key::ClusterValidator, proxy::AgentId, stream::TunnelStream}; +use protocol::{ + key::{ClusterRequest, ClusterValidator}, + proxy::AgentId, + stream::TunnelStream, +}; use quinn::{Endpoint, Incoming, VarInt}; use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; use serde::de::DeserializeOwned; @@ -19,15 +23,15 @@ use crate::{ use super::{AgentListener, AgentListenerEvent, AgentSession, AgentSessionId}; -pub struct AgentQuicListener { +pub struct AgentQuicListener { validate: Arc, endpoint: Endpoint, - internal_tx: Sender>, - internal_rx: Receiver>, + internal_tx: Sender>, + internal_rx: Receiver>, _tmp: PhantomData, } -impl AgentQuicListener { +impl AgentQuicListener { pub async fn new(addr: SocketAddr, priv_key: PrivatePkcs8KeyDer<'static>, cert: CertificateDer<'static>, validate: VALIDATE) -> anyhow::Result { let endpoint = make_server_endpoint(addr, priv_key, cert)?; let (internal_tx, internal_rx) = channel(10); @@ -42,8 +46,8 @@ impl AgentQuicListener { } } -impl, REQ: DeserializeOwned + Send + Sync + 'static> AgentListener for AgentQuicListener { - async fn recv(&mut self) -> anyhow::Result> { +impl, REQ: DeserializeOwned + Send + Sync + 'static + ClusterRequest> AgentListener for AgentQuicListener { + async fn recv(&mut self) -> anyhow::Result> { loop { select! { incoming = self.endpoint.accept() => { @@ -59,7 +63,11 @@ impl, REQ: DeserializeOwned + Send + Sync + 'sta } } -async fn run_connection, REQ>(validate: Arc, incoming: Incoming, internal_tx: Sender>) -> anyhow::Result<()> { +async fn run_connection, REQ: ClusterRequest>( + validate: Arc, + incoming: Incoming, + internal_tx: Sender>, +) -> anyhow::Result<()> { let started = Instant::now(); log::info!("[AgentQuic] new connection from {}", incoming.remote_address()); @@ -74,6 +82,7 @@ async fn run_connection, REQ>(validate: Arc, REQ>(validate: Arc { let stream = TunnelStream::new(recv, send); let internal_tx = internal_tx.clone(); + let agent_ctx = agent_ctx.clone(); tokio::spawn(async move { - internal_tx.send(AgentListenerEvent::IncomingStream(agent_id, stream)).await.expect("should send to main loop"); + internal_tx.send(AgentListenerEvent::IncomingStream(agent_id, agent_ctx, stream)).await.expect("should send to main loop"); }); }, Err(err) => { diff --git a/bin/relayer/src/agent/tcp.rs b/bin/relayer/src/agent/tcp.rs index a737494..5240ce7 100644 --- a/bin/relayer/src/agent/tcp.rs +++ b/bin/relayer/src/agent/tcp.rs @@ -2,7 +2,10 @@ use std::{marker::PhantomData, net::SocketAddr, sync::Arc, time::Instant}; use futures::StreamExt; use metrics::histogram; -use protocol::{key::ClusterValidator, proxy::AgentId}; +use protocol::{ + key::{ClusterRequest, ClusterValidator}, + proxy::AgentId, +}; use serde::de::DeserializeOwned; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, @@ -18,15 +21,15 @@ use super::{AgentListener, AgentListenerEvent, AgentSession, AgentSessionId}; pub type TunnelTcpStream = StreamHandle; -pub struct AgentTcpListener { +pub struct AgentTcpListener { validate: Arc, listener: TcpListener, - internal_tx: Sender>, - internal_rx: Receiver>, + internal_tx: Sender>, + internal_rx: Receiver>, _tmp: PhantomData, } -impl AgentTcpListener { +impl AgentTcpListener { pub async fn new(addr: SocketAddr, validate: VALIDATE) -> anyhow::Result { let (internal_tx, internal_rx) = channel(10); @@ -40,8 +43,8 @@ impl AgentTcpListener { } } -impl, REQ: DeserializeOwned + Send + Sync + 'static> AgentListener for AgentTcpListener { - async fn recv(&mut self) -> anyhow::Result> { +impl, REQ: DeserializeOwned + Send + Sync + 'static + ClusterRequest> AgentListener for AgentTcpListener { + async fn recv(&mut self) -> anyhow::Result> { loop { select! { incoming = self.listener.accept() => { @@ -56,11 +59,11 @@ impl, REQ: DeserializeOwned + Send + Sync + 'sta async fn shutdown(&mut self) {} } -async fn run_connection, REQ>( +async fn run_connection, REQ: ClusterRequest>( validate: Arc, mut in_stream: TcpStream, remote: SocketAddr, - internal_tx: Sender>, + internal_tx: Sender>, ) -> anyhow::Result<()> { let started = Instant::now(); log::info!("[AgentTcp] new connection from {}", remote); @@ -74,6 +77,7 @@ async fn run_connection, REQ>( let domain = validate.generate_domain(&req)?; let agent_id = AgentId::try_from_domain(&domain)?; let session_id = AgentSessionId::rand(); + let agent_ctx = req.context(); log::info!("[AgentTcp] new connection validated with domain {domain} agent_id: {agent_id}, session uuid: {session_id}"); @@ -118,8 +122,9 @@ async fn run_connection, REQ>( accept = session.next() => match accept { Some(Ok(stream)) => { let internal_tx = internal_tx.clone(); + let agent_ctx = agent_ctx.clone(); tokio::spawn(async move { - internal_tx.send(AgentListenerEvent::IncomingStream(agent_id, stream)).await.expect("should send to main loop"); + internal_tx.send(AgentListenerEvent::IncomingStream(agent_id, agent_ctx, stream)).await.expect("should send to main loop"); }); }, Some(Err(err)) => { diff --git a/bin/relayer/src/lib.rs b/bin/relayer/src/lib.rs index 2bd3377..e75ef12 100644 --- a/bin/relayer/src/lib.rs +++ b/bin/relayer/src/lib.rs @@ -13,7 +13,7 @@ use p2p::{ }; use protocol::{ cluster::{write_object, AgentTunnelRequest}, - key::ClusterValidator, + key::{ClusterRequest, ClusterValidator}, proxy::{AgentId, ProxyDestination}, }; use quic::TunnelQuicStream; @@ -45,9 +45,9 @@ pub struct TunnelServiceCtx { } /// This service take care how we process a incoming request from agent -pub trait TunnelServiceHandle { +pub trait TunnelServiceHandle { fn start(&mut self, _ctx: &TunnelServiceCtx); - fn on_agent_conn(&mut self, _ctx: &TunnelServiceCtx, _agent_id: AgentId, _stream: S); + fn on_agent_conn(&mut self, _ctx: &TunnelServiceCtx, _agent_id: AgentId, ctx: Ctx, _stream: S); fn on_cluster_event(&mut self, _ctx: &TunnelServiceCtx, _event: P2pServiceEvent); } @@ -78,7 +78,7 @@ pub enum QuicRelayerEvent { Continue, } -pub struct QuicRelayer { +pub struct QuicRelayer { agent_quic: AgentQuicListener, agent_tcp: AgentTcpListener, http_proxy: ProxyTcpListener, @@ -100,12 +100,12 @@ pub struct QuicRelayer { agent_tcp_sessions: HashMap, AliasGuard)>>, } -impl QuicRelayer +impl QuicRelayer where SECURE: HandshakeProtocol, VALIDATE: ClusterValidator, REQ: DeserializeOwned + Send + Sync + 'static, - TSH: TunnelServiceHandle + Send + Sync + 'static, + TSH: TunnelServiceHandle + Send + Sync + 'static, { pub async fn new(mut cfg: QuicRelayerConfig, validate: VALIDATE) -> anyhow::Result { let mut sdn = P2pNetwork::new(P2pNetworkConfig { @@ -237,8 +237,8 @@ where gauge!(METRICS_AGENT_LIVE).increment(1.0); Ok(QuicRelayerEvent::AgentConnected(agent_id, session_id, domain)) }, - AgentListenerEvent::IncomingStream(agent_id, stream) => { - self.tunnel_service_handle.on_agent_conn(&self.tunnel_service_ctx, agent_id, stream); + AgentListenerEvent::IncomingStream(agent_id, agent_ctx, stream) => { + self.tunnel_service_handle.on_agent_conn(&self.tunnel_service_ctx, agent_id, agent_ctx, stream); Ok(QuicRelayerEvent::Continue) } AgentListenerEvent::Disconnected(agent_id, session_id) => { @@ -265,8 +265,8 @@ where gauge!(METRICS_AGENT_LIVE).increment(1.0); Ok(QuicRelayerEvent::AgentConnected(agent_id, session_id, domain)) }, - AgentListenerEvent::IncomingStream(agent_id, stream) => { - self.tunnel_service_handle.on_agent_conn(&self.tunnel_service_ctx, agent_id, stream); + AgentListenerEvent::IncomingStream(agent_id, agent_ctx, stream) => { + self.tunnel_service_handle.on_agent_conn(&self.tunnel_service_ctx, agent_id, agent_ctx, stream); Ok(QuicRelayerEvent::Continue) } AgentListenerEvent::Disconnected(agent_id, session_id) => { diff --git a/bin/relayer/src/main.rs b/bin/relayer/src/main.rs index 33b92ab..b0be7fc 100644 --- a/bin/relayer/src/main.rs +++ b/bin/relayer/src/main.rs @@ -102,10 +102,17 @@ async fn main() { struct DummyTunnelHandle; -impl TunnelServiceHandle for DummyTunnelHandle { +impl TunnelServiceHandle> for DummyTunnelHandle { fn start(&mut self, _ctx: &atm0s_reverse_proxy_relayer::TunnelServiceCtx) {} - fn on_agent_conn(&mut self, _ctx: &atm0s_reverse_proxy_relayer::TunnelServiceCtx, _agent_id: protocol::proxy::AgentId, _stream: S) {} + fn on_agent_conn( + &mut self, + _ctx: &atm0s_reverse_proxy_relayer::TunnelServiceCtx, + _agent_id: protocol::proxy::AgentId, + _metadata: Vec, + _stream: S, + ) { + } fn on_cluster_event(&mut self, _ctx: &atm0s_reverse_proxy_relayer::TunnelServiceCtx, _event: p2p::P2pServiceEvent) {} } diff --git a/crates/protocol/src/key.rs b/crates/protocol/src/key.rs index e37846b..8288f83 100644 --- a/crates/protocol/src/key.rs +++ b/crates/protocol/src/key.rs @@ -1,9 +1,16 @@ +use std::fmt::Debug; + +pub trait ClusterRequest { + type Context: Clone + Send + Sync + 'static + Debug; + fn context(&self) -> Self::Context; +} + pub trait AgentSigner { fn sign_connect_req(&self) -> Vec; fn validate_connect_res(&self, resp: &[u8]) -> anyhow::Result; } -pub trait ClusterValidator: Send + Sync + Clone + 'static { +pub trait ClusterValidator: Send + Sync + Clone + 'static { fn validate_connect_req(&self, req: &[u8]) -> anyhow::Result; fn generate_domain(&self, req: &REQ) -> anyhow::Result; fn sign_response_res(&self, m: &REQ, err: Option) -> Vec; diff --git a/crates/protocol_ed25519/src/lib.rs b/crates/protocol_ed25519/src/lib.rs index 351cbc5..0960825 100644 --- a/crates/protocol_ed25519/src/lib.rs +++ b/crates/protocol_ed25519/src/lib.rs @@ -2,7 +2,7 @@ use ed25519_dalek::pkcs8::spki::der::pem::LineEnding; use ed25519_dalek::pkcs8::{DecodePrivateKey, EncodePrivateKey}; use ed25519_dalek::SigningKey; use ed25519_dalek::{Signer, Verifier}; -use protocol::key::{AgentSigner, ClusterValidator}; +use protocol::key::{AgentSigner, ClusterRequest, ClusterValidator}; use rand::rngs::OsRng; use serde::{Deserialize, Serialize}; @@ -12,6 +12,14 @@ pub struct RegisterRequest { pub signature: ed25519_dalek::Signature, } +impl ClusterRequest for RegisterRequest { + type Context = Vec; + + fn context(&self) -> Self::Context { + self.pub_key.to_bytes().to_vec() + } +} + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct RegisterResponse { pub response: Result,