From bccdb9a1b5284b51dc53ba4ae99b2e3b118143df Mon Sep 17 00:00:00 2001 From: hubertshelley Date: Mon, 6 Jan 2025 16:34:59 +0800 Subject: [PATCH 1/2] feat: Unix support --- .../custom_tokio_unix_listener/Cargo.toml | 14 ++++ .../custom_tokio_unix_listener/src/main.rs | 66 +++++++++++++++ silent/src/core/listener.rs | 80 +++++++++++++++++++ silent/src/core/mod.rs | 3 + silent/src/core/request.rs | 6 +- silent/src/core/socket_addr.rs | 70 ++++++++++++++++ silent/src/core/stream.rs | 61 ++++++++++++++ silent/src/prelude.rs | 2 +- silent/src/service/hyper_service.rs | 7 +- silent/src/service/mod.rs | 50 ++++++++---- silent/src/service/serve.rs | 14 ++-- 11 files changed, 343 insertions(+), 30 deletions(-) create mode 100644 examples/custom_tokio_unix_listener/Cargo.toml create mode 100644 examples/custom_tokio_unix_listener/src/main.rs create mode 100644 silent/src/core/listener.rs create mode 100644 silent/src/core/socket_addr.rs create mode 100644 silent/src/core/stream.rs diff --git a/examples/custom_tokio_unix_listener/Cargo.toml b/examples/custom_tokio_unix_listener/Cargo.toml new file mode 100644 index 0000000..200f193 --- /dev/null +++ b/examples/custom_tokio_unix_listener/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "example-custom_tokio_unix_listener" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +silent = { path = "../../silent" } +http-body-util = "0.1" +hyper = { version = "1.0.0", features = ["full"] } +hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] } +tokio = { version = "1.0", features = ["full"] } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/custom_tokio_unix_listener/src/main.rs b/examples/custom_tokio_unix_listener/src/main.rs new file mode 100644 index 0000000..70ce298 --- /dev/null +++ b/examples/custom_tokio_unix_listener/src/main.rs @@ -0,0 +1,66 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-custom_tokio_unix_listener +//! ``` +#[cfg(unix)] +#[tokio::main] +async fn main() { + unix::server().await; +} + +#[cfg(not(unix))] +fn main() { + println!("This example requires unix") +} + +#[cfg(unix)] +mod unix { + use http_body_util::BodyExt; + use hyper_util::rt::TokioIo; + use silent::prelude::*; + use silent::prelude::{logger, HandlerAppend, Level, Route, Server}; + use std::time::Duration; + use tokio::net::{UnixListener, UnixStream}; + + pub async fn server() { + logger::fmt().with_max_level(Level::INFO).init(); + let listener_path = "./examples/custom_tokio_unix_listener/custom_handler.sock"; + + tokio::spawn(async move { + let route = Route::new("").get(handler); + let listener = UnixListener::bind(listener_path).unwrap(); + + Server::new().listen(listener).serve(route).await; + // Server::new().bind_unix(listener_path).serve(route).await; + }); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let stream = TokioIo::new(UnixStream::connect(listener_path).await.unwrap()); + let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap(); + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + let request = Request::empty(); + + let response = sender.send_request(request.into_http()).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert_eq!(body, "Hello, World!"); + + let _ = tokio::fs::remove_file(listener_path).await; + } + + async fn handler(req: Request) -> Result<&'static str> { + println!("new connection from `{:?}`", req.remote()); + + Ok("Hello, World!") + } +} diff --git a/silent/src/core/listener.rs b/silent/src/core/listener.rs new file mode 100644 index 0000000..e961d17 --- /dev/null +++ b/silent/src/core/listener.rs @@ -0,0 +1,80 @@ +use super::socket_addr::SocketAddr; +use super::stream::Stream; + +pub enum Listener { + TcpListener(std::net::TcpListener), + UnixListener(std::os::unix::net::UnixListener), + TokioListener(tokio::net::TcpListener), + TokioUnixListener(tokio::net::UnixListener), +} + +impl From for Listener { + fn from(listener: std::net::TcpListener) -> Self { + Listener::TcpListener(listener) + } +} + +impl From for Listener { + fn from(value: std::os::unix::net::UnixListener) -> Self { + Listener::UnixListener(value) + } +} + +impl From for Listener { + fn from(listener: tokio::net::TcpListener) -> Self { + Listener::TokioListener(listener) + } +} + +impl From for Listener { + fn from(value: tokio::net::UnixListener) -> Self { + Listener::TokioUnixListener(value) + } +} + +impl Listener { + pub async fn accept(&self) -> std::io::Result<(Stream, SocketAddr)> { + match self { + Listener::TcpListener(listener) => { + let (stream, addr) = listener.accept()?; + Ok(( + Stream::TcpStream(tokio::net::TcpStream::from_std(stream)?), + SocketAddr::TcpSocketAddr(addr), + )) + } + Listener::UnixListener(listener) => { + let (stream, addr) = listener.accept()?; + Ok(( + Stream::UnixStream(tokio::net::UnixStream::from_std(stream)?), + SocketAddr::UnixSocketAddr(addr), + )) + } + Listener::TokioListener(listener) => { + let (stream, addr) = listener.accept().await?; + Ok((Stream::TcpStream(stream), SocketAddr::TcpSocketAddr(addr))) + } + Listener::TokioUnixListener(listener) => { + let (stream, addr) = listener.accept().await?; + Ok(( + Stream::UnixStream(stream), + SocketAddr::UnixSocketAddr(addr.into()), + )) + } + } + } + + pub fn local_addr(&self) -> std::io::Result { + match self { + Listener::TcpListener(listener) => listener.local_addr().map(SocketAddr::TcpSocketAddr), + Listener::UnixListener(listener) => { + Ok(SocketAddr::UnixSocketAddr(listener.local_addr()?)) + } + Listener::TokioListener(listener) => { + listener.local_addr().map(SocketAddr::TcpSocketAddr) + } + Listener::TokioUnixListener(listener) => { + Ok(SocketAddr::UnixSocketAddr(listener.local_addr()?.into())) + } + } + } +} diff --git a/silent/src/core/mod.rs b/silent/src/core/mod.rs index 447e661..393850f 100644 --- a/silent/src/core/mod.rs +++ b/silent/src/core/mod.rs @@ -2,6 +2,7 @@ pub mod adapt; #[cfg(feature = "multipart")] pub(crate) mod form; +pub(crate) mod listener; pub(crate) mod next; pub(crate) mod path_param; pub(crate) mod req_body; @@ -10,3 +11,5 @@ pub(crate) mod res_body; pub(crate) mod response; #[allow(dead_code)] mod serde; +pub(crate) mod socket_addr; +pub(crate) mod stream; diff --git a/silent/src/core/request.rs b/silent/src/core/request.rs index 1ca2ac7..90c9de2 100644 --- a/silent/src/core/request.rs +++ b/silent/src/core/request.rs @@ -4,6 +4,7 @@ use crate::core::path_param::PathParam; use crate::core::req_body::ReqBody; #[cfg(feature = "multipart")] use crate::core::serde::from_str_multi_val; +use crate::core::socket_addr::SocketAddr; use crate::header::CONTENT_TYPE; use crate::{Configs, Result, SilentError}; use bytes::Bytes; @@ -16,7 +17,6 @@ use serde::de::StdError; use serde::Deserialize; use serde_json::Value; use std::collections::HashMap; -use std::net::{IpAddr, SocketAddr}; use tokio::sync::OnceCell; use url::form_urlencoded; @@ -134,7 +134,7 @@ impl Request { /// 获取访问真实地址 #[inline] - pub fn remote(&self) -> IpAddr { + pub fn remote(&self) -> SocketAddr { self.headers() .get("x-real-ip") .and_then(|h| h.to_str().ok()) @@ -148,7 +148,7 @@ impl Request { pub fn set_remote(&mut self, remote_addr: SocketAddr) { if self.headers().get("x-real-ip").is_none() { self.headers_mut() - .insert("x-real-ip", remote_addr.ip().to_string().parse().unwrap()); + .insert("x-real-ip", remote_addr.to_string().parse().unwrap()); } } diff --git a/silent/src/core/socket_addr.rs b/silent/src/core/socket_addr.rs new file mode 100644 index 0000000..2c1bc3e --- /dev/null +++ b/silent/src/core/socket_addr.rs @@ -0,0 +1,70 @@ +use std::fmt::{Display, Formatter}; +use std::str::FromStr; + +#[derive(Clone, Debug)] +pub enum SocketAddr { + TcpSocketAddr(std::net::SocketAddr), + UnixSocketAddr(std::os::unix::net::SocketAddr), +} + +impl Display for SocketAddr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + #[allow(clippy::write_literal)] + SocketAddr::TcpSocketAddr(addr) => write!(f, "http{}//{:?}", ':', addr), + SocketAddr::UnixSocketAddr(addr) => { + write!(f, "{:?}", addr.as_pathname()) + } + } + } +} + +impl From for SocketAddr { + fn from(addr: std::net::SocketAddr) -> Self { + SocketAddr::TcpSocketAddr(addr) + } +} + +impl From for SocketAddr { + fn from(addr: std::os::unix::net::SocketAddr) -> Self { + SocketAddr::UnixSocketAddr(addr) + } +} + +impl FromStr for SocketAddr { + type Err = std::io::Error; + + fn from_str(s: &str) -> Result { + if let Ok(addr) = s.parse::() { + Ok(SocketAddr::TcpSocketAddr(addr)) + } else if let Ok(addr) = std::os::unix::net::SocketAddr::from_pathname(s) { + Ok(SocketAddr::UnixSocketAddr(addr)) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "invalid socket address", + )) + } + } +} + +#[cfg(test)] +mod tests { + use crate::core::socket_addr::SocketAddr; + use std::path::Path; + + #[test] + fn test_socket_addr() { + let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080)); + let socket_addr = SocketAddr::from(addr); + assert_eq!(format!("{}", socket_addr), "http://127.0.0.1:8080"); + + let _ = std::fs::remove_file("/tmp/sock"); + let addr = std::os::unix::net::SocketAddr::from_pathname("/tmp/sock").unwrap(); + let socket_addr = SocketAddr::from(addr); + assert_eq!( + format!("{}", socket_addr), + format!("{:?}", Some(Path::new("/tmp/sock"))) + ); + } +} diff --git a/silent/src/core/stream.rs b/silent/src/core/stream.rs new file mode 100644 index 0000000..705f3d0 --- /dev/null +++ b/silent/src/core/stream.rs @@ -0,0 +1,61 @@ +use crate::core::socket_addr::SocketAddr; +use std::io; +use std::io::Error; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::{TcpStream, UnixStream}; + +pub enum Stream { + TcpStream(TcpStream), + UnixStream(UnixStream), +} + +impl Stream { + pub fn peer_addr(&self) -> io::Result { + match self { + Stream::TcpStream(s) => Ok(s.peer_addr()?.into()), + Stream::UnixStream(s) => Ok(SocketAddr::UnixSocketAddr(s.peer_addr()?.into())), + } + } +} + +impl AsyncRead for Stream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Stream::TcpStream(s) => Pin::new(s).poll_read(cx, buf), + Stream::UnixStream(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for Stream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Stream::TcpStream(s) => Pin::new(s).poll_write(cx, buf), + Stream::UnixStream(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Stream::TcpStream(s) => Pin::new(s).poll_flush(cx), + Stream::UnixStream(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Stream::TcpStream(s) => Pin::new(s).poll_shutdown(cx), + Stream::UnixStream(s) => Pin::new(s).poll_shutdown(cx), + } + } +} diff --git a/silent/src/prelude.rs b/silent/src/prelude.rs index dbaf0bd..884cedc 100644 --- a/silent/src/prelude.rs +++ b/silent/src/prelude.rs @@ -5,7 +5,7 @@ pub use crate::cookie::cookie_ext::CookieExt; pub use crate::core::form::{FilePart, FormData}; pub use crate::core::{ next::Next, path_param::PathParam, req_body::ReqBody, request::Request, res_body::full, - res_body::stream_body, res_body::ResBody, response::Response, + res_body::stream_body, res_body::ResBody, response::Response, listener::Listener, stream::Stream, }; pub use crate::error::{SilentError, SilentResult as Result}; #[cfg(feature = "grpc")] diff --git a/silent/src/service/hyper_service.rs b/silent/src/service/hyper_service.rs index e7b9983..9b520f7 100644 --- a/silent/src/service/hyper_service.rs +++ b/silent/src/service/hyper_service.rs @@ -1,10 +1,10 @@ use std::future::Future; -use std::net::SocketAddr; use std::pin::Pin; use hyper::service::Service as HyperService; use hyper::{Request as HyperRequest, Response as HyperResponse}; +use crate::core::socket_addr::SocketAddr; use crate::core::{adapt::RequestAdapt, adapt::ResponseAdapt, res_body::ResBody}; use crate::prelude::ReqBody; use crate::{Handler, Request, Response}; @@ -65,7 +65,10 @@ mod tests { #[tokio::test] async fn test_handle_request() { // Arrange - let remote_addr = "127.0.0.1:8080".parse().unwrap(); + let remote_addr = "127.0.0.1:8080" + .parse::() + .unwrap() + .into(); let routes = RootRoute::new(); // Assuming RootRoute::new() creates a new instance of RootRoute let hsh = HyperServiceHandler::new(remote_addr, routes); let req = hyper::Request::builder().body(()).unwrap(); // Assuming Request::new() creates a new instance of Request diff --git a/silent/src/service/mod.rs b/silent/src/service/mod.rs index 548d812..6f881b1 100644 --- a/silent/src/service/mod.rs +++ b/silent/src/service/mod.rs @@ -1,18 +1,21 @@ mod hyper_service; mod serve; +use crate::core::listener::Listener; use crate::route::RouteService; use crate::service::serve::Serve; use crate::Configs; #[cfg(feature = "scheduler")] use crate::Scheduler; use std::net::SocketAddr; -use tokio::net::TcpListener; +use tokio::net::{TcpListener, UnixListener}; use tokio::signal; use tokio::task::JoinSet; + pub struct Server { addr: Option, - listener: Option, + path: Option, + listener: Option, shutdown_callback: Option>, configs: Option, } @@ -27,6 +30,7 @@ impl Server { pub fn new() -> Self { Self { addr: None, + path: None, listener: None, shutdown_callback: None, configs: None, @@ -52,8 +56,14 @@ impl Server { } #[inline] - pub fn listen(mut self, listener: TcpListener) -> Self { - self.listener = Some(listener); + pub fn bind_unix>(mut self, path: P) -> Self { + self.path = Some(path.into()); + self + } + + #[inline] + pub fn listen>(mut self, listener: T) -> Self { + self.listener = Some(listener.into()); self } @@ -73,25 +83,33 @@ impl Server { listener, configs, addr, + path, .. } = self; let listener = match listener { - None => match addr { - None => TcpListener::bind("127.0.0.1:0") - .await - .expect("failed to listen"), - Some(addr) => TcpListener::bind(addr) - .await - .unwrap_or_else(|_| panic!("failed to listen {}", addr)), + None => match (addr, path.clone()) { + (None, None) => Listener::TokioListener( + TcpListener::bind("127.0.0.1:0") + .await + .expect("failed to listen"), + ), + (Some(addr), _) => Listener::TokioListener( + TcpListener::bind(addr) + .await + .unwrap_or_else(|_| panic!("failed to listen {}", addr)), + ), + (None, Some(path)) => { + let _ = tokio::fs::remove_file(&path).await; + Listener::TokioUnixListener( + UnixListener::bind(path.clone()) + .unwrap_or_else(|_| panic!("failed to listen {}", path)), + ) + } }, Some(listener) => listener, }; - tracing::info!( - "listening on: http{}//{}", - ":", - listener.local_addr().unwrap() - ); + tracing::info!("listening on: {:?}", listener.local_addr().unwrap()); let mut root_route = service.route(); root_route.set_configs(configs.clone()); #[cfg(feature = "session")] diff --git a/silent/src/service/serve.rs b/silent/src/service/serve.rs index 6a999dd..c91ddb2 100644 --- a/silent/src/service/serve.rs +++ b/silent/src/service/serve.rs @@ -1,12 +1,10 @@ -use std::error::Error as StdError; -use std::net::SocketAddr; - -use hyper_util::rt::{TokioExecutor, TokioIo}; -use hyper_util::server::conn::auto::Builder; -use tokio::net::TcpStream; - +use crate::core::socket_addr::SocketAddr; +use crate::core::stream::Stream; use crate::route::RootRoute; use crate::service::hyper_service::HyperServiceHandler; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder; +use std::error::Error as StdError; pub(crate) struct Serve { pub(crate) routes: RootRoute, @@ -22,7 +20,7 @@ impl Serve { } pub(crate) async fn call( &self, - stream: TcpStream, + stream: Stream, peer_addr: SocketAddr, ) -> Result<(), Box> { let io = TokioIo::new(stream); From aa79f966bb122a3b40f4a93229afa53cacf91947 Mon Sep 17 00:00:00 2001 From: hubertshelley Date: Mon, 6 Jan 2025 17:26:26 +0800 Subject: [PATCH 2/2] code fmt --- silent/src/prelude.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/silent/src/prelude.rs b/silent/src/prelude.rs index 884cedc..cb3cc79 100644 --- a/silent/src/prelude.rs +++ b/silent/src/prelude.rs @@ -4,8 +4,8 @@ pub use crate::cookie::cookie_ext::CookieExt; #[cfg(feature = "multipart")] pub use crate::core::form::{FilePart, FormData}; pub use crate::core::{ - next::Next, path_param::PathParam, req_body::ReqBody, request::Request, res_body::full, - res_body::stream_body, res_body::ResBody, response::Response, listener::Listener, stream::Stream, + listener::Listener, next::Next, path_param::PathParam, req_body::ReqBody, request::Request, + res_body::full, res_body::stream_body, res_body::ResBody, response::Response, stream::Stream, }; pub use crate::error::{SilentError, SilentResult as Result}; #[cfg(feature = "grpc")]