diff --git a/crates/shadowsocks-service/src/net/mon_socket.rs b/crates/shadowsocks-service/src/net/mon_socket.rs index 7bf541d2fe13..8b2bebdfca2f 100644 --- a/crates/shadowsocks-service/src/net/mon_socket.rs +++ b/crates/shadowsocks-service/src/net/mon_socket.rs @@ -6,7 +6,6 @@ use shadowsocks::{ relay::{socks5::Address, udprelay::options::UdpSocketControlData}, ProxySocket, }; -use tokio::net::ToSocketAddrs; use super::flow::FlowStat; @@ -47,7 +46,7 @@ impl MonProxySocket { /// Send a UDP packet to target from proxy #[inline] - pub async fn send_to(&self, target: A, addr: &Address, payload: &[u8]) -> io::Result<()> { + pub async fn send_to(&self, target: SocketAddr, addr: &Address, payload: &[u8]) -> io::Result<()> { let n = self.socket.send_to(target, addr, payload).await?; self.flow_stat.incr_tx(n as u64); @@ -56,9 +55,9 @@ impl MonProxySocket { /// Send a UDP packet to target from proxy #[inline] - pub async fn send_to_with_ctrl( + pub async fn send_to_with_ctrl( &self, - target: A, + target: SocketAddr, addr: &Address, control: &UdpSocketControlData, payload: &[u8], diff --git a/crates/shadowsocks/src/net/udp.rs b/crates/shadowsocks/src/net/udp.rs index 3a8acc2d44b4..d1b788d658cb 100644 --- a/crates/shadowsocks/src/net/udp.rs +++ b/crates/shadowsocks/src/net/udp.rs @@ -24,7 +24,7 @@ use std::{ ))] use futures::future; use futures::ready; -use pin_project::pin_project; + #[cfg(any( target_os = "linux", target_os = "android", @@ -86,9 +86,7 @@ fn make_mtu_error(packet_size: usize, mtu: usize) -> io::Error { /// Wrappers for outbound `UdpSocket` #[derive(Debug)] -#[pin_project] pub struct UdpSocket { - #[pin] socket: tokio::net::UdpSocket, mtu: Option, } diff --git a/crates/shadowsocks/src/relay/udprelay/compat.rs b/crates/shadowsocks/src/relay/udprelay/compat.rs new file mode 100644 index 000000000000..7245ae0e31ba --- /dev/null +++ b/crates/shadowsocks/src/relay/udprelay/compat.rs @@ -0,0 +1,88 @@ +use async_trait::async_trait; +use std::{ + io::Result, + net::SocketAddr, + ops::Deref, + task::{Context, Poll}, +}; +use tokio::io::ReadBuf; + +use crate::net::UdpSocket; + +/// a trait for datagram transport that wraps around a tokio `UdpSocket` +#[async_trait] +pub trait DatagramTransport: Send + Sync + std::fmt::Debug { + async fn recv(&self, buf: &mut [u8]) -> Result; + async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)>; + + async fn send(&self, buf: &[u8]) -> Result; + async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result; + + fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll>; + fn poll_recv_from(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll>; + fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll>; + + fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll>; + fn poll_send_to(&self, cx: &mut Context<'_>, buf: &[u8], target: SocketAddr) -> Poll>; + fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll>; + + fn local_addr(&self) -> Result; + + #[cfg(unix)] + fn as_raw_fd(&self) -> std::os::fd::RawFd; +} + +#[async_trait] +impl DatagramTransport for UdpSocket { + async fn recv(&self, buf: &mut [u8]) -> Result { + UdpSocket::recv(self, buf).await + } + + async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { + UdpSocket::recv_from(self, buf).await + } + + async fn send(&self, buf: &[u8]) -> Result { + UdpSocket::send(self, buf).await + } + + async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result { + UdpSocket::send_to(self, buf, target).await + } + + fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + UdpSocket::poll_recv(self, cx, buf) + } + + fn poll_recv_from(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + UdpSocket::poll_recv_from(self, cx, buf) + } + + fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.deref().poll_recv_ready(cx) + } + + fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + UdpSocket::poll_send(self, cx, buf) + } + + fn poll_send_to(&self, cx: &mut Context<'_>, buf: &[u8], target: SocketAddr) -> Poll> { + UdpSocket::poll_send_to(self, cx, buf, target) + } + + fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.deref().poll_send_ready(cx) + } + + fn local_addr(&self) -> Result { + self.deref().local_addr() + } + + #[cfg(unix)] + fn as_raw_fd(&self) -> std::os::fd::RawFd { + use std::ops::Deref; + use std::os::fd::AsRawFd; + + self.deref().as_raw_fd() + } +} diff --git a/crates/shadowsocks/src/relay/udprelay/mod.rs b/crates/shadowsocks/src/relay/udprelay/mod.rs index dc22d0d39784..c445d77f29e8 100644 --- a/crates/shadowsocks/src/relay/udprelay/mod.rs +++ b/crates/shadowsocks/src/relay/udprelay/mod.rs @@ -50,10 +50,12 @@ use std::time::Duration; pub use self::proxy_socket::ProxySocket; +pub use compat::DatagramTransport; mod aead; #[cfg(feature = "aead-cipher-2022")] mod aead_2022; +mod compat; pub mod crypto_io; pub mod options; pub mod proxy_socket; diff --git a/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs b/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs index 5fdc647d85fa..431719713a0d 100644 --- a/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs +++ b/crates/shadowsocks/src/relay/udprelay/proxy_socket.rs @@ -12,7 +12,7 @@ use byte_string::ByteStr; use bytes::{Bytes, BytesMut}; use log::{info, trace, warn}; use once_cell::sync::Lazy; -use tokio::{io::ReadBuf, net::ToSocketAddrs, time}; +use tokio::{io::ReadBuf, time}; use crate::{ config::{ServerAddr, ServerConfig, ServerUserManager}, @@ -22,9 +22,12 @@ use crate::{ relay::{socks5::Address, udprelay::options::UdpSocketControlData}, }; -use super::crypto_io::{ - decrypt_client_payload, decrypt_server_payload, encrypt_client_payload, encrypt_server_payload, ProtocolError, - ProtocolResult, +use super::{ + compat::DatagramTransport, + crypto_io::{ + decrypt_client_payload, decrypt_server_payload, encrypt_client_payload, encrypt_server_payload, ProtocolError, + ProtocolResult, + }, }; #[cfg(unix)] @@ -72,7 +75,7 @@ pub type ProxySocketResult = Result; #[derive(Debug)] pub struct ProxySocket { socket_type: UdpSocketType, - socket: ShadowUdpSocket, + io: Box, method: CipherKind, key: Box<[u8]>, send_timeout: Option, @@ -128,11 +131,40 @@ impl ProxySocket { let key = svr_cfg.key().to_vec().into_boxed_slice(); let method = svr_cfg.method(); + // NOTE: svr_cfg.timeout() is not for this socket, but for associations. + ProxySocket { + socket_type, + io: Box::new(socket.into()), + method, + key, + send_timeout: None, + recv_timeout: None, + context, + identity_keys: match socket_type { + UdpSocketType::Client => svr_cfg.clone_identity_keys(), + UdpSocketType::Server => Arc::new(Vec::new()), + }, + user_manager: match socket_type { + UdpSocketType::Client => None, + UdpSocketType::Server => svr_cfg.clone_user_manager(), + }, + } + } + + pub fn from_io( + socket_type: UdpSocketType, + context: SharedContext, + svr_cfg: &ServerConfig, + io: Box, + ) -> ProxySocket { + let key = svr_cfg.key().to_vec().into_boxed_slice(); + let method = svr_cfg.method(); + // NOTE: svr_cfg.timeout() is not for this socket, but for associations. ProxySocket { socket_type, - socket: socket.into(), + io, method, key, send_timeout: None, @@ -241,8 +273,8 @@ impl ProxySocket { ); let send_len = match self.send_timeout { - None => self.socket.send(&send_buf).await?, - Some(d) => match time::timeout(d, self.socket.send(&send_buf)).await { + None => self.io.send(&send_buf).await?, + Some(d) => match time::timeout(d, self.io.send(&send_buf)).await { Ok(Ok(l)) => l, Ok(Err(err)) => return Err(err.into()), Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()), @@ -295,7 +327,7 @@ impl ProxySocket { let n_send_buf = send_buf.len(); - match self.socket.poll_send(cx, &send_buf).map_err(|x| x.into()) { + match self.io.poll_send(cx, &send_buf).map_err(|x| x.into()) { Poll::Ready(Ok(l)) => { if l == n_send_buf { Poll::Ready(Ok(payload.len())) @@ -340,14 +372,14 @@ impl ProxySocket { self.encrypt_send_buffer(addr, control, &self.identity_keys, payload, &mut send_buf)?; info!( - "UDP server client send to {}, payload length {} bytes, packet length {} bytes", + "UDP server client poll_send_to to {}, payload length {} bytes, packet length {} bytes", target, payload.len(), send_buf.len() ); let n_send_buf = send_buf.len(); - match self.socket.poll_send_to(cx, &send_buf, target).map_err(|x| x.into()) { + match self.io.poll_send_to(cx, &send_buf, target).map_err(|x| x.into()) { Poll::Ready(Ok(l)) => { if l == n_send_buf { Poll::Ready(Ok(payload.len())) @@ -363,25 +395,20 @@ impl ProxySocket { /// /// Check if socket is ready to `send`, or writable. pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.socket.poll_send_ready(cx).map_err(|x| x.into()) + self.io.poll_send_ready(cx).map_err(|x| x.into()) } /// Send a UDP packet to target through proxy `target` - pub async fn send_to( - &self, - target: A, - addr: &Address, - payload: &[u8], - ) -> ProxySocketResult { + pub async fn send_to(&self, target: SocketAddr, addr: &Address, payload: &[u8]) -> ProxySocketResult { self.send_to_with_ctrl(target, addr, &DEFAULT_SOCKET_CONTROL, payload) .await .map_err(Into::into) } /// Send a UDP packet to target through proxy `target` - pub async fn send_to_with_ctrl( + pub async fn send_to_with_ctrl( &self, - target: A, + target: SocketAddr, addr: &Address, control: &UdpSocketControlData, payload: &[u8], @@ -390,7 +417,7 @@ impl ProxySocket { self.encrypt_send_buffer(addr, control, &self.identity_keys, payload, &mut send_buf)?; trace!( - "UDP server client send to, addr {}, control: {:?}, payload length {} bytes, packet length {} bytes", + "UDP server client send_to to, addr {}, control: {:?}, payload length {} bytes, packet length {} bytes", addr, control, payload.len(), @@ -398,8 +425,8 @@ impl ProxySocket { ); let send_len = match self.send_timeout { - None => self.socket.send_to(&send_buf, target).await?, - Some(d) => match time::timeout(d, self.socket.send_to(&send_buf, target)).await { + None => self.io.send_to(&send_buf, target).await?, + Some(d) => match time::timeout(d, self.io.send_to(&send_buf, target)).await { Ok(Ok(l)) => l, Ok(Err(err)) => return Err(err.into()), Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()), @@ -408,7 +435,7 @@ impl ProxySocket { if send_buf.len() != send_len { warn!( - "UDP server client send {} bytes, but actually sent {} bytes", + "UDP server client send_to {} bytes, but actually sent {} bytes", send_buf.len(), send_len ); @@ -448,10 +475,9 @@ impl ProxySocket { &self, recv_buf: &mut [u8], ) -> ProxySocketResult<(usize, Address, usize, Option)> { - // Waiting for response from server SERVER -> CLIENT let recv_n = match self.recv_timeout { - None => self.socket.recv(recv_buf).await?, - Some(d) => match time::timeout(d, self.socket.recv(recv_buf)).await { + None => self.io.recv(recv_buf).await?, + Some(d) => match time::timeout(d, self.io.recv(recv_buf)).await { Ok(Ok(l)) => l, Ok(Err(err)) => return Err(err.into()), Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()), @@ -498,8 +524,8 @@ impl ProxySocket { ) -> ProxySocketResult<(usize, SocketAddr, Address, usize, Option)> { // Waiting for response from server SERVER -> CLIENT let (recv_n, target_addr) = match self.recv_timeout { - None => self.socket.recv_from(recv_buf).await?, - Some(d) => match time::timeout(d, self.socket.recv_from(recv_buf)).await { + None => self.io.recv_from(recv_buf).await?, + Some(d) => match time::timeout(d, self.io.recv_from(recv_buf)).await { Ok(Ok(l)) => l, Ok(Err(err)) => return Err(err.into()), Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()), @@ -542,7 +568,7 @@ impl ProxySocket { cx: &mut Context<'_>, recv_buf: &mut ReadBuf, ) -> Poll)>> { - ready!(self.socket.poll_recv(cx, recv_buf))?; + ready!(self.io.poll_recv(cx, recv_buf))?; let n_recv = recv_buf.filled().len(); @@ -570,7 +596,7 @@ impl ProxySocket { cx: &mut Context<'_>, recv_buf: &mut ReadBuf, ) -> Poll)>> { - let src = ready!(self.socket.poll_recv_from(cx, recv_buf))?; + let src = ready!(self.io.poll_recv_from(cx, recv_buf))?; let n_recv = recv_buf.filled().len(); match self.decrypt_recv_buffer(recv_buf.filled_mut(), self.user_manager.as_deref()) { @@ -581,12 +607,12 @@ impl ProxySocket { /// poll family functions pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.socket.poll_recv_ready(cx).map_err(|x| x.into()) + self.io.poll_recv_ready(cx).map_err(|x| x.into()) } /// Get local addr of socket pub fn local_addr(&self) -> io::Result { - self.socket.local_addr() + self.io.local_addr() } /// Set `send` timeout, `None` will clear timeout @@ -604,6 +630,6 @@ impl ProxySocket { impl AsRawFd for ProxySocket { /// Retrieve raw fd of the outbound socket fn as_raw_fd(&self) -> RawFd { - self.socket.as_raw_fd() + self.io.as_raw_fd() } }