Skip to content

Commit

Permalink
feat: turn tcp/quic into features in agent, optimize client binary si…
Browse files Browse the repository at this point in the history
…ze (#87)

* feat: turn tcp/quic into features in agent

* reduce agent size
  • Loading branch information
giangndm authored Feb 14, 2025
1 parent 38f45ca commit 8ca0ba0
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 75 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ protocol-ed25519 = { path = "crates/protocol_ed25519", package = "atm0s-reverse-
log = "0.4"
tokio-yamux = "0.3"
clap = "4.4"
argh = "=0.1.13" # small cli
async-trait = "0.1"
tokio = "1"
httparse = "1.8"
tls-parser = "0.12"
rtsp-types = "0.1"
tracing-subscriber = "0.3"
picolog = "1.0"
atm0s-sdn = "0.2"
serde = "1.0"
bincode = "1.3"
Expand Down
26 changes: 17 additions & 9 deletions bin/agent/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@ protocol-ed25519 = { workspace = true, optional = true }
tokio = { workspace = true, features = ["full"] }
futures = { version = "0.3" }
async-trait = { workspace = true }
clap = { workspace = true, features = ["derive", "env"] }
log = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter", "std"], optional = true }
tokio-yamux = { workspace = true }
bincode = { workspace = true }
serde = { workspace = true, features = ["derive"] }
quinn = { workspace = true, features = ["ring", "runtime-tokio", "futures-io"] }
rustls = { workspace = true, features = ["ring", "std"] }
url = { workspace = true }
base64 = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true, optional = true }

# for binary build
picolog = { workspace = true, optional = true }
argh = { workspace = true, optional = true }

# for tcp protocol
tokio-yamux = { workspace = true, optional = true }

# for quic protocol
quinn = { workspace = true, features = ["ring", "runtime-tokio", "futures-io"], optional = true }
rustls = { workspace = true, features = ["ring", "std"], optional = true }
base64 = { workspace = true, optional = true }

[features]
default = ["binary"]
binary = ["protocol-ed25519", "tracing-subscriber"]
default = ["binary", "tcp"]
binary = ["protocol-ed25519", "argh", "picolog"]
tcp = ["tokio-yamux"]
quic = ["quinn", "rustls", "base64", "thiserror"]
68 changes: 41 additions & 27 deletions bin/agent/examples/benchmark_clients.rs
Original file line number Diff line number Diff line change
@@ -1,69 +1,79 @@
use std::str::FromStr;
use std::{
net::SocketAddr,
sync::Arc,
time::{Duration, Instant},
};

use atm0s_reverse_proxy_agent::{run_tunnel_connection, Connection, Protocol, QuicConnection, ServiceRegistry, SimpleServiceRegistry, SubConnection, TcpConnection};
use argh::FromArgs;
#[cfg(feature = "quic")]
use atm0s_reverse_proxy_agent::QuicConnection;
#[cfg(feature = "tcp")]
use atm0s_reverse_proxy_agent::TcpConnection;
use atm0s_reverse_proxy_agent::{run_tunnel_connection, Connection, Protocol, ServiceRegistry, SimpleServiceRegistry, SubConnection};
#[cfg(feature = "quic")]
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
use clap::Parser;
use log::LevelFilter;
use picolog::PicoLogger;
#[cfg(feature = "quic")]
use protocol::DEFAULT_TUNNEL_CERT;
use protocol_ed25519::AgentLocalKey;
#[cfg(feature = "quic")]
use rustls::pki_types::CertificateDer;
use tokio::time::sleep;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use url::Url;

/// A benchmark util for simulating multiple clients connect to relay server
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
#[derive(FromArgs, Debug, Clone)]
struct Args {
/// Address of relay server
#[arg(env, long)]
/// address of relay server
#[argh(option)]
connector_addr: Url,

/// Protocol of relay server
#[arg(env, long)]
/// protocol of relay server
#[argh(option)]
connector_protocol: Protocol,

/// Http proxy dest
#[arg(env, long, default_value = "127.0.0.1:8080")]
/// http proxy dest
#[argh(option, default = "SocketAddr::from_str(\"127.0.0.1:8080\").unwrap()")]
http_dest: SocketAddr,

/// Sni-https proxy dest
#[arg(env, long, default_value = "127.0.0.1:8443")]
/// sni-https proxy dest
#[argh(option, default = "SocketAddr::from_str(\"127.0.0.1:8443\").unwrap()")]
https_dest: SocketAddr,

/// Custom quic server cert in base64
#[arg(env, long)]
#[cfg(feature = "quic")]
/// custom quic server cert in base64
#[argh(option)]
custom_quic_cert_base64: Option<String>,

/// Allow connect in insecure mode
#[arg(env, long)]
#[cfg(feature = "quic")]
/// allow connect in insecure mode
#[argh(option)]
allow_quic_insecure: bool,

/// clients
#[arg(env, long)]
#[argh(option)]
clients: usize,

/// wait time between connect action
#[arg(env, long, default_value_t = 1000)]
#[argh(option, default = "1000")]
connect_wait_ms: u64,
}

#[tokio::main]
async fn main() {
let args = Args::parse();
let args: Args = argh::from_env();

#[cfg(feature = "quic")]
rustls::crypto::ring::default_provider().install_default().expect("should install ring as default");

//if RUST_LOG env is not set, set it to info
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "warn");
}
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(tracing_subscriber::EnvFilter::from_default_env())
.init();
let level = match std::env::var("RUST_LOG") {
Ok(v) => LevelFilter::from_str(&v).unwrap_or(LevelFilter::Info),
_ => LevelFilter::Info,
};
PicoLogger::new(level).init();

let registry = SimpleServiceRegistry::new(args.http_dest, args.https_dest);
let registry = Arc::new(registry);
Expand All @@ -83,8 +93,10 @@ async fn main() {
}

async fn connect(client: usize, args: Args, registry: Arc<dyn ServiceRegistry>) {
#[cfg(feature = "quic")]
let default_tunnel_cert = CertificateDer::from(DEFAULT_TUNNEL_CERT.to_vec());

#[cfg(feature = "quic")]
let server_certs = if let Some(cert) = args.custom_quic_cert_base64 {
vec![CertificateDer::from(URL_SAFE.decode(&cert).expect("Custom cert should in base64 format").to_vec())]
} else {
Expand All @@ -96,6 +108,7 @@ async fn connect(client: usize, args: Args, registry: Arc<dyn ServiceRegistry>)
log::info!("Connecting to connector... {:?} addr: {}", args.connector_protocol, args.connector_addr);
let started = Instant::now();
match args.connector_protocol {
#[cfg(feature = "tcp")]
Protocol::Tcp => match TcpConnection::new(args.connector_addr.clone(), &agent_signer).await {
Ok(conn) => {
log::info!("Connected to connector via tcp with res {:?}", conn.response());
Expand All @@ -106,6 +119,7 @@ async fn connect(client: usize, args: Args, registry: Arc<dyn ServiceRegistry>)
log::error!("Connect to connector via tcp error: {e}");
}
},
#[cfg(feature = "quic")]
Protocol::Quic => match QuicConnection::new(args.connector_addr.clone(), &agent_signer, &server_certs, args.allow_quic_insecure).await {
Ok(conn) => {
log::info!("Connected to connector via quic with res {:?}", conn.response());
Expand Down
23 changes: 20 additions & 3 deletions bin/agent/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,37 @@
//! Tunnel is a trait that defines the interface for a tunnel which connect to connector port of relayer.
use std::fmt::Debug;
use std::{fmt::Debug, str::FromStr};

use clap::ValueEnum;
use protocol::stream::TunnelStream;
use tokio::io::{AsyncRead, AsyncWrite};

#[cfg(feature = "quic")]
pub mod quic;

#[cfg(feature = "tcp")]
pub mod tcp;

#[derive(ValueEnum, Debug, Clone)]
#[derive(Debug, Clone)]
pub enum Protocol {
#[cfg(feature = "tcp")]
Tcp,
#[cfg(feature = "quic")]
Quic,
}

impl FromStr for Protocol {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
#[cfg(feature = "tcp")]
"tcp" | "TCP" => Ok(Protocol::Tcp),
#[cfg(feature = "quic")]
"quic" | "QUIC" => Ok(Protocol::Quic),
_ => Err("invalid protocol"),
}
}
}

pub trait SubConnection: AsyncRead + AsyncWrite + Unpin + Send + Sync {}

impl<R: AsyncRead + Unpin + Send + Sync, W: AsyncWrite + Unpin + Send + Sync> SubConnection for TunnelStream<R, W> {}
Expand Down
10 changes: 5 additions & 5 deletions bin/agent/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ use protocol::cluster::{wait_object, AgentTunnelRequest};

mod connection;
mod local_tunnel;
#[cfg(feature = "quic")]
pub use connection::quic::{QuicConnection, QuicSubConnection};
#[cfg(feature = "tcp")]
pub use connection::tcp::{TcpConnection, TcpSubConnection};

pub use connection::{
quic::{QuicConnection, QuicSubConnection},
tcp::{TcpConnection, TcpSubConnection},
Connection, Protocol, SubConnection,
};
pub use connection::{Connection, Protocol, SubConnection};
pub use local_tunnel::{registry::SimpleServiceRegistry, LocalTunnel, ServiceRegistry};
use tokio::{io::copy_bidirectional, net::TcpStream};

Expand Down
80 changes: 49 additions & 31 deletions bin/agent/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,77 +1,93 @@
use std::str::FromStr;
use std::{alloc::System, net::SocketAddr, sync::Arc};

use atm0s_reverse_proxy_agent::{run_tunnel_connection, Connection, Protocol, QuicConnection, ServiceRegistry, SimpleServiceRegistry, SubConnection, TcpConnection};
use log::LevelFilter;
use picolog::PicoLogger;

#[cfg(feature = "quic")]
use atm0s_reverse_proxy_agent::QuicConnection;
#[cfg(feature = "tcp")]
use atm0s_reverse_proxy_agent::TcpConnection;
use atm0s_reverse_proxy_agent::{run_tunnel_connection, Connection, Protocol, ServiceRegistry, SimpleServiceRegistry, SubConnection};
#[cfg(feature = "quic")]
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
use clap::Parser;
use protocol::{services::SERVICE_RTSP, DEFAULT_TUNNEL_CERT};

use argh::FromArgs;
use protocol::services::SERVICE_RTSP;
#[cfg(feature = "quic")]
use protocol::DEFAULT_TUNNEL_CERT;
use protocol_ed25519::AgentLocalKey;
#[cfg(feature = "quic")]
use rustls::pki_types::CertificateDer;
use tokio::time::sleep;
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
// use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use url::Url;

#[global_allocator]
static A: System = System;

/// A HTTP and SNI HTTPs proxy for expose your local service to the internet.
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
#[derive(FromArgs, Debug)]
struct Args {
/// Address of relay server
#[arg(env, long)]
/// address of relay server
#[argh(option)]
connector_addr: Url,

/// Protocol of relay server
#[arg(env, long)]
/// protocol of relay server
#[argh(option)]
connector_protocol: Protocol,

/// Http proxy dest
#[arg(env, long, default_value = "127.0.0.1:8080")]
/// http proxy dest
#[argh(option, default = "SocketAddr::from_str(\"127.0.0.1:8080\").unwrap()")]
http_dest: SocketAddr,

/// Sni-https proxy dest
#[arg(env, long, default_value = "127.0.0.1:8443")]
/// sni-https proxy dest
#[argh(option, default = "SocketAddr::from_str(\"127.0.0.1:8443\").unwrap()")]
https_dest: SocketAddr,

/// Rtsp proxy dest
#[arg(env, long, default_value = "127.0.0.1:554")]
/// rtsp proxy dest
#[argh(option, default = "SocketAddr::from_str(\"127.0.0.1:554\").unwrap()")]
rtsp_dest: SocketAddr,

/// Sni-https proxy dest
#[arg(env, long, default_value = "127.0.0.1:5443")]
/// sni-https proxy dest
#[argh(option, default = "SocketAddr::from_str(\"127.0.0.1:5443\").unwrap()")]
rtsps_dest: SocketAddr,

/// Persistent local key
#[arg(env, long, default_value = "local_key.pem")]
/// persistent local key
#[argh(option, default = "String::from(\"local_key.pem\")")]
local_key: String,

/// Custom quic server cert in base64
#[arg(env, long)]
#[cfg(feature = "quic")]
/// custom quic server cert in base64
#[argh(option)]
custom_quic_cert_base64: Option<String>,

/// Allow connect in insecure mode
#[arg(env, long)]
#[cfg(feature = "quic")]
/// allow connect in insecure mode
#[argh(switch)]
allow_quic_insecure: bool,
}

#[tokio::main]
async fn main() {
let args = Args::parse();
let args: Args = argh::from_env();
//if RUST_LOG env is not set, set it to info
let level = match std::env::var("RUST_LOG") {
Ok(v) => LevelFilter::from_str(&v).unwrap_or(LevelFilter::Info),
_ => LevelFilter::Info,
};
PicoLogger::new(level).init();

#[cfg(feature = "quic")]
let server_certs = if let Some(cert) = args.custom_quic_cert_base64 {
vec![CertificateDer::from(URL_SAFE.decode(cert).expect("Custom cert should in base64 format").to_vec())]
} else {
vec![CertificateDer::from(DEFAULT_TUNNEL_CERT.to_vec())]
};

#[cfg(feature = "quic")]
rustls::crypto::ring::default_provider().install_default().expect("should install ring as default");

//if RUST_LOG env is not set, set it to info
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "info");
}
tracing_subscriber::registry().with(fmt::layer()).with(EnvFilter::from_default_env()).init();

//read local_key from file first, if not exist, create a new one and save to file
let agent_signer = match std::fs::read_to_string(&args.local_key) {
Ok(local_key) => match AgentLocalKey::from_pem(&local_key) {
Expand Down Expand Up @@ -110,6 +126,7 @@ async fn main() {
loop {
log::info!("Connecting to connector... {:?} addr: {}", args.connector_protocol, args.connector_addr);
match args.connector_protocol {
#[cfg(feature = "tcp")]
Protocol::Tcp => match TcpConnection::new(args.connector_addr.clone(), &agent_signer).await {
Ok(conn) => {
log::info!("Connected to connector via tcp with res {:?}", conn.response());
Expand All @@ -119,6 +136,7 @@ async fn main() {
log::error!("Connect to connector via tcp error: {e}");
}
},
#[cfg(feature = "quic")]
Protocol::Quic => match QuicConnection::new(args.connector_addr.clone(), &agent_signer, &server_certs, args.allow_quic_insecure).await {
Ok(conn) => {
log::info!("Connected to connector via quic with res {:?}", conn.response());
Expand Down

0 comments on commit 8ca0ba0

Please sign in to comment.