From 643498b51af7700e330d51757d606a62ce0a9dc4 Mon Sep 17 00:00:00 2001 From: shadowWalker Date: Mon, 23 Sep 2024 21:09:33 +0700 Subject: [PATCH] fix: sometime handshake pkt is split to some small chunks (#60) Co-authored-by: Giang Minh --- crates/agent/src/lib.rs | 94 ++++++++++++++++++++--------------------- 1 file changed, 46 insertions(+), 48 deletions(-) diff --git a/crates/agent/src/lib.rs b/crates/agent/src/lib.rs index fddaf0d..4000c82 100644 --- a/crates/agent/src/lib.rs +++ b/crates/agent/src/lib.rs @@ -1,7 +1,5 @@ use std::sync::Arc; -use async_std::io::WriteExt; - use futures::{select, AsyncRead, AsyncReadExt, AsyncWrite, FutureExt}; use local_tunnel::tcp::LocalTcpTunnel; use protocol::cluster::AgentTunnelRequest; @@ -26,45 +24,29 @@ where { log::info!("sub_connection pipe to local_tunnel start"); let (mut reader1, mut writer1) = sub_connection.split(); - let mut first_pkt = [0u8; 4096]; - let (local_tunnel, first_pkt_start, first_pkt_end) = match reader1.read(&mut first_pkt).await { - Ok(first_pkt_len) => { - log::info!("first pkt size: {}", first_pkt_len); - if first_pkt_len < 2 { - log::error!("first pkt size is < 4 => close"); - return; - } - let handshake_len = u16::from_be_bytes([first_pkt[0], first_pkt[1]]) as usize; - if handshake_len + 2 > first_pkt_len { - log::error!("first pkt size is < handshake {handshake_len} + 2 => close"); + + let local_tunnel = match wait_handshake(&mut reader1).await { + Ok(handshake) => { + log::info!( + "sub_connection pipe with handshake: tls: {}, {}/{:?} ", + handshake.tls, + handshake.domain, + handshake.service + ); + if let Some(dest) = + registry.dest_for(handshake.tls, handshake.service, &handshake.domain) + { + log::info!("create tunnel to dest {}", dest); + LocalTcpTunnel::new(dest).await + } else { + log::warn!( + "dest for service {:?} tls {} domain {} not found", + handshake.service, + handshake.tls, + handshake.domain + ); return; } - match AgentTunnelRequest::try_from(&first_pkt[2..(handshake_len + 2)]) { - Ok(handshake) => { - if let Some(dest) = - registry.dest_for(handshake.tls, handshake.service, &handshake.domain) - { - log::info!("create tunnel to dest {}", dest); - ( - LocalTcpTunnel::new(dest).await, - handshake_len + 2, - first_pkt_len, - ) - } else { - log::warn!( - "dest for service {:?} tls {} domain {} not found", - handshake.service, - handshake.tls, - handshake.domain - ); - return; - } - } - Err(e) => { - log::error!("handshake parse error: {}", e); - return; - } - } } Err(e) => { log::error!("read first pkt error: {}", e); @@ -81,15 +63,6 @@ where }; let (mut reader2, mut writer2) = local_tunnel.split(); - - if let Err(e) = writer2 - .write_all(&first_pkt[first_pkt_start..first_pkt_end]) - .await - { - log::error!("write first pkt to local_tunnel error: {}", e); - return; - } - let job1 = futures::io::copy(&mut reader1, &mut writer2); let job2 = futures::io::copy(&mut reader2, &mut writer1); @@ -107,3 +80,28 @@ where } log::info!("sub_connection pipe to local_tunnel stop"); } + +pub async fn wait_handshake( + reader: &mut R, +) -> Result { + let mut len_buf = [0; 2]; + let mut data_buf = [0; 1000]; + reader + .read_exact(&mut len_buf) + .await + .map_err(|e| e.to_string())?; + let handshake_len = u16::from_be_bytes([len_buf[0], len_buf[1]]) as usize; + log::info!("first pkt size: {}", handshake_len); + if handshake_len > data_buf.len() { + return Err("Handshake package too big".to_string()); + } + + reader + .read_exact(&mut data_buf[0..handshake_len]) + .await + .map_err(|e| e.to_string())?; + + log::info!("got first pkt with size: {}", handshake_len); + + AgentTunnelRequest::try_from(&data_buf[0..handshake_len]).map_err(|e| e.to_string()) +}