diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7ff2ba5..fd22c94 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -35,10 +35,10 @@ jobs: - linux musl aarch64 - linux gnueabihf arm - linux gnueabihf armv7 - - linux gnu mips - - linux gnuabi64 mips64 - - linux gnuabi64 mips64el - - linux gnu mipsel + # - linux gnu mips + # - linux gnuabi64 mips64 + # - linux gnuabi64 mips64el + # - linux gnu mipsel - macos x64 - macos aarch64 include: @@ -72,26 +72,26 @@ jobs: rust: stable target: armv7-unknown-linux-gnueabihf cross: true - - build: linux gnu mips - os: ubuntu-latest - rust: 1.71.1 - target: mips-unknown-linux-gnu - cross: true - - build: linux gnuabi64 mips64 - os: ubuntu-latest - rust: 1.71.1 - target: mips64-unknown-linux-gnuabi64 - cross: true - - build: linux gnuabi64 mips64el - os: ubuntu-latest - rust: 1.71.1 - target: mips64el-unknown-linux-gnuabi64 - cross: true - - build: linux gnu mipsel - os: ubuntu-latest - rust: 1.71.1 - target: mipsel-unknown-linux-gnu - cross: true + # - build: linux gnu mips + # os: ubuntu-latest + # rust: 1.71.1 + # target: mips-unknown-linux-gnu + # cross: true + # - build: linux gnuabi64 mips64 + # os: ubuntu-latest + # rust: 1.71.1 + # target: mips64-unknown-linux-gnuabi64 + # cross: true + # - build: linux gnuabi64 mips64el + # os: ubuntu-latest + # rust: 1.71.1 + # target: mips64el-unknown-linux-gnuabi64 + # cross: true + # - build: linux gnu mipsel + # os: ubuntu-latest + # rust: 1.71.1 + # target: mipsel-unknown-linux-gnu + # cross: true - build: linux musl aarch64 os: ubuntu-latest rust: stable diff --git a/Cargo.lock b/Cargo.lock index 51d52a8..d4bf820 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,6 +27,8 @@ dependencies = [ "futures", "log", "protocol", + "quinn", + "rustls", "tracing-subscriber", "yamux", ] @@ -482,6 +484,16 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.6" @@ -578,6 +590,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "deranged" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" +dependencies = [ + "powerfmt", +] + [[package]] name = "digest" version = "0.10.7" @@ -1397,6 +1418,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + [[package]] name = "ordered-float" version = "3.9.2" @@ -1441,6 +1468,16 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "pem" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b8fcc794035347fb64beda2d3b462595dd2753e3f268d89c5aae77e8cf2c310" +dependencies = [ + "base64", + "serde", +] + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -1636,6 +1673,12 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1708,6 +1751,57 @@ dependencies = [ "winapi", ] +[[package]] +name = "quinn" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cc2c5017e4b43d5995dcea317bc46c1e09404c0a9664d2908f7f02dfe943d75" +dependencies = [ + "async-io 1.13.0", + "async-std", + "bytes", + "futures-io", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "thiserror", + "tokio", + "tracing", +] + +[[package]] +name = "quinn-proto" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "141bf7dfde2fbc246bfd3fe12f2455aa24b0fbd9af535d8c86c7bd1381ff2b1a" +dependencies = [ + "bytes", + "rand", + "ring 0.16.20", + "rustc-hash", + "rustls", + "rustls-native-certs", + "slab", + "thiserror", + "tinyvec", + "tracing", +] + +[[package]] +name = "quinn-udp" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "055b4e778e8feb9f93c4e439f71dc2156ef13360b432b799e179a8c4cdf0b1d7" +dependencies = [ + "bytes", + "libc", + "socket2 0.5.5", + "tracing", + "windows-sys 0.48.0", +] + [[package]] name = "quote" version = "1.0.33" @@ -1786,6 +1880,18 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rcgen" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d918c80c5a4c7560db726763020bd16db179e4d5b828078842274a443addb5d" +dependencies = [ + "pem", + "ring 0.17.7", + "time", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -1853,6 +1959,9 @@ dependencies = [ "metrics-dashboard", "poem", "protocol", + "quinn", + "rcgen", + "rustls", "tls-parser", "tracing-subscriber", "yamux", @@ -1867,6 +1976,35 @@ dependencies = [ "uncased", ] +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin 0.5.2", + "untrusted 0.7.1", + "web-sys", + "winapi", +] + +[[package]] +name = "ring" +version = "0.17.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" +dependencies = [ + "cc", + "getrandom", + "libc", + "spin 0.9.8", + "untrusted 0.9.0", + "windows-sys 0.48.0", +] + [[package]] name = "rust-embed" version = "8.1.0" @@ -1907,6 +2045,12 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.4.0" @@ -1952,6 +2096,48 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" +dependencies = [ + "ring 0.17.7", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring 0.17.7", + "untrusted 0.9.0", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -1973,12 +2159,31 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" +dependencies = [ + "windows-sys 0.48.0", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring 0.17.7", + "untrusted 0.9.0", +] + [[package]] name = "sealed" version = "0.5.0" @@ -1991,6 +2196,29 @@ dependencies = [ "syn 2.0.41", ] +[[package]] +name = "security-framework" +version = "2.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.20" @@ -2136,6 +2364,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "spki" version = "0.7.3" @@ -2231,6 +2471,39 @@ dependencies = [ "once_cell", ] +[[package]] +name = "time" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4a34ab300f2dee6e562c10a046fc05e358b29f9bf92277f30c3c8d82275f6f5" +dependencies = [ + "deranged", + "powerfmt", + "serde", + "time-core", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tls-parser" version = "0.11.0" @@ -2316,6 +2589,7 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2407,6 +2681,18 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "utf8parse" version = "0.2.1" @@ -2732,6 +3018,15 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "zerocopy" version = "0.7.31" diff --git a/README.md b/README.md index c1b2408..e350d64 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,13 @@ To get started with the Decentralized HomeAssistant Proxy, follow these steps: 3. Run the server: ```shell - cargo run --release + cargo run -- --root-domain local.ha.8xff.io + ``` + +4. Run the client: + + ```shell + cargo run --release -- --connector-addr 127.0.0.1:33333 --connector-protocol quic --http-dest 127.0.0.1:18080 --https-dest 127.0.0.1:18443 ``` ## Contributing diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index dc2f2b7..f017cc0 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -14,9 +14,11 @@ log = "0.4.20" tracing-subscriber = { version = "0.3.18", features = ["env-filter", "std"] } yamux = "0.13.1" protocol = { path = "../protocol" } +quinn = { version = "0.10.2", default-features = false, features = ["native-certs", "tls-rustls", "log", "runtime-async-std", "futures-io", "ring"] } +rustls = { version = "0.21.0", default-features = false, features = ["quic", "dangerous_configuration"] } [profile.release] strip = true # Automatically strip symbols from the binary. opt-level = "z" # Optimize for size. lto = true -codegen-units = 1 \ No newline at end of file +codegen-units = 1 diff --git a/crates/agent/src/connection.rs b/crates/agent/src/connection.rs index bc9558a..043de07 100644 --- a/crates/agent/src/connection.rs +++ b/crates/agent/src/connection.rs @@ -1,7 +1,10 @@ //! Tunnel is a trait that defines the interface for a tunnel which connect to connector port of relayer. +use std::error::Error; + use futures::{AsyncRead, AsyncWrite}; +pub mod quic; pub mod tcp; pub trait SubConnection: Send + Sync { @@ -12,5 +15,5 @@ pub trait SubConnection: Send + Syn pub trait Connection, R: AsyncRead + Unpin, W: AsyncWrite + Unpin>: Send + Sync { - async fn recv(&mut self) -> Option; + async fn recv(&mut self) -> Result>; } diff --git a/crates/agent/src/connection/quic.rs b/crates/agent/src/connection/quic.rs new file mode 100644 index 0000000..73a072a --- /dev/null +++ b/crates/agent/src/connection/quic.rs @@ -0,0 +1,106 @@ +use std::sync::Arc; +use std::time::Duration; +use std::{error::Error, net::SocketAddr}; + +use protocol::{key::LocalKey, rpc::RegisterResponse}; +use quinn::{ClientConfig, Endpoint, RecvStream, SendStream, TransportConfig}; + +use super::{Connection, SubConnection}; + +pub struct QuicSubConnection { + pub send: SendStream, + pub recv: RecvStream, +} + +impl SubConnection for QuicSubConnection { + fn split(self) -> (RecvStream, SendStream) { + (self.recv, self.send) + } +} + +pub struct QuicConnection { + connection: quinn::Connection, + #[allow(unused)] + domain: String, +} + +impl QuicConnection { + pub async fn new(dest: SocketAddr, local_key: &LocalKey) -> Result> { + let mut endpoint = Endpoint::client("0.0.0.0:0".parse().expect(""))?; + endpoint.set_default_client_config(configure_client()); + + // connect to server + let connection = endpoint.connect(dest, "localhost")?.await?; + + log::info!("connected to {}, open bi stream", dest); + let (mut send_stream, mut recv_stream) = connection.open_bi().await?; + log::info!("opened bi stream, send register request"); + + let request = local_key.to_request(); + let request_buf: Vec = (&request).into(); + send_stream.write_all(&request_buf).await?; + + let mut buf = [0u8; 4096]; + let buf_len = recv_stream + .read(&mut buf) + .await? + .ok_or::>("read register response error".into())?; + let response = RegisterResponse::try_from(&buf[..buf_len])?; + match response.response { + Ok(domain) => { + log::info!("registed domain {}", domain); + Ok(Self { connection, domain }) + } + Err(e) => { + log::error!("register response error {}", e); + return Err(e.into()); + } + } + } +} + +#[async_trait::async_trait] +impl Connection for QuicConnection { + async fn recv(&mut self) -> Result> { + let (send, recv) = self.connection.accept_bi().await?; + Ok(QuicSubConnection { send, recv }) + } +} + +fn configure_client() -> ClientConfig { + let crypto = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(SkipServerVerification::new()) + .with_no_client_auth(); + + let mut transport = TransportConfig::default(); + transport.keep_alive_interval(Some(Duration::from_secs(5))); + + let mut config = ClientConfig::new(Arc::new(crypto) as Arc<_>); + config.transport_config(Arc::new(transport)); + config +} + +/// Dummy certificate verifier that treats any certificate as valid. +/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing. +struct SkipServerVerification; + +impl SkipServerVerification { + fn new() -> Arc { + Arc::new(Self) + } +} + +impl rustls::client::ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &rustls::Certificate, + _intermediates: &[rustls::Certificate], + _server_name: &rustls::ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: std::time::SystemTime, + ) -> Result { + Ok(rustls::client::ServerCertVerified::assertion()) + } +} diff --git a/crates/agent/src/connection/tcp.rs b/crates/agent/src/connection/tcp.rs index 9353a24..4fa7546 100644 --- a/crates/agent/src/connection/tcp.rs +++ b/crates/agent/src/connection/tcp.rs @@ -1,4 +1,5 @@ use std::{ + error::Error, net::SocketAddr, pin::Pin, task::{Context, Poll}, @@ -35,26 +36,26 @@ pub struct TcpConnection { } impl TcpConnection { - pub async fn new(dest: SocketAddr, local_key: &LocalKey) -> Option { - let mut stream = TcpStream::connect(dest).await.ok()?; + pub async fn new(dest: SocketAddr, local_key: &LocalKey) -> Result> { + let mut stream = TcpStream::connect(dest).await?; let request = local_key.to_request(); let request_buf: Vec = (&request).into(); - stream.write_all(&request_buf).await.ok()?; + stream.write_all(&request_buf).await?; let mut buf = [0u8; 4096]; - let buf_len = stream.read(&mut buf).await.ok()?; - let response = RegisterResponse::try_from(&buf[..buf_len]).ok()?; + let buf_len = stream.read(&mut buf).await?; + let response = RegisterResponse::try_from(&buf[..buf_len])?; match response.response { Ok(domain) => { - log::info!("registered domain {}", domain); - Some(Self { + log::info!("registed domain {}", domain); + Ok(Self { conn: yamux::Connection::new(stream, Default::default(), Mode::Server), domain, }) } Err(e) => { log::error!("register response error {}", e); - return None; + return Err(e.into()); } } } @@ -64,11 +65,12 @@ impl TcpConnection { impl Connection, WriteHalf> for TcpConnection { - async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Result> { let mux_server = YamuxConnectionServer::new(&mut self.conn); match mux_server.await { - Ok(Some(stream)) => Some(TcpSubConnection::new(stream)), - _ => None, + Ok(Some(stream)) => Ok(TcpSubConnection::new(stream)), + Ok(None) => Err("yamux server poll next inbound return None".into()), + Err(e) => Err(e.into()), } } } diff --git a/crates/agent/src/local_tunnel/tcp.rs b/crates/agent/src/local_tunnel/tcp.rs index ff30a06..19a509b 100644 --- a/crates/agent/src/local_tunnel/tcp.rs +++ b/crates/agent/src/local_tunnel/tcp.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::{error::Error, net::SocketAddr}; use async_std::net::TcpStream; use futures::{ @@ -13,9 +13,9 @@ pub struct LocalTcpTunnel { } impl LocalTcpTunnel { - pub async fn new(dest: SocketAddr) -> Option { - Some(Self { - stream: TcpStream::connect(dest).await.ok()?, + pub async fn new(dest: SocketAddr) -> Result> { + Ok(Self { + stream: TcpStream::connect(dest).await?, }) } } diff --git a/crates/agent/src/main.rs b/crates/agent/src/main.rs index 3dae3e0..9656623 100644 --- a/crates/agent/src/main.rs +++ b/crates/agent/src/main.rs @@ -6,28 +6,39 @@ static A: System = System; use std::net::SocketAddr; use async_std::io::WriteExt; -use clap::Parser; +use clap::{Parser, ValueEnum}; -use futures::{select, AsyncReadExt, FutureExt}; +use connection::tcp::TcpConnection; +use futures::{select, AsyncRead, AsyncReadExt, AsyncWrite, FutureExt}; use local_tunnel::tcp::LocalTcpTunnel; use protocol::key::LocalKey; use tracing_subscriber::{fmt, layer::*, util::SubscriberInitExt, EnvFilter}; use crate::{ - connection::{tcp::TcpConnection, Connection, SubConnection}, + connection::{quic::QuicConnection, Connection, SubConnection}, local_tunnel::LocalTunnel, }; mod connection; mod local_tunnel; +#[derive(ValueEnum, Debug, Clone)] +enum Protocol { + Tcp, + Quic, +} + /// A HTTP and SNI HTTPs proxy for expose your local service to the internet. #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Address of relay server - #[arg(env, long, long, default_value = "127.0.0.1:33333")] - tcp_connector_addr: SocketAddr, + #[arg(env, long)] + connector_addr: SocketAddr, + + /// Protocol of relay server + #[arg(env, long)] + connector_protocol: Protocol, /// Http proxy dest #[arg(env, long, default_value = "127.0.0.1:8080")] @@ -86,60 +97,120 @@ async fn main() { }; loop { - log::info!("Connecting to connector..."); - if let Some(mut connection) = TcpConnection::new(args.tcp_connector_addr, &local_key).await - { - log::info!("Connection to connector is established"); - while let Some(sub_connection) = connection.recv().await { - let http_dest = args.http_dest; - let https_dest = args.https_dest; - async_std::task::spawn(async move { - log::info!("sub_connection pipe to local_tunnel start"); - let (mut reader1, mut writer1) = sub_connection.split(); - let mut first_pkt = [0u8; 4096]; - let (local_tunnel, first_pkt_len) = match reader1.read(&mut first_pkt).await { - Ok(first_pkt_len) => { - log::info!("first pkt size: {}", first_pkt_len); - if first_pkt_len == 0 { - log::error!("first pkt size is 0 => close"); - return; - } - if first_pkt[0] == 0x16 { - log::info!("create tunnel to https dest {}", https_dest); - (LocalTcpTunnel::new(https_dest).await, first_pkt_len) - } else { - log::info!("create tunnel to http dest {}", http_dest); - (LocalTcpTunnel::new(http_dest).await, first_pkt_len) - } - } - Err(e) => { - log::error!("read first pkt error: {}", e); - return; - } - }; - - if let Some(local_tunnel) = local_tunnel { - let (mut reader2, mut writer2) = local_tunnel.split(); - - if let Err(e) = writer2.write_all(&first_pkt[..first_pkt_len]).await { - log::error!("write first pkt to local_tunnel error: {}", e); - return; - } - - let job1 = futures::io::copy(&mut reader1, &mut writer2); - let job2 = futures::io::copy(&mut reader2, &mut writer1); - - select! { - _ = job1.fuse() => {} - _ = job2.fuse() => {} - } - } - log::info!("sub_connection pipe to local_tunnel stop"); - }); - } - log::warn!("Connection to connector is closed, try to reconnect..."); + log::info!( + "Connecting to connector... {:?} addr: {:?}", + args.connector_protocol, + args.connector_addr + ); + match args.connector_protocol { + Protocol::Tcp => match TcpConnection::new(args.connector_addr, &local_key).await { + Ok(conn) => { + log::info!("Connected to connector via tcp"); + run_loop(conn, args.http_dest, args.https_dest).await; + } + Err(e) => { + log::error!("Connect to connector via tcp error: {}", e); + } + }, + Protocol::Quic => match QuicConnection::new(args.connector_addr, &local_key).await { + Ok(conn) => { + log::info!("Connected to connector via quic"); + run_loop(conn, args.http_dest, args.https_dest).await; + } + Err(e) => { + log::error!("Connect to connector via quic error: {}", e); + } + }, } //TODO exponential backoff async_std::task::sleep(std::time::Duration::from_secs(1)).await; } } + +async fn run_loop( + mut connection: impl Connection, + http_dest: SocketAddr, + https_dest: SocketAddr, +) where + S: SubConnection + 'static, + R: AsyncRead + Send + Unpin + 'static, + W: AsyncWrite + Send + Unpin + 'static, +{ + log::info!("Connection to connector is established"); + loop { + match connection.recv().await { + Ok(sub_connection) => { + log::info!("recv sub_connection"); + async_std::task::spawn_local(run_connection(sub_connection, http_dest, https_dest)); + } + Err(e) => { + log::error!("recv sub_connection error: {}", e); + break; + } + } + } + log::warn!("Connection to connector is closed, try to reconnect..."); +} + +async fn run_connection(sub_connection: S, http_dest: SocketAddr, https_dest: SocketAddr) +where + S: SubConnection + 'static, + R: AsyncRead + Send + Unpin, + W: AsyncWrite + Send + Unpin, +{ + log::info!("sub_connection pipe to local_tunnel start"); + let (mut reader1, mut writer1) = sub_connection.split(); + let mut first_pkt = [0u8; 4096]; + let (local_tunnel, first_pkt_len) = match reader1.read(&mut first_pkt).await { + Ok(first_pkt_len) => { + log::info!("first pkt size: {}", first_pkt_len); + if first_pkt_len == 0 { + log::error!("first pkt size is 0 => close"); + return; + } + if first_pkt[0] == 0x16 { + log::info!("create tunnel to https dest {}", https_dest); + (LocalTcpTunnel::new(https_dest).await, first_pkt_len) + } else { + log::info!("create tunnel to http dest {}", http_dest); + (LocalTcpTunnel::new(http_dest).await, first_pkt_len) + } + } + Err(e) => { + log::error!("read first pkt error: {}", e); + return; + } + }; + + let local_tunnel = match local_tunnel { + Ok(local_tunnel) => local_tunnel, + Err(e) => { + log::error!("create local_tunnel error: {}", e); + return; + } + }; + + let (mut reader2, mut writer2) = local_tunnel.split(); + + if let Err(e) = writer2.write_all(&first_pkt[..first_pkt_len]).await { + log::error!("write first pkt to local_tunnel error: {}", e); + return; + } + + let job1 = futures::io::copy(&mut reader1, &mut writer2); + let job2 = futures::io::copy(&mut reader2, &mut writer1); + + select! { + e = job1.fuse() => { + if let Err(e) = e { + log::error!("job1 error: {}", e); + } + } + e = job2.fuse() => { + if let Err(e) = e { + log::error!("job2 error: {}", e); + } + } + } + log::info!("sub_connection pipe to local_tunnel stop"); +} diff --git a/crates/relayer/Cargo.toml b/crates/relayer/Cargo.toml index 17814e4..7ad5b93 100644 --- a/crates/relayer/Cargo.toml +++ b/crates/relayer/Cargo.toml @@ -17,7 +17,10 @@ protocol = { path = "../protocol" } metrics-dashboard = { version = "0.1.3", features = ["system"], optional = true } poem = { version = "1.3.59", optional = true } metrics = { version = "0.21.1" } +quinn = { version = "0.10.2", default-features = false, features = ["native-certs", "tls-rustls", "log", "runtime-async-std", "futures-io", "ring"] } +rustls = { version = "0.21.0", default-features = false, features = ["quic", "dangerous_configuration"] } +rcgen = "0.12.0" [features] default = [] -expose-metrics = ["metrics-dashboard", "poem"] \ No newline at end of file +expose-metrics = ["metrics-dashboard", "poem"] diff --git a/crates/relayer/src/agent_listener.rs b/crates/relayer/src/agent_listener.rs index b8e8519..57a6a5d 100644 --- a/crates/relayer/src/agent_listener.rs +++ b/crates/relayer/src/agent_listener.rs @@ -1,7 +1,10 @@ //! Connector is server which accept connection from agent and wait msg from user. +use std::error::Error; + use futures::{AsyncRead, AsyncWrite}; +pub mod quic; pub mod tcp; pub trait AgentSubConnection: Send + Sync { @@ -13,8 +16,8 @@ pub trait AgentConnection, R: AsyncRead + Unpin, W: Send + Sync { fn domain(&self) -> String; - async fn create_sub_connection(&mut self) -> Option; - async fn recv(&mut self) -> Option<()>; + async fn create_sub_connection(&mut self) -> Result>; + async fn recv(&mut self) -> Result<(), Box>; } #[async_trait::async_trait] @@ -25,5 +28,5 @@ pub trait AgentListener< W: AsyncWrite + Unpin, >: Send + Sync { - async fn recv(&mut self) -> Option; + async fn recv(&mut self) -> Result>; } diff --git a/crates/relayer/src/agent_listener/quic.rs b/crates/relayer/src/agent_listener/quic.rs new file mode 100644 index 0000000..de9245d --- /dev/null +++ b/crates/relayer/src/agent_listener/quic.rs @@ -0,0 +1,145 @@ +use std::{error::Error, net::SocketAddr, sync::Arc}; + +use protocol::{ + key::validate_request, + rpc::{RegisterRequest, RegisterResponse}, +}; +use quinn::{Endpoint, RecvStream, SendStream, ServerConfig}; + +use super::{AgentConnection, AgentListener, AgentSubConnection}; + +pub struct AgentQuicListener { + endpoint: Endpoint, + root_domain: String, +} + +impl AgentQuicListener { + pub async fn new(addr: SocketAddr, root_domain: String) -> Self { + log::info!("AgentQuicListener::new {}", addr); + let (endpoint, _server_cert) = + make_server_endpoint(addr).expect("Should make server endpoint"); + + Self { + endpoint, + root_domain, + } + } + + async fn process_incoming_conn( + &self, + conn: quinn::Connection, + ) -> Result> { + let (mut send, mut recv) = conn.accept_bi().await?; + let mut buf = [0u8; 4096]; + let buf_len = recv + .read(&mut buf) + .await? + .ok_or::>("No incomming data".into())?; + + match RegisterRequest::try_from(&buf[..buf_len]) { + Ok(request) => { + let response = if let Some(sub_domain) = validate_request(&request) { + log::info!("register request domain {}", sub_domain); + Ok(format!("{}.{}", sub_domain, self.root_domain)) + } else { + log::error!("invalid register request {:?}", request); + Err(String::from("invalid request")) + }; + + let res = RegisterResponse { response }; + let res_buf: Vec = (&res).into(); + send.write_all(&res_buf).await?; + + let domain = res.response?; + Ok(AgentQuicConnection { domain, conn }) + } + Err(e) => { + log::error!("register request error {:?}", e); + Err(e.into()) + } + } + } +} + +#[async_trait::async_trait] +impl AgentListener + for AgentQuicListener +{ + async fn recv(&mut self) -> Result> { + loop { + let incoming_conn = self + .endpoint + .accept() + .await + .ok_or::>("Cannot accept".into())?; + let conn: quinn::Connection = incoming_conn.await?; + log::info!( + "[AgentQuicListener] new conn from {}", + conn.remote_address() + ); + match self.process_incoming_conn(conn).await { + Ok(connection) => { + log::info!("new connection {}", connection.domain()); + return Ok(connection); + } + Err(e) => { + log::error!("process_incoming_conn error: {}", e); + } + } + } + } +} + +pub struct AgentQuicConnection { + domain: String, + conn: quinn::Connection, +} + +#[async_trait::async_trait] +impl AgentConnection for AgentQuicConnection { + fn domain(&self) -> String { + self.domain.clone() + } + + async fn create_sub_connection(&mut self) -> Result> { + let (send, recv) = self.conn.open_bi().await?; + Ok(AgentQuicSubConnection { send, recv }) + } + + async fn recv(&mut self) -> Result<(), Box> { + self.conn.read_datagram().await?; + Ok(()) + } +} + +pub struct AgentQuicSubConnection { + send: SendStream, + recv: RecvStream, +} + +impl AgentSubConnection for AgentQuicSubConnection { + fn split(self) -> (RecvStream, SendStream) { + (self.recv, self.send) + } +} + +fn make_server_endpoint(bind_addr: SocketAddr) -> Result<(Endpoint, Vec), Box> { + let (server_config, server_cert) = configure_server()?; + let endpoint = Endpoint::server(server_config, bind_addr)?; + Ok((endpoint, server_cert)) +} + +/// Returns default server configuration along with its certificate. +fn configure_server() -> Result<(ServerConfig, Vec), Box> { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert_der = cert.serialize_der().unwrap(); + let priv_key = cert.serialize_private_key_der(); + let priv_key = rustls::PrivateKey(priv_key); + let cert_chain = vec![rustls::Certificate(cert_der.clone())]; + + let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?; + let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); + transport_config.max_concurrent_uni_streams(0_u8.into()); + + Ok((server_config, cert_der)) +} diff --git a/crates/relayer/src/agent_listener/tcp.rs b/crates/relayer/src/agent_listener/tcp.rs index de19c4d..050fcca 100644 --- a/crates/relayer/src/agent_listener/tcp.rs +++ b/crates/relayer/src/agent_listener/tcp.rs @@ -1,5 +1,6 @@ use std::{ - net::{Ipv4Addr, SocketAddr}, + error::Error, + net::SocketAddr, pin::Pin, task::{Context, Poll}, }; @@ -22,19 +23,20 @@ pub struct AgentTcpListener { } impl AgentTcpListener { - pub async fn new(port: u16, root_domain: String) -> Option { - log::info!("AgentTcpListener::new {}", port); - Some(Self { - tcp_listener: TcpListener::bind(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), port)) - .await - .ok()?, + pub async fn new(addr: SocketAddr, root_domain: String) -> Self { + log::info!("AgentTcpListener::new {}", addr); + Self { + tcp_listener: TcpListener::bind(addr).await.expect("Should open"), root_domain, - }) + } } - async fn process_incoming_stream(&self, mut stream: TcpStream) -> Option { + async fn process_incoming_stream( + &self, + mut stream: TcpStream, + ) -> Result> { let mut buf = [0u8; 4096]; - let buf_len = stream.read(&mut buf).await.ok()?; + let buf_len = stream.read(&mut buf).await?; match RegisterRequest::try_from(&buf[..buf_len]) { Ok(request) => { @@ -50,11 +52,11 @@ impl AgentTcpListener { let res_buf: Vec = (&res).into(); if let Err(e) = stream.write_all(&res_buf).await { log::error!("register response error {:?}", e); - return None; + return Err(e.into()); } - let domain = res.response.ok()?; - Some(AgentTcpConnection { + let domain = res.response?; + Ok(AgentTcpConnection { domain, connector: yamux::Connection::new( stream, @@ -65,7 +67,7 @@ impl AgentTcpListener { } Err(e) => { log::error!("register request error {:?}", e); - None + Err(e.into()) } } } @@ -80,12 +82,18 @@ impl WriteHalf, > for AgentTcpListener { - async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Result> { loop { - let (stream, remote) = self.tcp_listener.accept().await.ok()?; + let (stream, remote) = self.tcp_listener.accept().await?; log::info!("[AgentTcpListener] new conn from {}", remote); - if let Some(connection) = self.process_incoming_stream(stream).await { - return Some(connection); + match self.process_incoming_stream(stream).await { + Ok(connection) => { + log::info!("new connection {}", connection.domain()); + return Ok(connection); + } + Err(e) => { + log::error!("process_incoming_stream error: {}", e); + } } } } @@ -104,16 +112,16 @@ impl AgentConnection, WriteHalf Option { + async fn create_sub_connection(&mut self) -> Result> { let client = OpenStreamsClient { connection: &mut self.connector, }; - Some(AgentTcpSubConnection { - stream: client.await.ok()?, + Ok(AgentTcpSubConnection { + stream: client.await?, }) } - async fn recv(&mut self) -> Option<()> { + async fn recv(&mut self) -> Result<(), Box> { RecvStreamsClient { connection: &mut self.connector, } @@ -162,14 +170,16 @@ impl<'a, T> Future for RecvStreamsClient<'a, T> where T: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug, { - type Output = Option<()>; + type Output = Result<(), Box>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); match this.connection.poll_next_inbound(cx) { - Poll::Ready(Some(Ok(_stream))) => return Poll::Ready(Some(())), - Poll::Ready(Some(Err(_e))) => return Poll::Ready(None), - Poll::Ready(None) => return Poll::Ready(None), + Poll::Ready(Some(Ok(stream))) => return Poll::Ready(Ok(())), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e.into())), + Poll::Ready(None) => { + return Poll::Ready(Err("yamux server poll next inbound return None".into())) + } Poll::Pending => Poll::Pending, } } diff --git a/crates/relayer/src/agent_worker.rs b/crates/relayer/src/agent_worker.rs index 82ce5ce..63bc614 100644 --- a/crates/relayer/src/agent_worker.rs +++ b/crates/relayer/src/agent_worker.rs @@ -1,6 +1,6 @@ -use std::marker::PhantomData; +use std::{error::Error, marker::PhantomData}; -use futures::{select, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt}; +use futures::{select, AsyncRead, AsyncWrite, FutureExt}; use metrics::increment_gauge; use crate::{ @@ -45,12 +45,12 @@ where ) } - pub async fn run(&mut self) -> Option<()> { + pub async fn run(&mut self) -> Result<(), Box> { let incoming = select! { - incoming = self.rx.recv().fuse() => incoming.ok()?, + incoming = self.rx.recv().fuse() => incoming?, e = self.connection.recv().fuse() => { e?; - return Some(()); + return Ok(()); } }; let sub_connection = self.connection.create_sub_connection().await?; @@ -58,29 +58,28 @@ where increment_gauge!(crate::METRICS_PROXY_LIVE, 1.0); let domain = incoming.domain().to_string(); log::info!("start proxy tunnel for domain {}", domain); - let first_pkt = incoming.first_pkt(); let (mut reader1, mut writer1) = sub_connection.split(); - let success = if first_pkt.len() > 0 { - writer1.write_all(first_pkt).await.is_ok() - } else { - true - }; + let (mut reader2, mut writer2) = incoming.split(); - if success { - let (mut reader2, mut writer2) = incoming.split(); + let job1 = futures::io::copy(&mut reader1, &mut writer2); + let job2 = futures::io::copy(&mut reader2, &mut writer1); - let job1 = futures::io::copy(&mut reader1, &mut writer2); - let job2 = futures::io::copy(&mut reader2, &mut writer1); - - select! { - _ = job1.fuse() => {} - _ = job2.fuse() => {} + select! { + e = job1.fuse() => { + if let Err(e) = e { + log::info!("agent => proxy error: {}", e); + } + } + e = job2.fuse() => { + if let Err(e) = e { + log::info!("proxy => agent error: {}", e); + } } } log::info!("end proxy tunnel for domain {}", domain); increment_gauge!(crate::METRICS_PROXY_LIVE, -1.0); }); - Some(()) + Ok(()) } } diff --git a/crates/relayer/src/main.rs b/crates/relayer/src/main.rs index 5c794b5..91e2f7c 100644 --- a/crates/relayer/src/main.rs +++ b/crates/relayer/src/main.rs @@ -3,11 +3,12 @@ use clap::Parser; use metrics_dashboard::build_dashboard_route; #[cfg(feature = "expose-metrics")] use poem::{listener::TcpListener, middleware::Tracing, EndpointExt as _, Route, Server}; -use std::{collections::HashMap, process::exit, sync::Arc}; +use std::{collections::HashMap, net::SocketAddr, process::exit, sync::Arc, time::Duration}; use agent_listener::tcp::AgentTcpListener; -use async_std::sync::RwLock; -use futures::{select, FutureExt}; +use agent_listener::{quic::AgentQuicListener, AgentSubConnection}; +use async_std::{prelude::FutureExt as _, sync::RwLock}; +use futures::{select, AsyncRead, AsyncWrite, FutureExt}; use metrics::{ decrement_gauge, describe_counter, describe_gauge, increment_counter, increment_gauge, }; @@ -45,8 +46,8 @@ struct Args { https_port: u16, /// Number of times to greet - #[arg(env, long, default_value_t = 33333)] - tcp_connector_port: u16, + #[arg(env, long, default_value = "0.0.0.0:33333")] + connector_port: SocketAddr, /// Root domain #[arg(env, long, default_value = "localtunnel.me")] @@ -65,9 +66,9 @@ async fn main() { .with(fmt::layer()) .with(EnvFilter::from_default_env()) .init(); - let mut agent_listener = AgentTcpListener::new(args.tcp_connector_port, args.root_domain) - .await - .expect("Should listen agent port"); + let mut quic_agent_listener = + AgentQuicListener::new(args.connector_port, args.root_domain.clone()).await; + let mut tcp_agent_listener = AgentTcpListener::new(args.connector_port, args.root_domain).await; let mut proxy_http_listener = ProxyHttpListener::new(args.http_port, false) .await .expect("Should listen http port"); @@ -96,41 +97,27 @@ async fn main() { loop { select! { - e = agent_listener.recv().fuse() => match e { - Some(agent_connection) => { - increment_counter!(METRICS_AGENT_COUNT); - log::info!("agent_connection.domain(): {}", agent_connection.domain()); - let domain = agent_connection.domain().to_string(); - let (mut agent_worker, proxy_tunnel_tx) = agent_worker::AgentWorker::new(agent_connection); - agents.write().await.insert(domain.clone(), proxy_tunnel_tx); - let agents = agents.clone(); - async_std::task::spawn(async move { - increment_gauge!(METRICS_AGENT_LIVE, 1.0); - log::info!("agent_worker run for domain: {}", domain); - while let Some(_) = agent_worker.run().await {} - agents.write().await.remove(&domain); - log::info!("agent_worker exit for domain: {}", domain); - decrement_gauge!(METRICS_AGENT_LIVE, 1.0); - }); + e = quic_agent_listener.recv().fuse() => match e { + Ok(agent_connection) => { + run_agent_connection(agent_connection, agents.clone()).await; } - None => { - log::error!("agent_listener error"); + Err(e) => { + log::error!("agent_listener error {}", e); + exit(1); + } + }, + e = tcp_agent_listener.recv().fuse() => match e { + Ok(agent_connection) => { + run_agent_connection(agent_connection, agents.clone()).await; + } + Err(e) => { + log::error!("agent_listener error {}", e); exit(1); } }, e = proxy_http_listener.recv().fuse() => match e { - Some(mut proxy_tunnel) => { - if proxy_tunnel.wait().await.is_none() { - continue; - } - increment_counter!(METRICS_PROXY_COUNT); - log::info!("proxy_tunnel.domain(): {}", proxy_tunnel.domain()); - let domain = proxy_tunnel.domain().to_string(); - if let Some(agent_tx) = agents.read().await.get(&domain) { - agent_tx.send(proxy_tunnel).await.ok(); - } else { - log::warn!("agent not found for domain: {}", domain); - } + Some(proxy_tunnel) => { + async_std::task::spawn(run_http_request(proxy_tunnel, agents.clone())); } None => { log::error!("proxy_http_listener.recv()"); @@ -138,17 +125,8 @@ async fn main() { } }, e = proxy_tls_listener.recv().fuse() => match e { - Some(mut proxy_tunnel) => { - if proxy_tunnel.wait().await.is_none() { - continue; - } - log::info!("proxy_tunnel.domain(): {}", proxy_tunnel.domain()); - let domain = proxy_tunnel.domain().to_string(); - if let Some(agent_tx) = agents.read().await.get(&domain) { - agent_tx.send(proxy_tunnel).await.ok(); - } else { - log::warn!("agent not found for domain: {}", domain); - } + Some(proxy_tunnel) => { + async_std::task::spawn(run_http_request(proxy_tunnel, agents.clone())); } None => { log::error!("proxy_http_listener.recv()"); @@ -158,3 +136,69 @@ async fn main() { } } } + +async fn run_agent_connection( + agent_connection: AG, + agents: Arc>>>, +) where + AG: AgentConnection + 'static, + S: AgentSubConnection + 'static, + R: AsyncRead + Send + Unpin + 'static, + W: AsyncWrite + Send + Unpin + 'static, + PT: ProxyTunnel + 'static, + PR: AsyncRead + Send + Unpin + 'static, + PW: AsyncWrite + Send + Unpin + 'static, +{ + increment_counter!(METRICS_AGENT_COUNT); + log::info!("agent_connection.domain(): {}", agent_connection.domain()); + let domain = agent_connection.domain().to_string(); + let (mut agent_worker, proxy_tunnel_tx) = + agent_worker::AgentWorker::::new(agent_connection); + agents.write().await.insert(domain.clone(), proxy_tunnel_tx); + let agents = agents.clone(); + async_std::task::spawn(async move { + increment_gauge!(METRICS_AGENT_LIVE, 1.0); + log::info!("agent_worker run for domain: {}", domain); + loop { + match agent_worker.run().await { + Ok(()) => {} + Err(e) => { + log::error!("agent_worker error: {}", e); + break; + } + } + } + agents.write().await.remove(&domain); + log::info!("agent_worker exit for domain: {}", domain); + decrement_gauge!(METRICS_AGENT_LIVE, 1.0); + }); +} + +async fn run_http_request( + mut proxy_tunnel: PT, + agents: Arc>>>, +) where + PT: ProxyTunnel + 'static, + PR: AsyncRead + Send + Unpin + 'static, + PW: AsyncWrite + Send + Unpin + 'static, +{ + match proxy_tunnel.wait().timeout(Duration::from_secs(5)).await { + Err(_) => { + log::error!("proxy_tunnel.wait() for checking url timeout"); + return; + } + Ok(None) => { + log::error!("proxy_tunnel.wait() for checking url invalid"); + return; + } + _ => {} + } + increment_counter!(METRICS_PROXY_COUNT); + log::info!("proxy_tunnel.domain(): {}", proxy_tunnel.domain()); + let domain = proxy_tunnel.domain().to_string(); + if let Some(agent_tx) = agents.read().await.get(&domain) { + agent_tx.send(proxy_tunnel).await.ok(); + } else { + log::warn!("agent not found for domain: {}", domain); + } +} diff --git a/crates/relayer/src/proxy_listener.rs b/crates/relayer/src/proxy_listener.rs index db4b628..ba64a9c 100644 --- a/crates/relayer/src/proxy_listener.rs +++ b/crates/relayer/src/proxy_listener.rs @@ -6,7 +6,6 @@ pub mod http; #[async_trait::async_trait] pub trait ProxyTunnel: Send + Sync { - fn first_pkt(&self) -> &[u8]; async fn wait(&mut self) -> Option<()>; fn domain(&self) -> &str; fn split(self) -> (R, W); diff --git a/crates/relayer/src/proxy_listener/http.rs b/crates/relayer/src/proxy_listener/http.rs index 187fc7c..03f4104 100644 --- a/crates/relayer/src/proxy_listener/http.rs +++ b/crates/relayer/src/proxy_listener/http.rs @@ -34,8 +34,6 @@ impl ProxyListener, WriteHalf> let (stream, remote) = self.tcp_listener.accept().await.ok()?; log::info!("[ProxyHttpListener] new conn from {}", remote); Some(ProxyHttpTunnel { - first_pkt: vec![0u8; 4096], - first_pkt_size: 0, domain: "demo".to_string(), stream, tls: self.tls, @@ -44,8 +42,6 @@ impl ProxyListener, WriteHalf> } pub struct ProxyHttpTunnel { - first_pkt: Vec, - first_pkt_size: usize, domain: String, stream: TcpStream, tls: bool, @@ -53,22 +49,20 @@ pub struct ProxyHttpTunnel { #[async_trait::async_trait] impl ProxyTunnel, WriteHalf> for ProxyHttpTunnel { - fn first_pkt(&self) -> &[u8] { - &self.first_pkt[..self.first_pkt_size] - } - async fn wait(&mut self) -> Option<()> { - self.first_pkt_size = self.stream.read(&mut self.first_pkt).await.ok()?; + log::info!("[ProxyHttpTunnel] wait first data for checking url..."); + let mut first_pkt = [0u8; 4096]; + let first_pkt_size = self.stream.peek(&mut first_pkt).await.ok()?; log::info!( "[ProxyHttpTunnel] read {} bytes for determine url", - self.first_pkt_size + first_pkt_size ); if self.tls { - self.domain = get_sni_from_packet(&self.first_pkt[..self.first_pkt_size])?; + self.domain = get_sni_from_packet(&first_pkt[..first_pkt_size])?; } else { let mut headers = [httparse::EMPTY_HEADER; 64]; let mut req = httparse::Request::new(&mut headers); - let _ = req.parse(&self.first_pkt[..self.first_pkt_size]).ok()?; + let _ = req.parse(&first_pkt[..first_pkt_size]).ok()?; let domain = req.headers.iter().find(|h| h.name == "Host")?.value; // dont get the port let domain = String::from_utf8_lossy(domain).to_string();