diff --git a/crates/relayer/src/agent_listener/quic.rs b/crates/relayer/src/agent_listener/quic.rs index 4f9d1e5..de9245d 100644 --- a/crates/relayer/src/agent_listener/quic.rs +++ b/crates/relayer/src/agent_listener/quic.rs @@ -25,10 +25,16 @@ impl AgentQuicListener { } } - async fn process_incoming_conn(&self, conn: quinn::Connection) -> Result> { + async fn process_incoming_conn( + &self, + conn: quinn::Connection, + ) -> Result> { let (mut send, mut recv) = conn.accept_bi().await?; let mut buf = [0u8; 4096]; - let buf_len = recv.read(&mut buf).await?.ok_or::>("No incomming data".into())?; + let buf_len = recv + .read(&mut buf) + .await? + .ok_or::>("No incomming data".into())?; match RegisterRequest::try_from(&buf[..buf_len]) { Ok(request) => { @@ -61,7 +67,11 @@ impl AgentListener Result> { loop { - let incoming_conn = self.endpoint.accept().await.ok_or::>("Cannot accept".into())?; + let incoming_conn = self + .endpoint + .accept() + .await + .ok_or::>("Cannot accept".into())?; let conn: quinn::Connection = incoming_conn.await?; log::info!( "[AgentQuicListener] new conn from {}", diff --git a/crates/relayer/src/agent_listener/tcp.rs b/crates/relayer/src/agent_listener/tcp.rs index 428d8ba..050fcca 100644 --- a/crates/relayer/src/agent_listener/tcp.rs +++ b/crates/relayer/src/agent_listener/tcp.rs @@ -1,7 +1,8 @@ use std::{ + error::Error, net::SocketAddr, pin::Pin, - task::{Context, Poll}, error::Error, + task::{Context, Poll}, }; use async_std::net::{TcpListener, TcpStream}; @@ -25,13 +26,15 @@ impl AgentTcpListener { pub async fn new(addr: SocketAddr, root_domain: String) -> Self { log::info!("AgentTcpListener::new {}", addr); Self { - tcp_listener: TcpListener::bind(addr) - .await.expect("Should open"), + tcp_listener: TcpListener::bind(addr).await.expect("Should open"), root_domain, } } - async fn process_incoming_stream(&self, mut stream: TcpStream) -> Result> { + async fn process_incoming_stream( + &self, + mut stream: TcpStream, + ) -> Result> { let mut buf = [0u8; 4096]; let buf_len = stream.read(&mut buf).await?; @@ -174,7 +177,9 @@ where match this.connection.poll_next_inbound(cx) { Poll::Ready(Some(Ok(stream))) => return Poll::Ready(Ok(())), Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e.into())), - Poll::Ready(None) => return Poll::Ready(Err("yamux server poll next inbound return None".into())), + Poll::Ready(None) => { + return Poll::Ready(Err("yamux server poll next inbound return None".into())) + } Poll::Pending => Poll::Pending, } } diff --git a/crates/relayer/src/agent_worker.rs b/crates/relayer/src/agent_worker.rs index db6a0e5..63bc614 100644 --- a/crates/relayer/src/agent_worker.rs +++ b/crates/relayer/src/agent_worker.rs @@ -1,4 +1,4 @@ -use std::{marker::PhantomData, error::Error}; +use std::{error::Error, marker::PhantomData}; use futures::{select, AsyncRead, AsyncWrite, FutureExt}; use metrics::increment_gauge; diff --git a/crates/relayer/src/main.rs b/crates/relayer/src/main.rs index bdcefdd..91e2f7c 100644 --- a/crates/relayer/src/main.rs +++ b/crates/relayer/src/main.rs @@ -5,10 +5,10 @@ use metrics_dashboard::build_dashboard_route; use poem::{listener::TcpListener, middleware::Tracing, EndpointExt as _, Route, Server}; use std::{collections::HashMap, net::SocketAddr, process::exit, sync::Arc, time::Duration}; -use agent_listener::quic::AgentQuicListener; use agent_listener::tcp::AgentTcpListener; -use async_std::{sync::RwLock, prelude::FutureExt as _}; -use futures::{select, FutureExt}; +use agent_listener::{quic::AgentQuicListener, AgentSubConnection}; +use async_std::{prelude::FutureExt as _, sync::RwLock}; +use futures::{select, AsyncRead, AsyncWrite, FutureExt}; use metrics::{ decrement_gauge, describe_counter, describe_gauge, increment_counter, increment_gauge, }; @@ -68,8 +68,7 @@ async fn main() { .init(); let mut quic_agent_listener = AgentQuicListener::new(args.connector_port, args.root_domain.clone()).await; - let mut tcp_agent_listener = - AgentTcpListener::new(args.connector_port, args.root_domain).await; + let mut tcp_agent_listener = AgentTcpListener::new(args.connector_port, args.root_domain).await; let mut proxy_http_listener = ProxyHttpListener::new(args.http_port, false) .await .expect("Should listen http port"); @@ -100,28 +99,7 @@ async fn main() { select! { e = quic_agent_listener.recv().fuse() => match e { Ok(agent_connection) => { - increment_counter!(METRICS_AGENT_COUNT); - log::info!("agent_connection.domain(): {}", agent_connection.domain()); - let domain = agent_connection.domain().to_string(); - let (mut agent_worker, proxy_tunnel_tx) = agent_worker::AgentWorker::new(agent_connection); - agents.write().await.insert(domain.clone(), proxy_tunnel_tx); - let agents = agents.clone(); - async_std::task::spawn(async move { - increment_gauge!(METRICS_AGENT_LIVE, 1.0); - log::info!("agent_worker run for domain: {}", domain); - loop { - match agent_worker.run().await { - Ok(()) => {} - Err(e) => { - log::error!("agent_worker error: {}", e); - break; - } - } - } - agents.write().await.remove(&domain); - log::info!("agent_worker exit for domain: {}", domain); - decrement_gauge!(METRICS_AGENT_LIVE, 1.0); - }); + run_agent_connection(agent_connection, agents.clone()).await; } Err(e) => { log::error!("agent_listener error {}", e); @@ -130,28 +108,7 @@ async fn main() { }, e = tcp_agent_listener.recv().fuse() => match e { Ok(agent_connection) => { - increment_counter!(METRICS_AGENT_COUNT); - log::info!("agent_connection.domain(): {}", agent_connection.domain()); - let domain = agent_connection.domain().to_string(); - let (mut agent_worker, proxy_tunnel_tx) = agent_worker::AgentWorker::new(agent_connection); - agents.write().await.insert(domain.clone(), proxy_tunnel_tx); - let agents = agents.clone(); - async_std::task::spawn(async move { - increment_gauge!(METRICS_AGENT_LIVE, 1.0); - log::info!("agent_worker run for domain: {}", domain); - loop { - match agent_worker.run().await { - Ok(()) => {} - Err(e) => { - log::error!("agent_worker error: {}", e); - break; - } - } - } - agents.write().await.remove(&domain); - log::info!("agent_worker exit for domain: {}", domain); - decrement_gauge!(METRICS_AGENT_LIVE, 1.0); - }); + run_agent_connection(agent_connection, agents.clone()).await; } Err(e) => { log::error!("agent_listener error {}", e); @@ -159,29 +116,8 @@ async fn main() { } }, e = proxy_http_listener.recv().fuse() => match e { - Some(mut proxy_tunnel) => { - let agents = agents.clone(); - async_std::task::spawn(async move { - match proxy_tunnel.wait().timeout(Duration::from_secs(5)).await { - Err(_) => { - log::error!("proxy_tunnel.wait() for checking url timeout"); - return; - }, - Ok(None) => { - log::error!("proxy_tunnel.wait() for checking url invalid"); - return; - }, - _ => {} - } - increment_counter!(METRICS_PROXY_COUNT); - log::info!("proxy_tunnel.domain(): {}", proxy_tunnel.domain()); - let domain = proxy_tunnel.domain().to_string(); - if let Some(agent_tx) = agents.read().await.get(&domain) { - agent_tx.send(proxy_tunnel).await.ok(); - } else { - log::warn!("agent not found for domain: {}", domain); - } - }); + Some(proxy_tunnel) => { + async_std::task::spawn(run_http_request(proxy_tunnel, agents.clone())); } None => { log::error!("proxy_http_listener.recv()"); @@ -189,17 +125,8 @@ async fn main() { } }, e = proxy_tls_listener.recv().fuse() => match e { - Some(mut proxy_tunnel) => { - if proxy_tunnel.wait().await.is_none() { - continue; - } - log::info!("proxy_tunnel.domain(): {}", proxy_tunnel.domain()); - let domain = proxy_tunnel.domain().to_string(); - if let Some(agent_tx) = agents.read().await.get(&domain) { - agent_tx.send(proxy_tunnel).await.ok(); - } else { - log::warn!("agent not found for domain: {}", domain); - } + Some(proxy_tunnel) => { + async_std::task::spawn(run_http_request(proxy_tunnel, agents.clone())); } None => { log::error!("proxy_http_listener.recv()"); @@ -209,3 +136,69 @@ async fn main() { } } } + +async fn run_agent_connection( + agent_connection: AG, + agents: Arc>>>, +) where + AG: AgentConnection + 'static, + S: AgentSubConnection + 'static, + R: AsyncRead + Send + Unpin + 'static, + W: AsyncWrite + Send + Unpin + 'static, + PT: ProxyTunnel + 'static, + PR: AsyncRead + Send + Unpin + 'static, + PW: AsyncWrite + Send + Unpin + 'static, +{ + increment_counter!(METRICS_AGENT_COUNT); + log::info!("agent_connection.domain(): {}", agent_connection.domain()); + let domain = agent_connection.domain().to_string(); + let (mut agent_worker, proxy_tunnel_tx) = + agent_worker::AgentWorker::::new(agent_connection); + agents.write().await.insert(domain.clone(), proxy_tunnel_tx); + let agents = agents.clone(); + async_std::task::spawn(async move { + increment_gauge!(METRICS_AGENT_LIVE, 1.0); + log::info!("agent_worker run for domain: {}", domain); + loop { + match agent_worker.run().await { + Ok(()) => {} + Err(e) => { + log::error!("agent_worker error: {}", e); + break; + } + } + } + agents.write().await.remove(&domain); + log::info!("agent_worker exit for domain: {}", domain); + decrement_gauge!(METRICS_AGENT_LIVE, 1.0); + }); +} + +async fn run_http_request( + mut proxy_tunnel: PT, + agents: Arc>>>, +) where + PT: ProxyTunnel + 'static, + PR: AsyncRead + Send + Unpin + 'static, + PW: AsyncWrite + Send + Unpin + 'static, +{ + match proxy_tunnel.wait().timeout(Duration::from_secs(5)).await { + Err(_) => { + log::error!("proxy_tunnel.wait() for checking url timeout"); + return; + } + Ok(None) => { + log::error!("proxy_tunnel.wait() for checking url invalid"); + return; + } + _ => {} + } + increment_counter!(METRICS_PROXY_COUNT); + log::info!("proxy_tunnel.domain(): {}", proxy_tunnel.domain()); + let domain = proxy_tunnel.domain().to_string(); + if let Some(agent_tx) = agents.read().await.get(&domain) { + agent_tx.send(proxy_tunnel).await.ok(); + } else { + log::warn!("agent not found for domain: {}", domain); + } +}