From 795b9fa5ac52eb01270c1aa974e3cd393048f1bb Mon Sep 17 00:00:00 2001 From: giangndm <45644921+giangndm@users.noreply.github.com> Date: Sat, 17 Aug 2024 14:38:08 +0700 Subject: [PATCH] feat: rtsp proxy (#37) --- crates/agent/examples/benchmark_clients.rs | 28 ++-- crates/agent/node_local_quic.sh | 1 + crates/agent/node_local_tcp.sh | 1 + crates/agent/src/lib.rs | 58 +++++--- crates/agent/src/local_tunnel.rs | 7 + crates/agent/src/local_tunnel/registry.rs | 40 ++++++ crates/agent/src/main.rs | 33 +++-- crates/protocol/src/cluster.rs | 21 +++ crates/protocol/src/lib.rs | 1 + crates/protocol/src/services.rs | 1 + crates/relayer/Cargo.toml | 1 + crates/relayer/src/agent_worker.rs | 33 +++-- crates/relayer/src/lib.rs | 3 +- crates/relayer/src/main.rs | 91 ++++++++++-- crates/relayer/src/proxy_listener.rs | 7 +- crates/relayer/src/proxy_listener/cluster.rs | 3 + crates/relayer/src/proxy_listener/http.rs | 130 ------------------ crates/relayer/src/proxy_listener/tcp.rs | 117 ++++++++++++++++ .../src/proxy_listener/tcp/http_detector.rs | 21 +++ .../src/proxy_listener/tcp/rtsp_detector.rs | 16 +++ .../src/proxy_listener/tcp/tls_detector.rs | 47 +++++++ crates/relayer/src/tunnel.rs | 5 + crates/relayer/src/utils.rs | 1 - 23 files changed, 470 insertions(+), 196 deletions(-) create mode 100644 crates/agent/src/local_tunnel/registry.rs create mode 100644 crates/protocol/src/services.rs delete mode 100644 crates/relayer/src/proxy_listener/http.rs create mode 100644 crates/relayer/src/proxy_listener/tcp.rs create mode 100644 crates/relayer/src/proxy_listener/tcp/http_detector.rs create mode 100644 crates/relayer/src/proxy_listener/tcp/rtsp_detector.rs create mode 100644 crates/relayer/src/proxy_listener/tcp/tls_detector.rs diff --git a/crates/agent/examples/benchmark_clients.rs b/crates/agent/examples/benchmark_clients.rs index fa38605..9ad851b 100644 --- a/crates/agent/examples/benchmark_clients.rs +++ b/crates/agent/examples/benchmark_clients.rs @@ -1,10 +1,12 @@ use std::{ net::SocketAddr, + sync::Arc, time::{Duration, Instant}, }; use atm0s_reverse_proxy_agent::{ - run_tunnel_connection, Connection, Protocol, QuicConnection, SubConnection, TcpConnection, + run_tunnel_connection, Connection, Protocol, QuicConnection, ServiceRegistry, + SimpleServiceRegistry, SubConnection, TcpConnection, }; use base64::{engine::general_purpose::URL_SAFE, Engine as _}; use clap::Parser; @@ -67,10 +69,14 @@ async fn main() { .with(tracing_subscriber::EnvFilter::from_default_env()) .init(); + let registry = SimpleServiceRegistry::new(args.http_dest, args.https_dest); + let registry = Arc::new(registry); + for client in 0..args.clients { - let args2 = args.clone(); + let args_c = args.clone(); + let registry = registry.clone(); async_std::task::spawn(async move { - async_std::task::spawn_local(connect(client, args2)); + async_std::task::spawn_local(connect(client, args_c, registry)); }); async_std::task::sleep(Duration::from_millis(args.connect_wait_ms)).await; } @@ -80,7 +86,7 @@ async fn main() { } } -async fn connect(client: usize, args: Args) { +async fn connect(client: usize, args: Args, registry: Arc) { let default_tunnel_cert_buf = include_bytes!("../../../certs/tunnel.cert"); let default_tunnel_cert = CertificateDer::from(default_tunnel_cert_buf.to_vec()); @@ -112,7 +118,7 @@ async fn connect(client: usize, args: Args) { conn.response() ); println!("{client} connected after {:?}", started.elapsed()); - run_connection_loop(conn, args.http_dest, args.https_dest).await; + run_connection_loop(conn, registry.clone()).await; } Err(e) => { log::error!("Connect to connector via tcp error: {}", e); @@ -134,7 +140,7 @@ async fn connect(client: usize, args: Args) { conn.response() ); println!("{client} connected after {:?}", started.elapsed()); - run_connection_loop(conn, args.http_dest, args.https_dest).await; + run_connection_loop(conn, registry.clone()).await; } Err(e) => { log::error!("Connect to connector via quic error: {}", e); @@ -149,8 +155,7 @@ async fn connect(client: usize, args: Args) { async fn run_connection_loop( mut connection: impl Connection, - http_dest: SocketAddr, - https_dest: SocketAddr, + registry: Arc, ) where S: SubConnection + 'static, R: AsyncRead + Send + Unpin + 'static, @@ -160,11 +165,8 @@ async fn run_connection_loop( match connection.recv().await { Ok(sub_connection) => { log::info!("recv sub_connection"); - async_std::task::spawn_local(run_tunnel_connection( - sub_connection, - http_dest, - https_dest, - )); + let registry = registry.clone(); + async_std::task::spawn_local(run_tunnel_connection(sub_connection, registry)); } Err(e) => { log::error!("recv sub_connection error: {}", e); diff --git a/crates/agent/node_local_quic.sh b/crates/agent/node_local_quic.sh index 7c15ca7..761f24c 100644 --- a/crates/agent/node_local_quic.sh +++ b/crates/agent/node_local_quic.sh @@ -3,4 +3,5 @@ RUST_LOG=info cargo run -- \ --connector-addr https://127.0.0.1:13001 \ --http-dest 127.0.0.1:8080 \ --https-dest 127.0.0.1:8443 \ + --rtsp-dest 10.10.30.90:554 \ --allow-quic-insecure diff --git a/crates/agent/node_local_tcp.sh b/crates/agent/node_local_tcp.sh index 08f8b59..ac8a61b 100644 --- a/crates/agent/node_local_tcp.sh +++ b/crates/agent/node_local_tcp.sh @@ -3,4 +3,5 @@ RUST_LOG=info cargo run -- \ --connector-addr tcp://127.0.0.1:13001 \ --http-dest 127.0.0.1:8080 \ --https-dest 127.0.0.1:8443 \ + --rtsp-dest 10.10.30.90:554 \ --allow-quic-insecure diff --git a/crates/agent/src/lib.rs b/crates/agent/src/lib.rs index cbce530..fddaf0d 100644 --- a/crates/agent/src/lib.rs +++ b/crates/agent/src/lib.rs @@ -1,9 +1,10 @@ -use std::net::SocketAddr; +use std::sync::Arc; use async_std::io::WriteExt; use futures::{select, AsyncRead, AsyncReadExt, AsyncWrite, FutureExt}; use local_tunnel::tcp::LocalTcpTunnel; +use protocol::cluster::AgentTunnelRequest; use crate::local_tunnel::LocalTunnel; @@ -15,12 +16,10 @@ pub use connection::{ tcp::{TcpConnection, TcpSubConnection}, Connection, Protocol, SubConnection, }; +pub use local_tunnel::{registry::SimpleServiceRegistry, ServiceRegistry}; -pub async fn run_tunnel_connection( - sub_connection: S, - http_dest: SocketAddr, - https_dest: SocketAddr, -) where +pub async fn run_tunnel_connection(sub_connection: S, registry: Arc) +where S: SubConnection + 'static, R: AsyncRead + Send + Unpin, W: AsyncWrite + Send + Unpin, @@ -28,19 +27,43 @@ pub async fn run_tunnel_connection( log::info!("sub_connection pipe to local_tunnel start"); let (mut reader1, mut writer1) = sub_connection.split(); let mut first_pkt = [0u8; 4096]; - let (local_tunnel, first_pkt_len) = match reader1.read(&mut first_pkt).await { + let (local_tunnel, first_pkt_start, first_pkt_end) = match reader1.read(&mut first_pkt).await { Ok(first_pkt_len) => { log::info!("first pkt size: {}", first_pkt_len); - if first_pkt_len == 0 { - log::error!("first pkt size is 0 => close"); + if first_pkt_len < 2 { + log::error!("first pkt size is < 4 => close"); return; } - if first_pkt[0] == 0x16 { - log::info!("create tunnel to https dest {}", https_dest); - (LocalTcpTunnel::new(https_dest).await, first_pkt_len) - } else { - log::info!("create tunnel to http dest {}", http_dest); - (LocalTcpTunnel::new(http_dest).await, first_pkt_len) + let handshake_len = u16::from_be_bytes([first_pkt[0], first_pkt[1]]) as usize; + if handshake_len + 2 > first_pkt_len { + log::error!("first pkt size is < handshake {handshake_len} + 2 => close"); + return; + } + match AgentTunnelRequest::try_from(&first_pkt[2..(handshake_len + 2)]) { + Ok(handshake) => { + if let Some(dest) = + registry.dest_for(handshake.tls, handshake.service, &handshake.domain) + { + log::info!("create tunnel to dest {}", dest); + ( + LocalTcpTunnel::new(dest).await, + handshake_len + 2, + first_pkt_len, + ) + } else { + log::warn!( + "dest for service {:?} tls {} domain {} not found", + handshake.service, + handshake.tls, + handshake.domain + ); + return; + } + } + Err(e) => { + log::error!("handshake parse error: {}", e); + return; + } } } Err(e) => { @@ -59,7 +82,10 @@ pub async fn run_tunnel_connection( let (mut reader2, mut writer2) = local_tunnel.split(); - if let Err(e) = writer2.write_all(&first_pkt[..first_pkt_len]).await { + if let Err(e) = writer2 + .write_all(&first_pkt[first_pkt_start..first_pkt_end]) + .await + { log::error!("write first pkt to local_tunnel error: {}", e); return; } diff --git a/crates/agent/src/local_tunnel.rs b/crates/agent/src/local_tunnel.rs index 1a61c66..22552c2 100644 --- a/crates/agent/src/local_tunnel.rs +++ b/crates/agent/src/local_tunnel.rs @@ -1,7 +1,14 @@ +use std::net::SocketAddr; + use futures::{AsyncRead, AsyncWrite}; +pub mod registry; pub mod tcp; pub trait LocalTunnel: Send + Sync { fn split(self) -> (R, W); } + +pub trait ServiceRegistry { + fn dest_for(&self, tls: bool, service: Option, domain: &str) -> Option; +} diff --git a/crates/agent/src/local_tunnel/registry.rs b/crates/agent/src/local_tunnel/registry.rs new file mode 100644 index 0000000..b4de59a --- /dev/null +++ b/crates/agent/src/local_tunnel/registry.rs @@ -0,0 +1,40 @@ +use std::{collections::HashMap, net::SocketAddr}; + +use super::ServiceRegistry; + +pub struct SimpleServiceRegistry { + default_tcp: SocketAddr, + default_tls: SocketAddr, + tcp_services: HashMap, + tls_services: HashMap, +} + +impl SimpleServiceRegistry { + pub fn new(default_tcp: SocketAddr, default_tls: SocketAddr) -> Self { + Self { + default_tcp, + default_tls, + tcp_services: HashMap::new(), + tls_services: HashMap::new(), + } + } + + pub fn set_tcp_service(&mut self, service: u16, dest: SocketAddr) { + self.tcp_services.insert(service, dest); + } + + pub fn set_tls_service(&mut self, service: u16, dest: SocketAddr) { + self.tls_services.insert(service, dest); + } +} + +impl ServiceRegistry for SimpleServiceRegistry { + fn dest_for(&self, tls: bool, service: Option, _domain: &str) -> Option { + match (tls, service) { + (false, None) => Some(self.default_tcp), + (true, None) => Some(self.default_tls), + (false, Some(service)) => self.tcp_services.get(&service).cloned(), + (true, Some(service)) => self.tls_services.get(&service).cloned(), + } + } +} diff --git a/crates/agent/src/main.rs b/crates/agent/src/main.rs index c185dee..199390b 100644 --- a/crates/agent/src/main.rs +++ b/crates/agent/src/main.rs @@ -1,11 +1,13 @@ -use std::{alloc::System, net::SocketAddr}; +use std::{alloc::System, net::SocketAddr, sync::Arc}; use atm0s_reverse_proxy_agent::{ - run_tunnel_connection, Connection, Protocol, QuicConnection, SubConnection, TcpConnection, + run_tunnel_connection, Connection, Protocol, QuicConnection, ServiceRegistry, + SimpleServiceRegistry, SubConnection, TcpConnection, }; use base64::{engine::general_purpose::URL_SAFE, Engine as _}; use clap::Parser; use futures::{AsyncRead, AsyncWrite}; +use protocol::services::SERVICE_RTSP; use protocol_ed25519::AgentLocalKey; use rustls::pki_types::CertificateDer; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; @@ -34,6 +36,14 @@ struct Args { #[arg(env, long, default_value = "127.0.0.1:8443")] https_dest: SocketAddr, + /// Rtsp proxy dest + #[arg(env, long, default_value = "127.0.0.1:554")] + rtsp_dest: SocketAddr, + + /// Sni-https proxy dest + #[arg(env, long, default_value = "127.0.0.1:5443")] + rtsps_dest: SocketAddr, + /// Persistent local key #[arg(env, long, default_value = "local_key.pem")] local_key: String, @@ -107,6 +117,11 @@ async fn main() { } }; + let mut registry = SimpleServiceRegistry::new(args.http_dest, args.https_dest); + registry.set_tcp_service(SERVICE_RTSP, args.rtsp_dest); + registry.set_tls_service(SERVICE_RTSP, args.rtsps_dest); + let registry = Arc::new(registry); + loop { log::info!( "Connecting to connector... {:?} addr: {}", @@ -121,7 +136,7 @@ async fn main() { "Connected to connector via tcp with res {:?}", conn.response() ); - run_connection_loop(conn, args.http_dest, args.https_dest).await; + run_connection_loop(conn, registry.clone()).await; } Err(e) => { log::error!("Connect to connector via tcp error: {}", e); @@ -142,7 +157,7 @@ async fn main() { "Connected to connector via quic with res {:?}", conn.response() ); - run_connection_loop(conn, args.http_dest, args.https_dest).await; + run_connection_loop(conn, registry.clone()).await; } Err(e) => { log::error!("Connect to connector via quic error: {}", e); @@ -157,8 +172,7 @@ async fn main() { pub async fn run_connection_loop( mut connection: impl Connection, - http_dest: SocketAddr, - https_dest: SocketAddr, + registry: Arc, ) where S: SubConnection + 'static, R: AsyncRead + Send + Unpin + 'static, @@ -168,11 +182,8 @@ pub async fn run_connection_loop( match connection.recv().await { Ok(sub_connection) => { log::info!("recv sub_connection"); - async_std::task::spawn_local(run_tunnel_connection( - sub_connection, - http_dest, - https_dest, - )); + let registry = registry.clone(); + async_std::task::spawn_local(run_tunnel_connection(sub_connection, registry)); } Err(e) => { log::error!("recv sub_connection error: {}", e); diff --git a/crates/protocol/src/cluster.rs b/crates/protocol/src/cluster.rs index 60ca069..da410a8 100644 --- a/crates/protocol/src/cluster.rs +++ b/crates/protocol/src/cluster.rs @@ -37,3 +37,24 @@ impl TryFrom<&[u8]> for ClusterTunnelResponse { bincode::deserialize(buf) } } + +#[derive(Debug, Serialize, Deserialize)] +pub struct AgentTunnelRequest { + pub service: Option, + pub tls: bool, + pub domain: String, +} + +impl From<&AgentTunnelRequest> for Vec { + fn from(resp: &AgentTunnelRequest) -> Self { + bincode::serialize(resp).expect("Should ok") + } +} + +impl TryFrom<&[u8]> for AgentTunnelRequest { + type Error = bincode::Error; + + fn try_from(buf: &[u8]) -> Result { + bincode::deserialize(buf) + } +} diff --git a/crates/protocol/src/lib.rs b/crates/protocol/src/lib.rs index 6d23bc0..5e2ea93 100644 --- a/crates/protocol/src/lib.rs +++ b/crates/protocol/src/lib.rs @@ -1,2 +1,3 @@ pub mod cluster; pub mod key; +pub mod services; diff --git a/crates/protocol/src/services.rs b/crates/protocol/src/services.rs new file mode 100644 index 0000000..566bc31 --- /dev/null +++ b/crates/protocol/src/services.rs @@ -0,0 +1 @@ +pub const SERVICE_RTSP: u16 = 554; diff --git a/crates/relayer/Cargo.toml b/crates/relayer/Cargo.toml index 5f783b2..e73bfd4 100644 --- a/crates/relayer/Cargo.toml +++ b/crates/relayer/Cargo.toml @@ -22,6 +22,7 @@ metrics = { version = "0.22.0" } quinn = { version = "0.11", features = ["ring", "runtime-async-std", "futures-io"] } rustls = { version = "0.23", features = ["ring", "std"] } atm0s-sdn = { git = "https://github.com/8xFF/atm0s-sdn.git", rev = "e5acc4458f8ce9bd0d9286bb3ad68a2a21fffb11" } +rtsp-types = "0.1.2" [features] default = ["binary"] diff --git a/crates/relayer/src/agent_worker.rs b/crates/relayer/src/agent_worker.rs index ef655f5..9fa6e2e 100644 --- a/crates/relayer/src/agent_worker.rs +++ b/crates/relayer/src/agent_worker.rs @@ -1,6 +1,6 @@ use std::{error::Error, marker::PhantomData, sync::Arc, time::Instant}; -use futures::{select, AsyncRead, AsyncWrite, FutureExt}; +use futures::{select, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt}; use metrics::{counter, gauge, histogram}; enum IncommingConn< @@ -16,8 +16,9 @@ use crate::{ agent_listener::{AgentConnection, AgentConnectionHandler, AgentSubConnection}, proxy_listener::ProxyTunnel, METRICS_PROXY_AGENT_COUNT, METRICS_PROXY_AGENT_ERROR_COUNT, METRICS_PROXY_AGENT_LIVE, - METRICS_PROXY_CLUSTER_LIVE, METRICS_PROXY_HTTP_LIVE, METRICS_TUNNEL_AGENT_COUNT, - METRICS_TUNNEL_AGENT_ERROR_COUNT, METRICS_TUNNEL_AGENT_HISTOGRAM, METRICS_TUNNEL_AGENT_LIVE, + METRICS_PROXY_CLUSTER_ERROR_COUNT, METRICS_PROXY_CLUSTER_LIVE, METRICS_PROXY_HTTP_ERROR_COUNT, + METRICS_PROXY_HTTP_LIVE, METRICS_TUNNEL_AGENT_COUNT, METRICS_TUNNEL_AGENT_ERROR_COUNT, + METRICS_TUNNEL_AGENT_HISTOGRAM, METRICS_TUNNEL_AGENT_LIVE, }; pub struct AgentWorker @@ -63,19 +64,35 @@ where counter!(METRICS_TUNNEL_AGENT_COUNT).increment(1); let started = Instant::now(); let sub_connection = self.connection.create_sub_connection().await?; - histogram!(METRICS_TUNNEL_AGENT_HISTOGRAM) - .record(started.elapsed().as_millis() as f32 / 1000.0); async_std::task::spawn(async move { + let domain = conn.domain().to_string(); + let (mut reader1, mut writer1) = sub_connection.split(); + if let Some(handshake) = conn.handshake() { + let err1 = writer1 + .write_all(&(handshake.len() as u16).to_be_bytes()) + .await; + let err2 = writer1.write_all(handshake).await; + if let Err(e) = err1.and(err2) { + log::error!("handshake for domain {domain} failed {:?}", e); + if conn.local() { + gauge!(METRICS_PROXY_HTTP_ERROR_COUNT).increment(1.0); + } else { + gauge!(METRICS_PROXY_CLUSTER_ERROR_COUNT).increment(1.0); + } + return; + } + } + histogram!(METRICS_TUNNEL_AGENT_HISTOGRAM) + .record(started.elapsed().as_millis() as f32 / 1000.0); + if conn.local() { gauge!(METRICS_PROXY_HTTP_LIVE).increment(1.0); } else { gauge!(METRICS_PROXY_CLUSTER_LIVE).increment(1.0); } gauge!(METRICS_TUNNEL_AGENT_LIVE).increment(1.0); - let domain = conn.domain().to_string(); - log::info!("start proxy tunnel for domain {}", domain); - let (mut reader1, mut writer1) = sub_connection.split(); + log::info!("start proxy tunnel for domain {domain}"); let (mut reader2, mut writer2) = conn.split(); let job1 = futures::io::copy(&mut reader1, &mut writer2); diff --git a/crates/relayer/src/lib.rs b/crates/relayer/src/lib.rs index ba231e5..6fa9b9b 100644 --- a/crates/relayer/src/lib.rs +++ b/crates/relayer/src/lib.rs @@ -62,10 +62,11 @@ pub use proxy_listener::cluster::{ pub use quinn; pub use proxy_listener::cluster::{run_sdn, ProxyClusterListener, ProxyClusterTunnel}; -pub use proxy_listener::http::{ProxyHttpListener, ProxyHttpTunnel}; +pub use proxy_listener::tcp::{ProxyTcpListener, ProxyTcpTunnel}; pub use proxy_listener::{ProxyListener, ProxyTunnel}; pub use agent_store::AgentStore; +pub use proxy_listener::tcp::{HttpDomainDetector, RtspDomainDetector, TlsDomainDetector}; pub use tunnel::{tunnel_task, TunnelContext}; pub async fn run_agent_connection( diff --git a/crates/relayer/src/main.rs b/crates/relayer/src/main.rs index 3d074e1..f30f6fe 100644 --- a/crates/relayer/src/main.rs +++ b/crates/relayer/src/main.rs @@ -1,14 +1,14 @@ use atm0s_reverse_proxy_relayer::{ run_agent_connection, run_sdn, tunnel_task, AgentIncommingConnHandlerDummy, AgentListener, - AgentQuicListener, AgentStore, AgentTcpListener, ProxyHttpListener, ProxyListener, - TunnelContext, METRICS_AGENT_COUNT, METRICS_AGENT_HISTOGRAM, METRICS_AGENT_LIVE, - METRICS_PROXY_AGENT_COUNT, METRICS_PROXY_AGENT_ERROR_COUNT, METRICS_PROXY_AGENT_HISTOGRAM, - METRICS_PROXY_AGENT_LIVE, METRICS_PROXY_CLUSTER_COUNT, METRICS_PROXY_CLUSTER_ERROR_COUNT, - METRICS_PROXY_CLUSTER_LIVE, METRICS_PROXY_HTTP_COUNT, METRICS_PROXY_HTTP_ERROR_COUNT, - METRICS_PROXY_HTTP_LIVE, METRICS_TUNNEL_AGENT_COUNT, METRICS_TUNNEL_AGENT_ERROR_COUNT, - METRICS_TUNNEL_AGENT_HISTOGRAM, METRICS_TUNNEL_AGENT_LIVE, METRICS_TUNNEL_CLUSTER_COUNT, - METRICS_TUNNEL_CLUSTER_ERROR_COUNT, METRICS_TUNNEL_CLUSTER_HISTOGRAM, - METRICS_TUNNEL_CLUSTER_LIVE, + AgentQuicListener, AgentStore, AgentTcpListener, HttpDomainDetector, ProxyListener, + ProxyTcpListener, RtspDomainDetector, TlsDomainDetector, TunnelContext, METRICS_AGENT_COUNT, + METRICS_AGENT_HISTOGRAM, METRICS_AGENT_LIVE, METRICS_PROXY_AGENT_COUNT, + METRICS_PROXY_AGENT_ERROR_COUNT, METRICS_PROXY_AGENT_HISTOGRAM, METRICS_PROXY_AGENT_LIVE, + METRICS_PROXY_CLUSTER_COUNT, METRICS_PROXY_CLUSTER_ERROR_COUNT, METRICS_PROXY_CLUSTER_LIVE, + METRICS_PROXY_HTTP_COUNT, METRICS_PROXY_HTTP_ERROR_COUNT, METRICS_PROXY_HTTP_LIVE, + METRICS_TUNNEL_AGENT_COUNT, METRICS_TUNNEL_AGENT_ERROR_COUNT, METRICS_TUNNEL_AGENT_HISTOGRAM, + METRICS_TUNNEL_AGENT_LIVE, METRICS_TUNNEL_CLUSTER_COUNT, METRICS_TUNNEL_CLUSTER_ERROR_COUNT, + METRICS_TUNNEL_CLUSTER_HISTOGRAM, METRICS_TUNNEL_CLUSTER_LIVE, }; use atm0s_sdn::{NodeAddr, NodeId}; use clap::Parser; @@ -16,6 +16,7 @@ use clap::Parser; use metrics_dashboard::{build_dashboard_route, DashboardOptions}; #[cfg(feature = "expose-metrics")] use poem::{listener::TcpListener, middleware::Tracing, EndpointExt as _, Route, Server}; +use protocol::services::SERVICE_RTSP; use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; #[cfg(feature = "expose-metrics")] use std::net::{Ipv4Addr, SocketAddrV4}; @@ -41,6 +42,14 @@ struct Args { #[arg(env, long, default_value_t = 443)] https_port: u16, + /// Rtsp proxy port + #[arg(env, long, default_value_t = 554)] + rtsp_port: u16, + + /// Sni-rtsp proxy port + #[arg(env, long, default_value_t = 5443)] + rtsps_port: u16, + /// Number of times to greet #[arg(env, long, default_value = "0.0.0.0:33333")] connector_port: SocketAddr, @@ -105,12 +114,38 @@ async fn main() { .await; let mut tcp_agent_listener = AgentTcpListener::new(args.connector_port, cluster_validator).await; - let mut proxy_http_listener = ProxyHttpListener::new(args.http_port, false) - .await - .expect("Should listen http port"); - let mut proxy_tls_listener = ProxyHttpListener::new(args.https_port, true) - .await - .expect("Should listen tls port"); + let mut proxy_http_listener = ProxyTcpListener::new( + args.http_port, + false, + None, + Arc::new(HttpDomainDetector::default()), + ) + .await + .expect("Should listen http port"); + let mut proxy_tls_listener = ProxyTcpListener::new( + args.https_port, + true, + None, + Arc::new(TlsDomainDetector::default()), + ) + .await + .expect("Should listen tls port"); + let mut proxy_rtsp_listener = ProxyTcpListener::new( + args.rtsp_port, + false, + Some(SERVICE_RTSP), + Arc::new(RtspDomainDetector::default()), + ) + .await + .expect("Should listen rtsp port"); + let mut proxy_rtsps_listener = ProxyTcpListener::new( + args.rtsps_port, + true, + Some(SERVICE_RTSP), + Arc::new(TlsDomainDetector::default()), + ) + .await + .expect("Should listen rtsps port"); let agents = AgentStore::new(); #[cfg(feature = "expose-metrics")] @@ -274,6 +309,32 @@ async fn main() { exit(2); } }, + e = proxy_rtsps_listener.recv().fuse() => match e { + Some(proxy_tunnel) => { + if let Some(socket) = virtual_net.udp_socket(0).await { + async_std::task::spawn(tunnel_task(proxy_tunnel, agents.clone(), TunnelContext::Local(alias_sdk.clone(), socket, vec![default_cluster_cert.clone()]))); + } else { + log::error!("Virtual Net create socket error"); + } + } + None => { + log::error!("proxy_http_listener.recv()"); + exit(2); + } + }, + e = proxy_rtsp_listener.recv().fuse() => match e { + Some(proxy_tunnel) => { + if let Some(socket) = virtual_net.udp_socket(0).await { + async_std::task::spawn(tunnel_task(proxy_tunnel, agents.clone(), TunnelContext::Local(alias_sdk.clone(), socket, vec![default_cluster_cert.clone()]))); + } else { + log::error!("Virtual Net create socket error"); + } + } + None => { + log::error!("proxy_http_listener.recv()"); + exit(2); + } + }, e = cluster_endpoint.recv().fuse() => match e { Some(proxy_tunnel) => { async_std::task::spawn(tunnel_task(proxy_tunnel, agents.clone(), TunnelContext::Cluster)); diff --git a/crates/relayer/src/proxy_listener.rs b/crates/relayer/src/proxy_listener.rs index 32c8a7c..4d22f15 100644 --- a/crates/relayer/src/proxy_listener.rs +++ b/crates/relayer/src/proxy_listener.rs @@ -3,13 +3,18 @@ use futures::{AsyncRead, AsyncWrite}; pub mod cluster; -pub mod http; +pub mod tcp; + +pub trait DomainDetector: Send + Sync { + fn get_domain(&self, buf: &[u8]) -> Option; +} #[async_trait::async_trait] pub trait ProxyTunnel: Send + Sync { async fn wait(&mut self) -> Option<()>; fn local(&self) -> bool; fn domain(&self) -> &str; + fn handshake(&self) -> Option<&[u8]>; fn split( &mut self, ) -> ( diff --git a/crates/relayer/src/proxy_listener/cluster.rs b/crates/relayer/src/proxy_listener/cluster.rs index 0b3d914..0bb7612 100644 --- a/crates/relayer/src/proxy_listener/cluster.rs +++ b/crates/relayer/src/proxy_listener/cluster.rs @@ -233,6 +233,9 @@ impl ProxyTunnel for ProxyClusterTunnel { fn domain(&self) -> &str { &self.domain } + fn handshake(&self) -> Option<&[u8]> { + None + } fn split( &mut self, ) -> ( diff --git a/crates/relayer/src/proxy_listener/http.rs b/crates/relayer/src/proxy_listener/http.rs deleted file mode 100644 index 5db0e6b..0000000 --- a/crates/relayer/src/proxy_listener/http.rs +++ /dev/null @@ -1,130 +0,0 @@ -use std::net::{Ipv4Addr, SocketAddr}; - -use async_std::net::{TcpListener, TcpStream}; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite}; -use tls_parser::{parse_tls_extensions, parse_tls_plaintext}; - -use super::{ProxyListener, ProxyTunnel}; - -pub struct ProxyHttpListener { - tcp_listener: TcpListener, - tls: bool, -} - -impl ProxyHttpListener { - pub async fn new(port: u16, tls: bool) -> Option { - log::info!("ProxyHttpListener::new {}", port); - Some(Self { - tcp_listener: TcpListener::bind(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), port)) - .await - .ok()?, - tls, - }) - } -} - -#[async_trait::async_trait] -impl ProxyListener for ProxyHttpListener { - async fn recv(&mut self) -> Option> { - let (stream, remote) = self.tcp_listener.accept().await.ok()?; - log::info!("[ProxyHttpListener] new conn from {}", remote); - Some(Box::new(ProxyHttpTunnel { - domain: "".to_string(), - stream: Some(stream), - tls: self.tls, - })) - } -} - -pub struct ProxyHttpTunnel { - domain: String, - stream: Option, - tls: bool, -} - -#[async_trait::async_trait] -impl ProxyTunnel for ProxyHttpTunnel { - async fn wait(&mut self) -> Option<()> { - log::info!("[ProxyHttpTunnel] wait first data for checking url..."); - let mut first_pkt = [0u8; 4096]; - let stream = self.stream.as_mut()?; - let first_pkt_size = stream.peek(&mut first_pkt).await.ok()?; - log::info!( - "[ProxyHttpTunnel] read {} bytes for determine url", - first_pkt_size - ); - if self.tls { - self.domain = get_sni_from_packet(&first_pkt[..first_pkt_size])?; - } else { - let mut headers = [httparse::EMPTY_HEADER; 64]; - let mut req = httparse::Request::new(&mut headers); - let _ = req.parse(&first_pkt[..first_pkt_size]).ok()?; - let domain = req - .headers - .iter() - .find(|h| h.name.to_lowercase() == "host")? - .value; - // dont get the port - let domain = String::from_utf8_lossy(domain).to_string(); - let domain = domain.split(':').next()?; - self.domain = domain.to_string(); - } - Some(()) - } - - fn local(&self) -> bool { - true - } - - fn domain(&self) -> &str { - &self.domain - } - fn split( - &mut self, - ) -> ( - Box, - Box, - ) { - let (read, write) = AsyncReadExt::split(self.stream.take().expect("Should has stream")); - (Box::new(read), Box::new(write)) - } -} - -fn get_sni_from_packet(packet: &[u8]) -> Option { - let res = match parse_tls_plaintext(&packet) { - Ok(res) => res, - Err(e) => { - log::error!("parse_tls_plaintext error {:?}", e); - return None; - } - }; - - let tls_message = &res.1.msg[0]; - if let tls_parser::TlsMessage::Handshake(handshake) = tls_message { - if let tls_parser::TlsMessageHandshake::ClientHello(client_hello) = handshake { - // get the extensions - let extensions: &[u8] = client_hello.ext?; - // parse the extensions - let res = match parse_tls_extensions(extensions) { - Ok(res) => res, - Err(e) => { - log::error!("parse_tls_extensions error {:?}", e); - return None; - } - }; - // iterate over the extensions and find the SNI - for extension in res.1 { - if let tls_parser::TlsExtension::SNI(sni) = extension { - // get the hostname - let hostname: &[u8] = sni[0].1; - let s: String = match String::from_utf8(hostname.to_vec()) { - Ok(v) => v, - Err(e) => panic!("Invalid UTF-8 sequence: {}", e), - }; - return Some(s); - } - } - } - } - None -} diff --git a/crates/relayer/src/proxy_listener/tcp.rs b/crates/relayer/src/proxy_listener/tcp.rs new file mode 100644 index 0000000..b44e1cd --- /dev/null +++ b/crates/relayer/src/proxy_listener/tcp.rs @@ -0,0 +1,117 @@ +use std::{ + net::{Ipv4Addr, SocketAddr}, + sync::Arc, +}; + +use async_std::net::{TcpListener, TcpStream}; +use futures::{AsyncRead, AsyncReadExt, AsyncWrite}; +use protocol::cluster::AgentTunnelRequest; + +use super::{DomainDetector, ProxyListener, ProxyTunnel}; + +mod http_detector; +mod rtsp_detector; +mod tls_detector; + +pub use http_detector::*; +pub use rtsp_detector::*; +pub use tls_detector::*; + +pub struct ProxyTcpListener { + tcp_listener: TcpListener, + tls: bool, + service: Option, + detector: Arc, +} + +impl ProxyTcpListener { + pub async fn new( + port: u16, + tls: bool, + service: Option, + detector: Arc, + ) -> Option { + log::info!( + "ProxyTcpListener::new port {port} tls {tls} service {:?}", + service + ); + Some(Self { + tcp_listener: TcpListener::bind(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), port)) + .await + .ok()?, + tls, + service, + detector, + }) + } +} + +#[async_trait::async_trait] +impl ProxyListener for ProxyTcpListener { + async fn recv(&mut self) -> Option> { + let (stream, remote) = self.tcp_listener.accept().await.ok()?; + log::info!("[ProxyTcpListener] new conn from {}", remote); + Some(Box::new(ProxyTcpTunnel { + detector: self.detector.clone(), + service: self.service, + domain: "".to_string(), + handshake: vec![], + stream: Some(stream), + tls: self.tls, + })) + } +} + +pub struct ProxyTcpTunnel { + detector: Arc, + service: Option, + domain: String, + stream: Option, + handshake: Vec, + tls: bool, +} + +#[async_trait::async_trait] +impl ProxyTunnel for ProxyTcpTunnel { + async fn wait(&mut self) -> Option<()> { + log::info!("[ProxyTcpTunnel] wait first data for checking url..."); + let mut first_pkt = [0u8; 4096]; + let stream = self.stream.as_mut()?; + let first_pkt_size = stream.peek(&mut first_pkt).await.ok()?; + log::info!( + "[ProxyTcpTunnel] read {} bytes for determine url", + first_pkt_size + ); + self.domain = self.detector.get_domain(&first_pkt[..first_pkt_size])?; + log::info!("[PRoxyTcpTunnel] detected domain {}", self.domain); + self.handshake = (&AgentTunnelRequest { + service: self.service.clone(), + tls: self.tls, + domain: self.domain.clone(), + }) + .into(); + Some(()) + } + + fn local(&self) -> bool { + true + } + + fn domain(&self) -> &str { + &self.domain + } + + fn handshake(&self) -> Option<&[u8]> { + Some(&self.handshake) + } + + fn split( + &mut self, + ) -> ( + Box, + Box, + ) { + let (read, write) = AsyncReadExt::split(self.stream.take().expect("Should has stream")); + (Box::new(read), Box::new(write)) + } +} diff --git a/crates/relayer/src/proxy_listener/tcp/http_detector.rs b/crates/relayer/src/proxy_listener/tcp/http_detector.rs new file mode 100644 index 0000000..b6c242b --- /dev/null +++ b/crates/relayer/src/proxy_listener/tcp/http_detector.rs @@ -0,0 +1,21 @@ +use crate::proxy_listener::DomainDetector; + +#[derive(Default)] +pub struct HttpDomainDetector(); + +impl DomainDetector for HttpDomainDetector { + fn get_domain(&self, buf: &[u8]) -> Option { + let mut headers = [httparse::EMPTY_HEADER; 64]; + let mut req = httparse::Request::new(&mut headers); + let _ = req.parse(buf).ok()?; + let domain = req + .headers + .iter() + .find(|h| h.name.to_lowercase() == "host")? + .value; + // dont get the port + let domain = String::from_utf8_lossy(domain).to_string(); + let domain = domain.split(':').next()?; + Some(domain.to_string()) + } +} diff --git a/crates/relayer/src/proxy_listener/tcp/rtsp_detector.rs b/crates/relayer/src/proxy_listener/tcp/rtsp_detector.rs new file mode 100644 index 0000000..56c4774 --- /dev/null +++ b/crates/relayer/src/proxy_listener/tcp/rtsp_detector.rs @@ -0,0 +1,16 @@ +use crate::proxy_listener::DomainDetector; + +#[derive(Default)] +pub struct RtspDomainDetector(); + +impl DomainDetector for RtspDomainDetector { + fn get_domain(&self, buf: &[u8]) -> Option { + let (message, _consumed): (rtsp_types::Message>, _) = + rtsp_types::Message::parse(buf).ok()?; + log::info!("{:?}", message); + match message { + rtsp_types::Message::Request(req) => req.request_uri()?.host().map(|h| h.to_string()), + _ => None, + } + } +} diff --git a/crates/relayer/src/proxy_listener/tcp/tls_detector.rs b/crates/relayer/src/proxy_listener/tcp/tls_detector.rs new file mode 100644 index 0000000..9b31b36 --- /dev/null +++ b/crates/relayer/src/proxy_listener/tcp/tls_detector.rs @@ -0,0 +1,47 @@ +use tls_parser::{parse_tls_extensions, parse_tls_plaintext}; + +use crate::proxy_listener::DomainDetector; + +#[derive(Default)] +pub struct TlsDomainDetector(); + +impl DomainDetector for TlsDomainDetector { + fn get_domain(&self, packet: &[u8]) -> Option { + let res = match parse_tls_plaintext(&packet) { + Ok(res) => res, + Err(e) => { + log::error!("parse_tls_plaintext error {:?}", e); + return None; + } + }; + + let tls_message = &res.1.msg[0]; + if let tls_parser::TlsMessage::Handshake(handshake) = tls_message { + if let tls_parser::TlsMessageHandshake::ClientHello(client_hello) = handshake { + // get the extensions + let extensions: &[u8] = client_hello.ext?; + // parse the extensions + let res = match parse_tls_extensions(extensions) { + Ok(res) => res, + Err(e) => { + log::error!("parse_tls_extensions error {:?}", e); + return None; + } + }; + // iterate over the extensions and find the SNI + for extension in res.1 { + if let tls_parser::TlsExtension::SNI(sni) = extension { + // get the hostname + let hostname: &[u8] = sni[0].1; + let s: String = match String::from_utf8(hostname.to_vec()) { + Ok(v) => v, + Err(e) => panic!("Invalid UTF-8 sequence: {}", e), + }; + return Some(s); + } + } + } + } + None + } +} diff --git a/crates/relayer/src/tunnel.rs b/crates/relayer/src/tunnel.rs index 98f5e6b..4f94fc3 100644 --- a/crates/relayer/src/tunnel.rs +++ b/crates/relayer/src/tunnel.rs @@ -120,6 +120,11 @@ async fn tunnel_over_cluster<'a>( let connection = connecting.await?; log::info!("connected to agent for domain: {domain} in node {dest}"); let (mut send, mut recv) = connection.open_bi().await?; + if let Some(handshake) = proxy_tunnel.handshake() { + send.write_all(&(handshake.len() as u16).to_be_bytes()) + .await?; + send.write_all(handshake).await?; + } log::info!("opened bi stream to agent for domain: {domain} in node {dest}"); histogram!(METRICS_TUNNEL_CLUSTER_HISTOGRAM) diff --git a/crates/relayer/src/utils.rs b/crates/relayer/src/utils.rs index 83fd580..9578996 100644 --- a/crates/relayer/src/utils.rs +++ b/crates/relayer/src/utils.rs @@ -1,7 +1,6 @@ use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, - time::Instant, }; /// get home id from domain by get first subdomain