Skip to content

Commit

Permalink
avoid duplicate code in relayer
Browse files Browse the repository at this point in the history
  • Loading branch information
giangndm committed Dec 18, 2023
1 parent 52a38fe commit 6fdd2a9
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 92 deletions.
16 changes: 13 additions & 3 deletions crates/relayer/src/agent_listener/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ impl AgentQuicListener {
}
}

async fn process_incoming_conn(&self, conn: quinn::Connection) -> Result<AgentQuicConnection, Box<dyn Error>> {
async fn process_incoming_conn(
&self,
conn: quinn::Connection,
) -> Result<AgentQuicConnection, Box<dyn Error>> {
let (mut send, mut recv) = conn.accept_bi().await?;
let mut buf = [0u8; 4096];
let buf_len = recv.read(&mut buf).await?.ok_or::<Box<dyn Error>>("No incomming data".into())?;
let buf_len = recv
.read(&mut buf)
.await?
.ok_or::<Box<dyn Error>>("No incomming data".into())?;

match RegisterRequest::try_from(&buf[..buf_len]) {
Ok(request) => {
Expand Down Expand Up @@ -61,7 +67,11 @@ impl AgentListener<AgentQuicConnection, AgentQuicSubConnection, RecvStream, Send
{
async fn recv(&mut self) -> Result<AgentQuicConnection, Box<dyn Error>> {
loop {
let incoming_conn = self.endpoint.accept().await.ok_or::<Box<dyn Error>>("Cannot accept".into())?;
let incoming_conn = self
.endpoint
.accept()
.await
.ok_or::<Box<dyn Error>>("Cannot accept".into())?;
let conn: quinn::Connection = incoming_conn.await?;
log::info!(
"[AgentQuicListener] new conn from {}",
Expand Down
15 changes: 10 additions & 5 deletions crates/relayer/src/agent_listener/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::{
error::Error,
net::SocketAddr,
pin::Pin,
task::{Context, Poll}, error::Error,
task::{Context, Poll},
};

use async_std::net::{TcpListener, TcpStream};
Expand All @@ -25,13 +26,15 @@ impl AgentTcpListener {
pub async fn new(addr: SocketAddr, root_domain: String) -> Self {
log::info!("AgentTcpListener::new {}", addr);
Self {
tcp_listener: TcpListener::bind(addr)
.await.expect("Should open"),
tcp_listener: TcpListener::bind(addr).await.expect("Should open"),
root_domain,
}
}

async fn process_incoming_stream(&self, mut stream: TcpStream) -> Result<AgentTcpConnection, Box<dyn Error>> {
async fn process_incoming_stream(
&self,
mut stream: TcpStream,
) -> Result<AgentTcpConnection, Box<dyn Error>> {
let mut buf = [0u8; 4096];
let buf_len = stream.read(&mut buf).await?;

Expand Down Expand Up @@ -174,7 +177,9 @@ where
match this.connection.poll_next_inbound(cx) {
Poll::Ready(Some(Ok(stream))) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e.into())),
Poll::Ready(None) => return Poll::Ready(Err("yamux server poll next inbound return None".into())),
Poll::Ready(None) => {
return Poll::Ready(Err("yamux server poll next inbound return None".into()))
}
Poll::Pending => Poll::Pending,
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/relayer/src/agent_worker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{marker::PhantomData, error::Error};
use std::{error::Error, marker::PhantomData};

use futures::{select, AsyncRead, AsyncWrite, FutureExt};
use metrics::increment_gauge;
Expand Down
159 changes: 76 additions & 83 deletions crates/relayer/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use metrics_dashboard::build_dashboard_route;
use poem::{listener::TcpListener, middleware::Tracing, EndpointExt as _, Route, Server};
use std::{collections::HashMap, net::SocketAddr, process::exit, sync::Arc, time::Duration};

use agent_listener::quic::AgentQuicListener;
use agent_listener::tcp::AgentTcpListener;
use async_std::{sync::RwLock, prelude::FutureExt as _};
use futures::{select, FutureExt};
use agent_listener::{quic::AgentQuicListener, AgentSubConnection};
use async_std::{prelude::FutureExt as _, sync::RwLock};
use futures::{select, AsyncRead, AsyncWrite, FutureExt};
use metrics::{
decrement_gauge, describe_counter, describe_gauge, increment_counter, increment_gauge,
};
Expand Down Expand Up @@ -68,8 +68,7 @@ async fn main() {
.init();
let mut quic_agent_listener =
AgentQuicListener::new(args.connector_port, args.root_domain.clone()).await;
let mut tcp_agent_listener =
AgentTcpListener::new(args.connector_port, args.root_domain).await;
let mut tcp_agent_listener = AgentTcpListener::new(args.connector_port, args.root_domain).await;
let mut proxy_http_listener = ProxyHttpListener::new(args.http_port, false)
.await
.expect("Should listen http port");
Expand Down Expand Up @@ -100,28 +99,7 @@ async fn main() {
select! {
e = quic_agent_listener.recv().fuse() => match e {
Ok(agent_connection) => {
increment_counter!(METRICS_AGENT_COUNT);
log::info!("agent_connection.domain(): {}", agent_connection.domain());
let domain = agent_connection.domain().to_string();
let (mut agent_worker, proxy_tunnel_tx) = agent_worker::AgentWorker::new(agent_connection);
agents.write().await.insert(domain.clone(), proxy_tunnel_tx);
let agents = agents.clone();
async_std::task::spawn(async move {
increment_gauge!(METRICS_AGENT_LIVE, 1.0);
log::info!("agent_worker run for domain: {}", domain);
loop {
match agent_worker.run().await {
Ok(()) => {}
Err(e) => {
log::error!("agent_worker error: {}", e);
break;
}
}
}
agents.write().await.remove(&domain);
log::info!("agent_worker exit for domain: {}", domain);
decrement_gauge!(METRICS_AGENT_LIVE, 1.0);
});
run_agent_connection(agent_connection, agents.clone()).await;
}
Err(e) => {
log::error!("agent_listener error {}", e);
Expand All @@ -130,76 +108,25 @@ async fn main() {
},
e = tcp_agent_listener.recv().fuse() => match e {
Ok(agent_connection) => {
increment_counter!(METRICS_AGENT_COUNT);
log::info!("agent_connection.domain(): {}", agent_connection.domain());
let domain = agent_connection.domain().to_string();
let (mut agent_worker, proxy_tunnel_tx) = agent_worker::AgentWorker::new(agent_connection);
agents.write().await.insert(domain.clone(), proxy_tunnel_tx);
let agents = agents.clone();
async_std::task::spawn(async move {
increment_gauge!(METRICS_AGENT_LIVE, 1.0);
log::info!("agent_worker run for domain: {}", domain);
loop {
match agent_worker.run().await {
Ok(()) => {}
Err(e) => {
log::error!("agent_worker error: {}", e);
break;
}
}
}
agents.write().await.remove(&domain);
log::info!("agent_worker exit for domain: {}", domain);
decrement_gauge!(METRICS_AGENT_LIVE, 1.0);
});
run_agent_connection(agent_connection, agents.clone()).await;
}
Err(e) => {
log::error!("agent_listener error {}", e);
exit(1);
}
},
e = proxy_http_listener.recv().fuse() => match e {
Some(mut proxy_tunnel) => {
let agents = agents.clone();
async_std::task::spawn(async move {
match proxy_tunnel.wait().timeout(Duration::from_secs(5)).await {
Err(_) => {
log::error!("proxy_tunnel.wait() for checking url timeout");
return;
},
Ok(None) => {
log::error!("proxy_tunnel.wait() for checking url invalid");
return;
},
_ => {}
}
increment_counter!(METRICS_PROXY_COUNT);
log::info!("proxy_tunnel.domain(): {}", proxy_tunnel.domain());
let domain = proxy_tunnel.domain().to_string();
if let Some(agent_tx) = agents.read().await.get(&domain) {
agent_tx.send(proxy_tunnel).await.ok();
} else {
log::warn!("agent not found for domain: {}", domain);
}
});
Some(proxy_tunnel) => {
async_std::task::spawn(run_http_request(proxy_tunnel, agents.clone()));
}
None => {
log::error!("proxy_http_listener.recv()");
exit(2);
}
},
e = proxy_tls_listener.recv().fuse() => match e {
Some(mut proxy_tunnel) => {
if proxy_tunnel.wait().await.is_none() {
continue;
}
log::info!("proxy_tunnel.domain(): {}", proxy_tunnel.domain());
let domain = proxy_tunnel.domain().to_string();
if let Some(agent_tx) = agents.read().await.get(&domain) {
agent_tx.send(proxy_tunnel).await.ok();
} else {
log::warn!("agent not found for domain: {}", domain);
}
Some(proxy_tunnel) => {
async_std::task::spawn(run_http_request(proxy_tunnel, agents.clone()));
}
None => {
log::error!("proxy_http_listener.recv()");
Expand All @@ -209,3 +136,69 @@ async fn main() {
}
}
}

async fn run_agent_connection<AG, S, R, W, PT, PR, PW>(
agent_connection: AG,
agents: Arc<RwLock<HashMap<String, async_std::channel::Sender<PT>>>>,
) where
AG: AgentConnection<S, R, W> + 'static,
S: AgentSubConnection<R, W> + 'static,
R: AsyncRead + Send + Unpin + 'static,
W: AsyncWrite + Send + Unpin + 'static,
PT: ProxyTunnel<PR, PW> + 'static,
PR: AsyncRead + Send + Unpin + 'static,
PW: AsyncWrite + Send + Unpin + 'static,
{
increment_counter!(METRICS_AGENT_COUNT);
log::info!("agent_connection.domain(): {}", agent_connection.domain());
let domain = agent_connection.domain().to_string();
let (mut agent_worker, proxy_tunnel_tx) =
agent_worker::AgentWorker::<AG, S, R, W, PT, PR, PW>::new(agent_connection);
agents.write().await.insert(domain.clone(), proxy_tunnel_tx);
let agents = agents.clone();
async_std::task::spawn(async move {
increment_gauge!(METRICS_AGENT_LIVE, 1.0);
log::info!("agent_worker run for domain: {}", domain);
loop {
match agent_worker.run().await {
Ok(()) => {}
Err(e) => {
log::error!("agent_worker error: {}", e);
break;
}
}
}
agents.write().await.remove(&domain);
log::info!("agent_worker exit for domain: {}", domain);
decrement_gauge!(METRICS_AGENT_LIVE, 1.0);
});
}

async fn run_http_request<PT, PR, PW>(
mut proxy_tunnel: PT,
agents: Arc<RwLock<HashMap<String, async_std::channel::Sender<PT>>>>,
) where
PT: ProxyTunnel<PR, PW> + 'static,
PR: AsyncRead + Send + Unpin + 'static,
PW: AsyncWrite + Send + Unpin + 'static,
{
match proxy_tunnel.wait().timeout(Duration::from_secs(5)).await {
Err(_) => {
log::error!("proxy_tunnel.wait() for checking url timeout");
return;
}
Ok(None) => {
log::error!("proxy_tunnel.wait() for checking url invalid");
return;
}
_ => {}
}
increment_counter!(METRICS_PROXY_COUNT);
log::info!("proxy_tunnel.domain(): {}", proxy_tunnel.domain());
let domain = proxy_tunnel.domain().to_string();
if let Some(agent_tx) = agents.read().await.get(&domain) {
agent_tx.send(proxy_tunnel).await.ok();
} else {
log::warn!("agent not found for domain: {}", domain);
}
}

0 comments on commit 6fdd2a9

Please sign in to comment.