Skip to content

Commit

Permalink
feat: add agent context uses in forward from agent to service (#90)
Browse files Browse the repository at this point in the history
* feat: add agent context uses in forward from agent to service

* remove log

---------

Co-authored-by: giangndm <[email protected]>
  • Loading branch information
marverlous811 and giangndm authored Feb 14, 2025
1 parent 8ca0ba0 commit d7082c9
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 37 deletions.
8 changes: 4 additions & 4 deletions bin/relayer/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ impl<S> Clone for AgentSession<S> {
}
}

pub enum AgentListenerEvent<S> {
pub enum AgentListenerEvent<C, S> {
Connected(AgentId, AgentSession<S>),
IncomingStream(AgentId, S),
IncomingStream(AgentId, C, S),
Disconnected(AgentId, AgentSessionId),
}

pub trait AgentListener<S: AsyncRead + AsyncWrite> {
async fn recv(&mut self) -> anyhow::Result<AgentListenerEvent<S>>;
pub trait AgentListener<C, S: AsyncRead + AsyncWrite> {
async fn recv(&mut self) -> anyhow::Result<AgentListenerEvent<C, S>>;
async fn shutdown(&mut self);
}

Expand Down
28 changes: 19 additions & 9 deletions bin/relayer/src/agent/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ use std::{marker::PhantomData, net::SocketAddr, sync::Arc, time::Instant};

use anyhow::anyhow;
use metrics::histogram;
use protocol::{key::ClusterValidator, proxy::AgentId, stream::TunnelStream};
use protocol::{
key::{ClusterRequest, ClusterValidator},
proxy::AgentId,
stream::TunnelStream,
};
use quinn::{Endpoint, Incoming, VarInt};
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
use serde::de::DeserializeOwned;
Expand All @@ -19,15 +23,15 @@ use crate::{

use super::{AgentListener, AgentListenerEvent, AgentSession, AgentSessionId};

pub struct AgentQuicListener<VALIDATE, HANDSHAKE> {
pub struct AgentQuicListener<VALIDATE, HANDSHAKE: ClusterRequest> {
validate: Arc<VALIDATE>,
endpoint: Endpoint,
internal_tx: Sender<AgentListenerEvent<TunnelQuicStream>>,
internal_rx: Receiver<AgentListenerEvent<TunnelQuicStream>>,
internal_tx: Sender<AgentListenerEvent<HANDSHAKE::Context, TunnelQuicStream>>,
internal_rx: Receiver<AgentListenerEvent<HANDSHAKE::Context, TunnelQuicStream>>,
_tmp: PhantomData<HANDSHAKE>,
}

impl<VALIDATE, HANDSHAKE> AgentQuicListener<VALIDATE, HANDSHAKE> {
impl<VALIDATE, HANDSHAKE: ClusterRequest> AgentQuicListener<VALIDATE, HANDSHAKE> {
pub async fn new(addr: SocketAddr, priv_key: PrivatePkcs8KeyDer<'static>, cert: CertificateDer<'static>, validate: VALIDATE) -> anyhow::Result<Self> {
let endpoint = make_server_endpoint(addr, priv_key, cert)?;
let (internal_tx, internal_rx) = channel(10);
Expand All @@ -42,8 +46,8 @@ impl<VALIDATE, HANDSHAKE> AgentQuicListener<VALIDATE, HANDSHAKE> {
}
}

impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'static> AgentListener<TunnelQuicStream> for AgentQuicListener<VALIDATE, REQ> {
async fn recv(&mut self) -> anyhow::Result<AgentListenerEvent<TunnelQuicStream>> {
impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'static + ClusterRequest> AgentListener<REQ::Context, TunnelQuicStream> for AgentQuicListener<VALIDATE, REQ> {
async fn recv(&mut self) -> anyhow::Result<AgentListenerEvent<REQ::Context, TunnelQuicStream>> {
loop {
select! {
incoming = self.endpoint.accept() => {
Expand All @@ -59,7 +63,11 @@ impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'sta
}
}

async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ>(validate: Arc<VALIDATE>, incoming: Incoming, internal_tx: Sender<AgentListenerEvent<TunnelQuicStream>>) -> anyhow::Result<()> {
async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ: ClusterRequest>(
validate: Arc<VALIDATE>,
incoming: Incoming,
internal_tx: Sender<AgentListenerEvent<REQ::Context, TunnelQuicStream>>,
) -> anyhow::Result<()> {
let started = Instant::now();
log::info!("[AgentQuic] new connection from {}", incoming.remote_address());

Expand All @@ -74,6 +82,7 @@ async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ>(validate: Arc<VALI
let domain = validate.generate_domain(&req)?;
let agent_id = AgentId::try_from_domain(&domain)?;
let session_id = AgentSessionId::rand();
let agent_ctx = req.context();

log::info!("[AgentQuic] new connection validated with domain {domain} agent_id: {agent_id}, session uuid: {session_id}");

Expand Down Expand Up @@ -121,8 +130,9 @@ async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ>(validate: Arc<VALI
Ok((send, recv)) => {
let stream = TunnelStream::new(recv, send);
let internal_tx = internal_tx.clone();
let agent_ctx = agent_ctx.clone();
tokio::spawn(async move {
internal_tx.send(AgentListenerEvent::IncomingStream(agent_id, stream)).await.expect("should send to main loop");
internal_tx.send(AgentListenerEvent::IncomingStream(agent_id, agent_ctx, stream)).await.expect("should send to main loop");
});
},
Err(err) => {
Expand Down
25 changes: 15 additions & 10 deletions bin/relayer/src/agent/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use std::{marker::PhantomData, net::SocketAddr, sync::Arc, time::Instant};

use futures::StreamExt;
use metrics::histogram;
use protocol::{key::ClusterValidator, proxy::AgentId};
use protocol::{
key::{ClusterRequest, ClusterValidator},
proxy::AgentId,
};
use serde::de::DeserializeOwned;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
Expand All @@ -18,15 +21,15 @@ use super::{AgentListener, AgentListenerEvent, AgentSession, AgentSessionId};

pub type TunnelTcpStream = StreamHandle;

pub struct AgentTcpListener<VALIDATE, HANDSHAKE> {
pub struct AgentTcpListener<VALIDATE, HANDSHAKE: ClusterRequest> {
validate: Arc<VALIDATE>,
listener: TcpListener,
internal_tx: Sender<AgentListenerEvent<TunnelTcpStream>>,
internal_rx: Receiver<AgentListenerEvent<TunnelTcpStream>>,
internal_tx: Sender<AgentListenerEvent<HANDSHAKE::Context, TunnelTcpStream>>,
internal_rx: Receiver<AgentListenerEvent<HANDSHAKE::Context, TunnelTcpStream>>,
_tmp: PhantomData<HANDSHAKE>,
}

impl<VALIDATE, HANDSHAKE> AgentTcpListener<VALIDATE, HANDSHAKE> {
impl<VALIDATE, HANDSHAKE: ClusterRequest> AgentTcpListener<VALIDATE, HANDSHAKE> {
pub async fn new(addr: SocketAddr, validate: VALIDATE) -> anyhow::Result<Self> {
let (internal_tx, internal_rx) = channel(10);

Expand All @@ -40,8 +43,8 @@ impl<VALIDATE, HANDSHAKE> AgentTcpListener<VALIDATE, HANDSHAKE> {
}
}

impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'static> AgentListener<TunnelTcpStream> for AgentTcpListener<VALIDATE, REQ> {
async fn recv(&mut self) -> anyhow::Result<AgentListenerEvent<TunnelTcpStream>> {
impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'static + ClusterRequest> AgentListener<REQ::Context, TunnelTcpStream> for AgentTcpListener<VALIDATE, REQ> {
async fn recv(&mut self) -> anyhow::Result<AgentListenerEvent<REQ::Context, TunnelTcpStream>> {
loop {
select! {
incoming = self.listener.accept() => {
Expand All @@ -56,11 +59,11 @@ impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'sta
async fn shutdown(&mut self) {}
}

async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ>(
async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ: ClusterRequest>(
validate: Arc<VALIDATE>,
mut in_stream: TcpStream,
remote: SocketAddr,
internal_tx: Sender<AgentListenerEvent<TunnelTcpStream>>,
internal_tx: Sender<AgentListenerEvent<REQ::Context, TunnelTcpStream>>,
) -> anyhow::Result<()> {
let started = Instant::now();
log::info!("[AgentTcp] new connection from {}", remote);
Expand All @@ -74,6 +77,7 @@ async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ>(
let domain = validate.generate_domain(&req)?;
let agent_id = AgentId::try_from_domain(&domain)?;
let session_id = AgentSessionId::rand();
let agent_ctx = req.context();

log::info!("[AgentTcp] new connection validated with domain {domain} agent_id: {agent_id}, session uuid: {session_id}");

Expand Down Expand Up @@ -118,8 +122,9 @@ async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ>(
accept = session.next() => match accept {
Some(Ok(stream)) => {
let internal_tx = internal_tx.clone();
let agent_ctx = agent_ctx.clone();
tokio::spawn(async move {
internal_tx.send(AgentListenerEvent::IncomingStream(agent_id, stream)).await.expect("should send to main loop");
internal_tx.send(AgentListenerEvent::IncomingStream(agent_id, agent_ctx, stream)).await.expect("should send to main loop");
});
},
Some(Err(err)) => {
Expand Down
20 changes: 10 additions & 10 deletions bin/relayer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use p2p::{
};
use protocol::{
cluster::{write_object, AgentTunnelRequest},
key::ClusterValidator,
key::{ClusterRequest, ClusterValidator},
proxy::{AgentId, ProxyDestination},
};
use quic::TunnelQuicStream;
Expand Down Expand Up @@ -45,9 +45,9 @@ pub struct TunnelServiceCtx {
}

/// This service take care how we process a incoming request from agent
pub trait TunnelServiceHandle {
pub trait TunnelServiceHandle<Ctx> {
fn start(&mut self, _ctx: &TunnelServiceCtx);
fn on_agent_conn<S: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(&mut self, _ctx: &TunnelServiceCtx, _agent_id: AgentId, _stream: S);
fn on_agent_conn<S: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(&mut self, _ctx: &TunnelServiceCtx, _agent_id: AgentId, ctx: Ctx, _stream: S);
fn on_cluster_event(&mut self, _ctx: &TunnelServiceCtx, _event: P2pServiceEvent);
}

Expand Down Expand Up @@ -78,7 +78,7 @@ pub enum QuicRelayerEvent {
Continue,
}

pub struct QuicRelayer<SECURE, VALIDATE, REQ, TSH> {
pub struct QuicRelayer<SECURE, VALIDATE, REQ: ClusterRequest, TSH> {
agent_quic: AgentQuicListener<VALIDATE, REQ>,
agent_tcp: AgentTcpListener<VALIDATE, REQ>,
http_proxy: ProxyTcpListener<HttpDestinationDetector>,
Expand All @@ -100,12 +100,12 @@ pub struct QuicRelayer<SECURE, VALIDATE, REQ, TSH> {
agent_tcp_sessions: HashMap<AgentId, HashMap<AgentSessionId, (AgentSession<TunnelTcpStream>, AliasGuard)>>,
}

impl<SECURE, VALIDATE, REQ, TSH> QuicRelayer<SECURE, VALIDATE, REQ, TSH>
impl<SECURE, VALIDATE, REQ: ClusterRequest, TSH> QuicRelayer<SECURE, VALIDATE, REQ, TSH>
where
SECURE: HandshakeProtocol,
VALIDATE: ClusterValidator<REQ>,
REQ: DeserializeOwned + Send + Sync + 'static,
TSH: TunnelServiceHandle + Send + Sync + 'static,
TSH: TunnelServiceHandle<REQ::Context> + Send + Sync + 'static,
{
pub async fn new(mut cfg: QuicRelayerConfig<SECURE, TSH>, validate: VALIDATE) -> anyhow::Result<Self> {
let mut sdn = P2pNetwork::new(P2pNetworkConfig {
Expand Down Expand Up @@ -237,8 +237,8 @@ where
gauge!(METRICS_AGENT_LIVE).increment(1.0);
Ok(QuicRelayerEvent::AgentConnected(agent_id, session_id, domain))
},
AgentListenerEvent::IncomingStream(agent_id, stream) => {
self.tunnel_service_handle.on_agent_conn(&self.tunnel_service_ctx, agent_id, stream);
AgentListenerEvent::IncomingStream(agent_id, agent_ctx, stream) => {
self.tunnel_service_handle.on_agent_conn(&self.tunnel_service_ctx, agent_id, agent_ctx, stream);
Ok(QuicRelayerEvent::Continue)
}
AgentListenerEvent::Disconnected(agent_id, session_id) => {
Expand All @@ -265,8 +265,8 @@ where
gauge!(METRICS_AGENT_LIVE).increment(1.0);
Ok(QuicRelayerEvent::AgentConnected(agent_id, session_id, domain))
},
AgentListenerEvent::IncomingStream(agent_id, stream) => {
self.tunnel_service_handle.on_agent_conn(&self.tunnel_service_ctx, agent_id, stream);
AgentListenerEvent::IncomingStream(agent_id, agent_ctx, stream) => {
self.tunnel_service_handle.on_agent_conn(&self.tunnel_service_ctx, agent_id, agent_ctx, stream);
Ok(QuicRelayerEvent::Continue)
}
AgentListenerEvent::Disconnected(agent_id, session_id) => {
Expand Down
11 changes: 9 additions & 2 deletions bin/relayer/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,17 @@ async fn main() {

struct DummyTunnelHandle;

impl TunnelServiceHandle for DummyTunnelHandle {
impl TunnelServiceHandle<Vec<u8>> for DummyTunnelHandle {
fn start(&mut self, _ctx: &atm0s_reverse_proxy_relayer::TunnelServiceCtx) {}

fn on_agent_conn<S: tokio::io::AsyncRead + tokio::io::AsyncWrite>(&mut self, _ctx: &atm0s_reverse_proxy_relayer::TunnelServiceCtx, _agent_id: protocol::proxy::AgentId, _stream: S) {}
fn on_agent_conn<S: tokio::io::AsyncRead + tokio::io::AsyncWrite>(
&mut self,
_ctx: &atm0s_reverse_proxy_relayer::TunnelServiceCtx,
_agent_id: protocol::proxy::AgentId,
_metadata: Vec<u8>,
_stream: S,
) {
}

fn on_cluster_event(&mut self, _ctx: &atm0s_reverse_proxy_relayer::TunnelServiceCtx, _event: p2p::P2pServiceEvent) {}
}
9 changes: 8 additions & 1 deletion crates/protocol/src/key.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
use std::fmt::Debug;

pub trait ClusterRequest {
type Context: Clone + Send + Sync + 'static + Debug;
fn context(&self) -> Self::Context;
}

pub trait AgentSigner<RES> {
fn sign_connect_req(&self) -> Vec<u8>;
fn validate_connect_res(&self, resp: &[u8]) -> anyhow::Result<RES>;
}

pub trait ClusterValidator<REQ>: Send + Sync + Clone + 'static {
pub trait ClusterValidator<REQ: ClusterRequest>: Send + Sync + Clone + 'static {
fn validate_connect_req(&self, req: &[u8]) -> anyhow::Result<REQ>;
fn generate_domain(&self, req: &REQ) -> anyhow::Result<String>;
fn sign_response_res(&self, m: &REQ, err: Option<String>) -> Vec<u8>;
Expand Down
10 changes: 9 additions & 1 deletion crates/protocol_ed25519/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ed25519_dalek::pkcs8::spki::der::pem::LineEnding;
use ed25519_dalek::pkcs8::{DecodePrivateKey, EncodePrivateKey};
use ed25519_dalek::SigningKey;
use ed25519_dalek::{Signer, Verifier};
use protocol::key::{AgentSigner, ClusterValidator};
use protocol::key::{AgentSigner, ClusterRequest, ClusterValidator};
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};

Expand All @@ -12,6 +12,14 @@ pub struct RegisterRequest {
pub signature: ed25519_dalek::Signature,
}

impl ClusterRequest for RegisterRequest {
type Context = Vec<u8>;

fn context(&self) -> Self::Context {
self.pub_key.to_bytes().to_vec()
}
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct RegisterResponse {
pub response: Result<String, String>,
Expand Down

0 comments on commit d7082c9

Please sign in to comment.