Skip to content

Commit

Permalink
take arbitrary IO for udp proxy (#1641)
Browse files Browse the repository at this point in the history
* udp relay

* refact

* reset

* Update crates/shadowsocks/src/relay/udprelay/proxy_socket.rs

* export trait
  • Loading branch information
ibigbug authored Sep 20, 2024
1 parent e691853 commit 4e29581
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 41 deletions.
7 changes: 3 additions & 4 deletions crates/shadowsocks-service/src/net/mon_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use shadowsocks::{
relay::{socks5::Address, udprelay::options::UdpSocketControlData},
ProxySocket,
};
use tokio::net::ToSocketAddrs;

use super::flow::FlowStat;

Expand Down Expand Up @@ -47,7 +46,7 @@ impl MonProxySocket {

/// Send a UDP packet to target from proxy
#[inline]
pub async fn send_to<A: ToSocketAddrs>(&self, target: A, addr: &Address, payload: &[u8]) -> io::Result<()> {
pub async fn send_to(&self, target: SocketAddr, addr: &Address, payload: &[u8]) -> io::Result<()> {
let n = self.socket.send_to(target, addr, payload).await?;
self.flow_stat.incr_tx(n as u64);

Expand All @@ -56,9 +55,9 @@ impl MonProxySocket {

/// Send a UDP packet to target from proxy
#[inline]
pub async fn send_to_with_ctrl<A: ToSocketAddrs>(
pub async fn send_to_with_ctrl(
&self,
target: A,
target: SocketAddr,
addr: &Address,
control: &UdpSocketControlData,
payload: &[u8],
Expand Down
4 changes: 1 addition & 3 deletions crates/shadowsocks/src/net/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::{
))]
use futures::future;
use futures::ready;
use pin_project::pin_project;

#[cfg(any(
target_os = "linux",
target_os = "android",
Expand Down Expand Up @@ -86,9 +86,7 @@ fn make_mtu_error(packet_size: usize, mtu: usize) -> io::Error {

/// Wrappers for outbound `UdpSocket`
#[derive(Debug)]
#[pin_project]
pub struct UdpSocket {
#[pin]
socket: tokio::net::UdpSocket,
mtu: Option<usize>,
}
Expand Down
88 changes: 88 additions & 0 deletions crates/shadowsocks/src/relay/udprelay/compat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use async_trait::async_trait;
use std::{
io::Result,
net::SocketAddr,
ops::Deref,
task::{Context, Poll},
};
use tokio::io::ReadBuf;

use crate::net::UdpSocket;

/// a trait for datagram transport that wraps around a tokio `UdpSocket`
#[async_trait]
pub trait DatagramTransport: Send + Sync + std::fmt::Debug {
async fn recv(&self, buf: &mut [u8]) -> Result<usize>;
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)>;

async fn send(&self, buf: &[u8]) -> Result<usize>;
async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize>;

fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>>;
fn poll_recv_from(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<SocketAddr>>;
fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;

fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>;
fn poll_send_to(&self, cx: &mut Context<'_>, buf: &[u8], target: SocketAddr) -> Poll<Result<usize>>;
fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;

fn local_addr(&self) -> Result<SocketAddr>;

#[cfg(unix)]
fn as_raw_fd(&self) -> std::os::fd::RawFd;
}

#[async_trait]
impl DatagramTransport for UdpSocket {
async fn recv(&self, buf: &mut [u8]) -> Result<usize> {
UdpSocket::recv(self, buf).await
}

async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
UdpSocket::recv_from(self, buf).await
}

async fn send(&self, buf: &[u8]) -> Result<usize> {
UdpSocket::send(self, buf).await
}

async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize> {
UdpSocket::send_to(self, buf, target).await
}

fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
UdpSocket::poll_recv(self, cx, buf)
}

fn poll_recv_from(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<SocketAddr>> {
UdpSocket::poll_recv_from(self, cx, buf)
}

fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.deref().poll_recv_ready(cx)
}

fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
UdpSocket::poll_send(self, cx, buf)
}

fn poll_send_to(&self, cx: &mut Context<'_>, buf: &[u8], target: SocketAddr) -> Poll<Result<usize>> {
UdpSocket::poll_send_to(self, cx, buf, target)
}

fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.deref().poll_send_ready(cx)
}

fn local_addr(&self) -> Result<SocketAddr> {
self.deref().local_addr()
}

#[cfg(unix)]
fn as_raw_fd(&self) -> std::os::fd::RawFd {
use std::ops::Deref;
use std::os::fd::AsRawFd;

self.deref().as_raw_fd()
}
}
2 changes: 2 additions & 0 deletions crates/shadowsocks/src/relay/udprelay/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@
use std::time::Duration;

pub use self::proxy_socket::ProxySocket;
pub use compat::DatagramTransport;

mod aead;
#[cfg(feature = "aead-cipher-2022")]
mod aead_2022;
mod compat;
pub mod crypto_io;
pub mod options;
pub mod proxy_socket;
Expand Down
94 changes: 60 additions & 34 deletions crates/shadowsocks/src/relay/udprelay/proxy_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use byte_string::ByteStr;
use bytes::{Bytes, BytesMut};
use log::{info, trace, warn};
use once_cell::sync::Lazy;
use tokio::{io::ReadBuf, net::ToSocketAddrs, time};
use tokio::{io::ReadBuf, time};

use crate::{
config::{ServerAddr, ServerConfig, ServerUserManager},
Expand All @@ -22,9 +22,12 @@ use crate::{
relay::{socks5::Address, udprelay::options::UdpSocketControlData},
};

use super::crypto_io::{
decrypt_client_payload, decrypt_server_payload, encrypt_client_payload, encrypt_server_payload, ProtocolError,
ProtocolResult,
use super::{
compat::DatagramTransport,
crypto_io::{
decrypt_client_payload, decrypt_server_payload, encrypt_client_payload, encrypt_server_payload, ProtocolError,
ProtocolResult,
},
};

#[cfg(unix)]
Expand Down Expand Up @@ -72,7 +75,7 @@ pub type ProxySocketResult<T> = Result<T, ProxySocketError>;
#[derive(Debug)]
pub struct ProxySocket {
socket_type: UdpSocketType,
socket: ShadowUdpSocket,
io: Box<dyn DatagramTransport>,
method: CipherKind,
key: Box<[u8]>,
send_timeout: Option<Duration>,
Expand Down Expand Up @@ -128,11 +131,40 @@ impl ProxySocket {
let key = svr_cfg.key().to_vec().into_boxed_slice();
let method = svr_cfg.method();

// NOTE: svr_cfg.timeout() is not for this socket, but for associations.
ProxySocket {
socket_type,
io: Box::new(socket.into()),
method,
key,
send_timeout: None,
recv_timeout: None,
context,
identity_keys: match socket_type {
UdpSocketType::Client => svr_cfg.clone_identity_keys(),
UdpSocketType::Server => Arc::new(Vec::new()),
},
user_manager: match socket_type {
UdpSocketType::Client => None,
UdpSocketType::Server => svr_cfg.clone_user_manager(),
},
}
}

pub fn from_io(
socket_type: UdpSocketType,
context: SharedContext,
svr_cfg: &ServerConfig,
io: Box<dyn DatagramTransport>,
) -> ProxySocket {
let key = svr_cfg.key().to_vec().into_boxed_slice();
let method = svr_cfg.method();

// NOTE: svr_cfg.timeout() is not for this socket, but for associations.

ProxySocket {
socket_type,
socket: socket.into(),
io,
method,
key,
send_timeout: None,
Expand Down Expand Up @@ -241,8 +273,8 @@ impl ProxySocket {
);

let send_len = match self.send_timeout {
None => self.socket.send(&send_buf).await?,
Some(d) => match time::timeout(d, self.socket.send(&send_buf)).await {
None => self.io.send(&send_buf).await?,
Some(d) => match time::timeout(d, self.io.send(&send_buf)).await {
Ok(Ok(l)) => l,
Ok(Err(err)) => return Err(err.into()),
Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()),
Expand Down Expand Up @@ -295,7 +327,7 @@ impl ProxySocket {

let n_send_buf = send_buf.len();

match self.socket.poll_send(cx, &send_buf).map_err(|x| x.into()) {
match self.io.poll_send(cx, &send_buf).map_err(|x| x.into()) {
Poll::Ready(Ok(l)) => {
if l == n_send_buf {
Poll::Ready(Ok(payload.len()))
Expand Down Expand Up @@ -340,14 +372,14 @@ impl ProxySocket {
self.encrypt_send_buffer(addr, control, &self.identity_keys, payload, &mut send_buf)?;

info!(
"UDP server client send to {}, payload length {} bytes, packet length {} bytes",
"UDP server client poll_send_to to {}, payload length {} bytes, packet length {} bytes",
target,
payload.len(),
send_buf.len()
);

let n_send_buf = send_buf.len();
match self.socket.poll_send_to(cx, &send_buf, target).map_err(|x| x.into()) {
match self.io.poll_send_to(cx, &send_buf, target).map_err(|x| x.into()) {
Poll::Ready(Ok(l)) => {
if l == n_send_buf {
Poll::Ready(Ok(payload.len()))
Expand All @@ -363,25 +395,20 @@ impl ProxySocket {
///
/// Check if socket is ready to `send`, or writable.
pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<ProxySocketResult<()>> {
self.socket.poll_send_ready(cx).map_err(|x| x.into())
self.io.poll_send_ready(cx).map_err(|x| x.into())
}

/// Send a UDP packet to target through proxy `target`
pub async fn send_to<A: ToSocketAddrs>(
&self,
target: A,
addr: &Address,
payload: &[u8],
) -> ProxySocketResult<usize> {
pub async fn send_to(&self, target: SocketAddr, addr: &Address, payload: &[u8]) -> ProxySocketResult<usize> {
self.send_to_with_ctrl(target, addr, &DEFAULT_SOCKET_CONTROL, payload)
.await
.map_err(Into::into)
}

/// Send a UDP packet to target through proxy `target`
pub async fn send_to_with_ctrl<A: ToSocketAddrs>(
pub async fn send_to_with_ctrl(
&self,
target: A,
target: SocketAddr,
addr: &Address,
control: &UdpSocketControlData,
payload: &[u8],
Expand All @@ -390,16 +417,16 @@ impl ProxySocket {
self.encrypt_send_buffer(addr, control, &self.identity_keys, payload, &mut send_buf)?;

trace!(
"UDP server client send to, addr {}, control: {:?}, payload length {} bytes, packet length {} bytes",
"UDP server client send_to to, addr {}, control: {:?}, payload length {} bytes, packet length {} bytes",
addr,
control,
payload.len(),
send_buf.len()
);

let send_len = match self.send_timeout {
None => self.socket.send_to(&send_buf, target).await?,
Some(d) => match time::timeout(d, self.socket.send_to(&send_buf, target)).await {
None => self.io.send_to(&send_buf, target).await?,
Some(d) => match time::timeout(d, self.io.send_to(&send_buf, target)).await {
Ok(Ok(l)) => l,
Ok(Err(err)) => return Err(err.into()),
Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()),
Expand All @@ -408,7 +435,7 @@ impl ProxySocket {

if send_buf.len() != send_len {
warn!(
"UDP server client send {} bytes, but actually sent {} bytes",
"UDP server client send_to {} bytes, but actually sent {} bytes",
send_buf.len(),
send_len
);
Expand Down Expand Up @@ -448,10 +475,9 @@ impl ProxySocket {
&self,
recv_buf: &mut [u8],
) -> ProxySocketResult<(usize, Address, usize, Option<UdpSocketControlData>)> {
// Waiting for response from server SERVER -> CLIENT
let recv_n = match self.recv_timeout {
None => self.socket.recv(recv_buf).await?,
Some(d) => match time::timeout(d, self.socket.recv(recv_buf)).await {
None => self.io.recv(recv_buf).await?,
Some(d) => match time::timeout(d, self.io.recv(recv_buf)).await {
Ok(Ok(l)) => l,
Ok(Err(err)) => return Err(err.into()),
Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()),
Expand Down Expand Up @@ -498,8 +524,8 @@ impl ProxySocket {
) -> ProxySocketResult<(usize, SocketAddr, Address, usize, Option<UdpSocketControlData>)> {
// Waiting for response from server SERVER -> CLIENT
let (recv_n, target_addr) = match self.recv_timeout {
None => self.socket.recv_from(recv_buf).await?,
Some(d) => match time::timeout(d, self.socket.recv_from(recv_buf)).await {
None => self.io.recv_from(recv_buf).await?,
Some(d) => match time::timeout(d, self.io.recv_from(recv_buf)).await {
Ok(Ok(l)) => l,
Ok(Err(err)) => return Err(err.into()),
Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()),
Expand Down Expand Up @@ -542,7 +568,7 @@ impl ProxySocket {
cx: &mut Context<'_>,
recv_buf: &mut ReadBuf,
) -> Poll<ProxySocketResult<(usize, Address, usize, Option<UdpSocketControlData>)>> {
ready!(self.socket.poll_recv(cx, recv_buf))?;
ready!(self.io.poll_recv(cx, recv_buf))?;

let n_recv = recv_buf.filled().len();

Expand Down Expand Up @@ -570,7 +596,7 @@ impl ProxySocket {
cx: &mut Context<'_>,
recv_buf: &mut ReadBuf,
) -> Poll<ProxySocketResult<(usize, SocketAddr, Address, usize, Option<UdpSocketControlData>)>> {
let src = ready!(self.socket.poll_recv_from(cx, recv_buf))?;
let src = ready!(self.io.poll_recv_from(cx, recv_buf))?;

let n_recv = recv_buf.filled().len();
match self.decrypt_recv_buffer(recv_buf.filled_mut(), self.user_manager.as_deref()) {
Expand All @@ -581,12 +607,12 @@ impl ProxySocket {

/// poll family functions
pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<ProxySocketResult<()>> {
self.socket.poll_recv_ready(cx).map_err(|x| x.into())
self.io.poll_recv_ready(cx).map_err(|x| x.into())
}

/// Get local addr of socket
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.socket.local_addr()
self.io.local_addr()
}

/// Set `send` timeout, `None` will clear timeout
Expand All @@ -604,6 +630,6 @@ impl ProxySocket {
impl AsRawFd for ProxySocket {
/// Retrieve raw fd of the outbound socket
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
self.io.as_raw_fd()
}
}

0 comments on commit 4e29581

Please sign in to comment.