Skip to content

Commit

Permalink
fix: wrong handshake when tunnel request over cluster (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
giangndm authored Sep 9, 2024
1 parent fa4f7cb commit fa3c018
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 48 deletions.
1 change: 1 addition & 0 deletions crates/protocol/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct ClusterTunnelRequest {
pub domain: String,
pub handshake: Vec<u8>,
}

impl From<&ClusterTunnelRequest> for Vec<u8> {
Expand Down
5 changes: 4 additions & 1 deletion crates/relayer/run_local_node1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ cargo run -- \
--connector-port 0.0.0.0:13001 \
--root-domain local.ha.8xff.io \
--sdn-node-id 1 \
--sdn-port 50001
--sdn-ip 127.0.0.1 \
--sdn-port 50001 \
--rtsp-port 5341 \
--rtsps-port 15341 \
3 changes: 3 additions & 0 deletions crates/relayer/run_local_node2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@ cargo run -- \
--connector-port 0.0.0.0:13002 \
--root-domain local.ha.8xff.io \
--sdn-node-id 2 \
--sdn-ip 127.0.0.1 \
--sdn-port 50002 \
--rtsp-port 5342 \
--rtsps-port 15342 \
--sdn-seeds '1@/ip4/127.0.0.1/udp/50001'
25 changes: 12 additions & 13 deletions crates/relayer/src/agent_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,19 @@ where
async_std::task::spawn(async move {
let domain = conn.domain().to_string();
let (mut reader1, mut writer1) = sub_connection.split();
if let Some(handshake) = conn.handshake() {
let err1 = writer1
.write_all(&(handshake.len() as u16).to_be_bytes())
.await;
let err2 = writer1.write_all(handshake).await;
if let Err(e) = err1.and(err2) {
log::error!("handshake for domain {domain} failed {:?}", e);
if conn.local() {
gauge!(METRICS_PROXY_HTTP_ERROR_COUNT).increment(1.0);
} else {
gauge!(METRICS_PROXY_CLUSTER_ERROR_COUNT).increment(1.0);
}
return;
let handshake = conn.handshake();
let err1 = writer1
.write_all(&(handshake.len() as u16).to_be_bytes())
.await;
let err2 = writer1.write_all(handshake).await;
if let Err(e) = err1.and(err2) {
log::error!("handshake for domain {domain} failed {:?}", e);
if conn.local() {
gauge!(METRICS_PROXY_HTTP_ERROR_COUNT).increment(1.0);
} else {
gauge!(METRICS_PROXY_CLUSTER_ERROR_COUNT).increment(1.0);
}
return;
}
histogram!(METRICS_TUNNEL_AGENT_HISTOGRAM)
.record(started.elapsed().as_millis() as f32 / 1000.0);
Expand Down
43 changes: 29 additions & 14 deletions crates/relayer/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ use protocol::{
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
#[cfg(feature = "expose-metrics")]
use std::net::{Ipv4Addr, SocketAddrV4};
use std::{net::SocketAddr, process::exit, sync::Arc};
use std::{
net::{IpAddr, SocketAddr},
process::exit,
sync::Arc,
};

use futures::{select, FutureExt};
use metrics::{describe_counter, describe_gauge, describe_histogram};
Expand Down Expand Up @@ -65,10 +69,14 @@ struct Args {
#[arg(env, long)]
sdn_node_id: NodeId,

/// atm0s-sdn node-id
/// atm0s-sdn port
#[arg(env, long, default_value_t = 0)]
sdn_port: u16,

/// atm0s-sdn ip
#[arg(env, long)]
sdn_ip: Vec<IpAddr>,

/// atm0s-sdn secret key
#[arg(env, long, default_value = "insecure")]
sdn_secret_key: String,
Expand Down Expand Up @@ -247,18 +255,25 @@ async fn main() {
.run(app)
.await;
});
let sdn_addrs = local_ip_address::list_afinet_netifas()
.expect("Should have list interfaces")
.into_iter()
.filter(|(_, ip)| {
if ip.is_unspecified() || ip.is_multicast() {
false
} else {
std::net::UdpSocket::bind(SocketAddr::new(*ip, 0)).is_ok()
}
})
.map(|(_name, ip)| SocketAddr::new(ip, args.sdn_port))
.collect::<Vec<_>>();
let sdn_addrs = if args.sdn_ip.is_empty() {
local_ip_address::list_afinet_netifas()
.expect("Should have list interfaces")
.into_iter()
.filter(|(_, ip)| {
if ip.is_unspecified() || ip.is_multicast() {
false
} else {
std::net::UdpSocket::bind(SocketAddr::new(*ip, 0)).is_ok()
}
})
.map(|(_name, ip)| SocketAddr::new(ip, args.sdn_port))
.collect::<Vec<_>>()
} else {
args.sdn_ip
.iter()
.map(|ip| SocketAddr::new(*ip, args.sdn_port))
.collect::<Vec<_>>()
};

let (mut cluster_endpoint, alias_sdk, mut virtual_net) = run_sdn(
args.sdn_node_id,
Expand Down
4 changes: 3 additions & 1 deletion crates/relayer/src/proxy_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ pub mod cluster;
pub mod tcp;

pub trait DomainDetector: Send + Sync {
fn name(&self) -> &str;
fn get_domain(&self, buf: &[u8]) -> Option<String>;
}

#[async_trait::async_trait]
pub trait ProxyTunnel: Send + Sync {
async fn wait(&mut self) -> Option<()>;
fn source_addr(&self) -> String;
fn local(&self) -> bool;
fn domain(&self) -> &str;
fn handshake(&self) -> Option<&[u8]>;
fn handshake(&self) -> &[u8];
fn split(
&mut self,
) -> (
Expand Down
29 changes: 24 additions & 5 deletions crates/relayer/src/proxy_listener/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,19 @@ impl ProxyListener for ProxyClusterListener {
let connecting = self.server.accept().await?;
log::info!("incoming connection from {}", connecting.remote_address());
Some(Box::new(ProxyClusterTunnel {
virtual_addr: connecting.remote_address(),
domain: "".to_string(),
handshake: vec![],
connecting: Some(connecting),
streams: None,
}))
}
}

pub struct ProxyClusterTunnel {
virtual_addr: SocketAddr,
domain: String,
handshake: Vec<u8>,
connecting: Option<Incoming>,
streams: Option<(
Box<dyn AsyncRead + Send + Sync + Unpin>,
Expand All @@ -216,20 +220,35 @@ pub struct ProxyClusterTunnel {

#[async_trait::async_trait]
impl ProxyTunnel for ProxyClusterTunnel {
fn source_addr(&self) -> String {
format!("sdn-quic://{}", self.virtual_addr)
}

async fn wait(&mut self) -> Option<()> {
let connecting = self.connecting.take()?;
let connection = connecting.await.ok()?;
log::info!("incoming connection from: {}", connection.remote_address());
log::info!(
"[ProxyClusterTunnel] incoming connection from: {}",
connection.remote_address()
);
let (mut send, mut recv) = connection.accept_bi().await.ok()?;
log::info!("accepted bi stream from: {}", connection.remote_address());
log::info!(
"[ProxyClusterTunnel] accepted bi stream from: {}",
connection.remote_address()
);
let mut req_buf = [0; 1500];
let req_size = recv.read(&mut req_buf).await.ok()??;
log::info!(
"[ProxyClusterTunnel] read {req_size} handhshake buffer from: {}",
connection.remote_address()
);
let req = ClusterTunnelRequest::try_from(&req_buf[..req_size]).ok()?;
let res_buf: Vec<u8> = (&ClusterTunnelResponse { success: true }).into();
send.write_all(&res_buf).await.ok()?;
log::info!("ProxyClusterTunnel domain: {}", req.domain);
log::info!("[ProxyClusterTunnel] got domain: {}", req.domain);

self.domain = req.domain;
self.handshake = req.handshake;
self.streams = Some((Box::new(recv), Box::new(send)));
Some(())
}
Expand All @@ -239,8 +258,8 @@ impl ProxyTunnel for ProxyClusterTunnel {
fn domain(&self) -> &str {
&self.domain
}
fn handshake(&self) -> Option<&[u8]> {
None
fn handshake(&self) -> &[u8] {
&self.handshake
}
fn split(
&mut self,
Expand Down
16 changes: 13 additions & 3 deletions crates/relayer/src/proxy_listener/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl ProxyListener for ProxyTcpListener {
let (stream, remote) = self.tcp_listener.accept().await.ok()?;
log::info!("[ProxyTcpListener] new conn from {}", remote);
Some(Box::new(ProxyTcpTunnel {
stream_addr: remote,
detector: self.detector.clone(),
service: self.service,
domain: "".to_string(),
Expand All @@ -63,6 +64,7 @@ impl ProxyListener for ProxyTcpListener {
}

pub struct ProxyTcpTunnel {
stream_addr: SocketAddr,
detector: Arc<dyn DomainDetector>,
service: Option<u16>,
domain: String,
Expand All @@ -73,6 +75,14 @@ pub struct ProxyTcpTunnel {

#[async_trait::async_trait]
impl ProxyTunnel for ProxyTcpTunnel {
fn source_addr(&self) -> String {
if self.tls {
format!("tls+{}://{}", self.detector.name(), self.stream_addr)
} else {
format!("tcp+{}://{}", self.detector.name(), self.stream_addr)
}
}

async fn wait(&mut self) -> Option<()> {
log::info!("[ProxyTcpTunnel] wait first data for checking url...");
let mut first_pkt = [0u8; 4096];
Expand All @@ -83,7 +93,7 @@ impl ProxyTunnel for ProxyTcpTunnel {
first_pkt_size
);
self.domain = self.detector.get_domain(&first_pkt[..first_pkt_size])?;
log::info!("[PRoxyTcpTunnel] detected domain {}", self.domain);
log::info!("[ProxyTcpTunnel] detected domain {}", self.domain);
self.handshake = (&AgentTunnelRequest {
service: self.service,
tls: self.tls,
Expand All @@ -101,8 +111,8 @@ impl ProxyTunnel for ProxyTcpTunnel {
&self.domain
}

fn handshake(&self) -> Option<&[u8]> {
Some(&self.handshake)
fn handshake(&self) -> &[u8] {
&self.handshake
}

fn split(
Expand Down
8 changes: 8 additions & 0 deletions crates/relayer/src/proxy_listener/tcp/http_detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@ use crate::proxy_listener::DomainDetector;
pub struct HttpDomainDetector();

impl DomainDetector for HttpDomainDetector {
fn name(&self) -> &str {
"http"
}

fn get_domain(&self, buf: &[u8]) -> Option<String> {
log::info!(
"[HttpDomainDetector] check domain for {}",
String::from_utf8_lossy(buf)
);
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
let _ = req.parse(buf).ok()?;
Expand Down
8 changes: 8 additions & 0 deletions crates/relayer/src/proxy_listener/tcp/rtsp_detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@ use crate::proxy_listener::DomainDetector;
pub struct RtspDomainDetector();

impl DomainDetector for RtspDomainDetector {
fn name(&self) -> &str {
"rtsp"
}

fn get_domain(&self, buf: &[u8]) -> Option<String> {
log::info!(
"[RtspDomainDetector] check domain for {}",
String::from_utf8_lossy(buf)
);
let (message, _consumed): (rtsp_types::Message<Vec<u8>>, _) =
rtsp_types::Message::parse(buf).ok()?;
log::info!("{:?}", message);
Expand Down
8 changes: 8 additions & 0 deletions crates/relayer/src/proxy_listener/tcp/tls_detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@ use crate::proxy_listener::DomainDetector;
pub struct TlsDomainDetector();

impl DomainDetector for TlsDomainDetector {
fn name(&self) -> &str {
"tls"
}

fn get_domain(&self, packet: &[u8]) -> Option<String> {
log::info!(
"[TlsDomainDetector] check domain for buffer {} bytes",
packet.len()
);
let res = match parse_tls_plaintext(packet) {
Ok(res) => res,
Err(e) => {
Expand Down
25 changes: 14 additions & 11 deletions crates/relayer/src/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ pub async fn tunnel_task(
} else {
counter!(METRICS_PROXY_CLUSTER_COUNT).increment(1);
}
log::info!(
"proxy_tunnel.wait() for checking url from {}",
proxy_tunnel.source_addr()
);
match proxy_tunnel.wait().timeout(Duration::from_secs(5)).await {
Err(_) => {
log::error!("proxy_tunnel.wait() for checking url timeout");
Expand All @@ -64,7 +68,10 @@ pub async fn tunnel_task(
_ => {}
}

log::info!("proxy_tunnel.domain(): {}", proxy_tunnel.domain());
log::info!(
"proxy_tunnel.wait() done and got domain {}",
proxy_tunnel.domain()
);
let domain = proxy_tunnel.domain().to_string();
let home_id = home_id_from_domain(&domain);
if let Some(agent_tx) = agents.get(home_id) {
Expand Down Expand Up @@ -102,36 +109,32 @@ async fn tunnel_over_cluster<'a>(
) -> Result<(), Box<dyn Error>> {
let started = Instant::now();
log::warn!(
"agent not found for domain: {} in local => finding in cluster",
"[TunnerOverCluster] agent not found for domain: {} in local => finding in cluster",
domain
);
let node_alias_id = home_id_from_domain(&domain);
let dest = node_alias_sdk
.find_alias(node_alias_id)
.await
.ok_or("NODE_ALIAS_NOT_FOUND".to_string())?;
log::info!("found agent for domain: {domain} in node {dest}");
log::info!("[TunnerOverCluster] found agent for domain: {domain} in node {dest}");
let client = make_quinn_client(socket, server_certs)?;
log::info!("connecting to agent for domain: {domain} in node {dest}");
log::info!("[TunnerOverCluster] connecting to agent for domain: {domain} in node {dest}");
let connecting = client.connect(
SocketAddr::V4(SocketAddrV4::new(dest.into(), 443)),
"cluster",
)?;
let connection = connecting.await?;
log::info!("connected to agent for domain: {domain} in node {dest}");
log::info!("[TunnerOverCluster] connected to agent for domain: {domain} in node {dest}");
let (mut send, mut recv) = connection.open_bi().await?;
if let Some(handshake) = proxy_tunnel.handshake() {
send.write_all(&(handshake.len() as u16).to_be_bytes())
.await?;
send.write_all(handshake).await?;
}
log::info!("opened bi stream to agent for domain: {domain} in node {dest}");
log::info!("[TunnerOverCluster] opened bi stream to agent for domain: {domain} in node {dest}");

histogram!(METRICS_TUNNEL_CLUSTER_HISTOGRAM)
.record(started.elapsed().as_millis() as f32 / 1000.0);

let req_buf: Vec<u8> = (&ClusterTunnelRequest {
domain: domain.clone(),
handshake: proxy_tunnel.handshake().to_vec(),
})
.into();
send.write_all(&req_buf).await?;
Expand Down

0 comments on commit fa3c018

Please sign in to comment.