Skip to content

Commit

Permalink
Merge pull request #78 from hubertshelley/features/unix_support
Browse files Browse the repository at this point in the history
feat: Unix support
  • Loading branch information
hubertshelley authored Jan 6, 2025
2 parents 2ba8f34 + aa79f96 commit 2609ce7
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 31 deletions.
14 changes: 14 additions & 0 deletions examples/custom_tokio_unix_listener/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "example-custom_tokio_unix_listener"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
silent = { path = "../../silent" }
http-body-util = "0.1"
hyper = { version = "1.0.0", features = ["full"] }
hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] }
tokio = { version = "1.0", features = ["full"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
66 changes: 66 additions & 0 deletions examples/custom_tokio_unix_listener/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//! Run with
//!
//! ```not_rust
//! cargo run -p example-custom_tokio_unix_listener
//! ```
#[cfg(unix)]
#[tokio::main]
async fn main() {
unix::server().await;
}

#[cfg(not(unix))]
fn main() {
println!("This example requires unix")
}

#[cfg(unix)]
mod unix {
use http_body_util::BodyExt;
use hyper_util::rt::TokioIo;
use silent::prelude::*;
use silent::prelude::{logger, HandlerAppend, Level, Route, Server};
use std::time::Duration;
use tokio::net::{UnixListener, UnixStream};

pub async fn server() {
logger::fmt().with_max_level(Level::INFO).init();
let listener_path = "./examples/custom_tokio_unix_listener/custom_handler.sock";

tokio::spawn(async move {
let route = Route::new("").get(handler);
let listener = UnixListener::bind(listener_path).unwrap();

Server::new().listen(listener).serve(route).await;
// Server::new().bind_unix(listener_path).serve(route).await;
});

tokio::time::sleep(Duration::from_secs(1)).await;

let stream = TokioIo::new(UnixStream::connect(listener_path).await.unwrap());
let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});

let request = Request::empty();

let response = sender.send_request(request.into_http()).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);

let body = response.collect().await.unwrap().to_bytes();
let body = String::from_utf8(body.to_vec()).unwrap();
assert_eq!(body, "Hello, World!");

let _ = tokio::fs::remove_file(listener_path).await;
}

async fn handler(req: Request) -> Result<&'static str> {
println!("new connection from `{:?}`", req.remote());

Ok("Hello, World!")
}
}
80 changes: 80 additions & 0 deletions silent/src/core/listener.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use super::socket_addr::SocketAddr;
use super::stream::Stream;

pub enum Listener {
TcpListener(std::net::TcpListener),
UnixListener(std::os::unix::net::UnixListener),
TokioListener(tokio::net::TcpListener),
TokioUnixListener(tokio::net::UnixListener),
}

impl From<std::net::TcpListener> for Listener {
fn from(listener: std::net::TcpListener) -> Self {
Listener::TcpListener(listener)
}
}

impl From<std::os::unix::net::UnixListener> for Listener {
fn from(value: std::os::unix::net::UnixListener) -> Self {
Listener::UnixListener(value)
}
}

impl From<tokio::net::TcpListener> for Listener {
fn from(listener: tokio::net::TcpListener) -> Self {
Listener::TokioListener(listener)
}
}

impl From<tokio::net::UnixListener> for Listener {
fn from(value: tokio::net::UnixListener) -> Self {
Listener::TokioUnixListener(value)
}
}

impl Listener {
pub async fn accept(&self) -> std::io::Result<(Stream, SocketAddr)> {
match self {
Listener::TcpListener(listener) => {
let (stream, addr) = listener.accept()?;
Ok((
Stream::TcpStream(tokio::net::TcpStream::from_std(stream)?),
SocketAddr::TcpSocketAddr(addr),
))
}
Listener::UnixListener(listener) => {
let (stream, addr) = listener.accept()?;
Ok((
Stream::UnixStream(tokio::net::UnixStream::from_std(stream)?),
SocketAddr::UnixSocketAddr(addr),
))
}
Listener::TokioListener(listener) => {
let (stream, addr) = listener.accept().await?;
Ok((Stream::TcpStream(stream), SocketAddr::TcpSocketAddr(addr)))
}
Listener::TokioUnixListener(listener) => {
let (stream, addr) = listener.accept().await?;
Ok((
Stream::UnixStream(stream),
SocketAddr::UnixSocketAddr(addr.into()),
))
}
}
}

pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
match self {
Listener::TcpListener(listener) => listener.local_addr().map(SocketAddr::TcpSocketAddr),
Listener::UnixListener(listener) => {
Ok(SocketAddr::UnixSocketAddr(listener.local_addr()?))
}
Listener::TokioListener(listener) => {
listener.local_addr().map(SocketAddr::TcpSocketAddr)
}
Listener::TokioUnixListener(listener) => {
Ok(SocketAddr::UnixSocketAddr(listener.local_addr()?.into()))
}
}
}
}
3 changes: 3 additions & 0 deletions silent/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod adapt;

#[cfg(feature = "multipart")]
pub(crate) mod form;
pub(crate) mod listener;
pub(crate) mod next;
pub(crate) mod path_param;
pub(crate) mod req_body;
Expand All @@ -10,3 +11,5 @@ pub(crate) mod res_body;
pub(crate) mod response;
#[allow(dead_code)]
mod serde;
pub(crate) mod socket_addr;
pub(crate) mod stream;
6 changes: 3 additions & 3 deletions silent/src/core/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::core::path_param::PathParam;
use crate::core::req_body::ReqBody;
#[cfg(feature = "multipart")]
use crate::core::serde::from_str_multi_val;
use crate::core::socket_addr::SocketAddr;
use crate::header::CONTENT_TYPE;
use crate::{Configs, Result, SilentError};
use bytes::Bytes;
Expand All @@ -16,7 +17,6 @@ use serde::de::StdError;
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use tokio::sync::OnceCell;
use url::form_urlencoded;

Expand Down Expand Up @@ -134,7 +134,7 @@ impl Request {

/// 获取访问真实地址
#[inline]
pub fn remote(&self) -> IpAddr {
pub fn remote(&self) -> SocketAddr {
self.headers()
.get("x-real-ip")
.and_then(|h| h.to_str().ok())
Expand All @@ -148,7 +148,7 @@ impl Request {
pub fn set_remote(&mut self, remote_addr: SocketAddr) {
if self.headers().get("x-real-ip").is_none() {
self.headers_mut()
.insert("x-real-ip", remote_addr.ip().to_string().parse().unwrap());
.insert("x-real-ip", remote_addr.to_string().parse().unwrap());
}
}

Expand Down
70 changes: 70 additions & 0 deletions silent/src/core/socket_addr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use std::fmt::{Display, Formatter};
use std::str::FromStr;

#[derive(Clone, Debug)]
pub enum SocketAddr {
TcpSocketAddr(std::net::SocketAddr),
UnixSocketAddr(std::os::unix::net::SocketAddr),
}

impl Display for SocketAddr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
#[allow(clippy::write_literal)]
SocketAddr::TcpSocketAddr(addr) => write!(f, "http{}//{:?}", ':', addr),
SocketAddr::UnixSocketAddr(addr) => {
write!(f, "{:?}", addr.as_pathname())
}
}
}
}

impl From<std::net::SocketAddr> for SocketAddr {
fn from(addr: std::net::SocketAddr) -> Self {
SocketAddr::TcpSocketAddr(addr)
}
}

impl From<std::os::unix::net::SocketAddr> for SocketAddr {
fn from(addr: std::os::unix::net::SocketAddr) -> Self {
SocketAddr::UnixSocketAddr(addr)
}
}

impl FromStr for SocketAddr {
type Err = std::io::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(addr) = s.parse::<std::net::SocketAddr>() {
Ok(SocketAddr::TcpSocketAddr(addr))
} else if let Ok(addr) = std::os::unix::net::SocketAddr::from_pathname(s) {
Ok(SocketAddr::UnixSocketAddr(addr))
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"invalid socket address",
))
}
}
}

#[cfg(test)]
mod tests {
use crate::core::socket_addr::SocketAddr;
use std::path::Path;

#[test]
fn test_socket_addr() {
let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
let socket_addr = SocketAddr::from(addr);
assert_eq!(format!("{}", socket_addr), "http://127.0.0.1:8080");

let _ = std::fs::remove_file("/tmp/sock");
let addr = std::os::unix::net::SocketAddr::from_pathname("/tmp/sock").unwrap();
let socket_addr = SocketAddr::from(addr);
assert_eq!(
format!("{}", socket_addr),
format!("{:?}", Some(Path::new("/tmp/sock")))
);
}
}
61 changes: 61 additions & 0 deletions silent/src/core/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use crate::core::socket_addr::SocketAddr;
use std::io;
use std::io::Error;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{TcpStream, UnixStream};

pub enum Stream {
TcpStream(TcpStream),
UnixStream(UnixStream),
}

impl Stream {
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
match self {
Stream::TcpStream(s) => Ok(s.peer_addr()?.into()),
Stream::UnixStream(s) => Ok(SocketAddr::UnixSocketAddr(s.peer_addr()?.into())),
}
}
}

impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Stream::TcpStream(s) => Pin::new(s).poll_read(cx, buf),
Stream::UnixStream(s) => Pin::new(s).poll_read(cx, buf),
}
}
}

impl AsyncWrite for Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
match self.get_mut() {
Stream::TcpStream(s) => Pin::new(s).poll_write(cx, buf),
Stream::UnixStream(s) => Pin::new(s).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
Stream::TcpStream(s) => Pin::new(s).poll_flush(cx),
Stream::UnixStream(s) => Pin::new(s).poll_flush(cx),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
Stream::TcpStream(s) => Pin::new(s).poll_shutdown(cx),
Stream::UnixStream(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
4 changes: 2 additions & 2 deletions silent/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ pub use crate::cookie::cookie_ext::CookieExt;
#[cfg(feature = "multipart")]
pub use crate::core::form::{FilePart, FormData};
pub use crate::core::{
next::Next, path_param::PathParam, req_body::ReqBody, request::Request, res_body::full,
res_body::stream_body, res_body::ResBody, response::Response,
listener::Listener, next::Next, path_param::PathParam, req_body::ReqBody, request::Request,
res_body::full, res_body::stream_body, res_body::ResBody, response::Response, stream::Stream,
};
pub use crate::error::{SilentError, SilentResult as Result};
#[cfg(feature = "grpc")]
Expand Down
7 changes: 5 additions & 2 deletions silent/src/service/hyper_service.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;

use hyper::service::Service as HyperService;
use hyper::{Request as HyperRequest, Response as HyperResponse};

use crate::core::socket_addr::SocketAddr;
use crate::core::{adapt::RequestAdapt, adapt::ResponseAdapt, res_body::ResBody};
use crate::prelude::ReqBody;
use crate::{Handler, Request, Response};
Expand Down Expand Up @@ -65,7 +65,10 @@ mod tests {
#[tokio::test]
async fn test_handle_request() {
// Arrange
let remote_addr = "127.0.0.1:8080".parse().unwrap();
let remote_addr = "127.0.0.1:8080"
.parse::<std::net::SocketAddr>()
.unwrap()
.into();
let routes = RootRoute::new(); // Assuming RootRoute::new() creates a new instance of RootRoute
let hsh = HyperServiceHandler::new(remote_addr, routes);
let req = hyper::Request::builder().body(()).unwrap(); // Assuming Request::new() creates a new instance of Request
Expand Down
Loading

0 comments on commit 2609ce7

Please sign in to comment.