Skip to content

Commit

Permalink
WIP: AF_VSOCK support
Browse files Browse the repository at this point in the history
Signed-off-by: Alexander V. Nikolaev <[email protected]>
  • Loading branch information
avnik committed Sep 9, 2024
1 parent 5bae864 commit 058499f
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 3 deletions.
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ tracing = "0.1.37"
axum07 = { version="0.7", package="axum", optional=true }
futures-util = {version="0.3", optional=true}
tokio-util = {version = "0.7", optional=true, features=["net","codec"]}
tokio-vsock = {version = "0.5", optional=true}

[target.'cfg(unix)'.dependencies]
nix = { version = "0.26.2", default-features = false, features = ["user", "fs"], optional=true }
nix = { version = "0.26.2", default-features = false, features = ["user", "fs", "socket"], optional=true }


[features]
Expand Down Expand Up @@ -85,6 +86,9 @@ tokio-util = ["dep:tokio-util"]
## Enable [`Listener::bind_multiple`] and `sd-listen:*` (if combined with `sd_listen` feature)
multi-listener = ["dep:futures-util"]

## Enable `Vsock` support
vsock = ["dep:tokio-vsock"]

[dev-dependencies]
anyhow = "1.0.71"
argh = "0.1.10"
Expand Down
45 changes: 45 additions & 0 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ use tracing::{debug, warn};
#[cfg(unix)]
use tokio::net::UnixStream;

#[cfg(all(feature = "vsock", target_os = "linux"))]
use tokio_vsock::VsockStream;

/// Accepted connection, which can be a TCP socket, AF_UNIX stream socket or a stdin/stdout pair.
///
/// Although inner enum is private, you can use methods or `From` impls to convert this to/from usual Tokio types.
Expand All @@ -34,6 +37,8 @@ impl std::fmt::Debug for Connection {
ConnectionImpl::Tcp(_) => f.write_str("Connection(tcp)"),
#[cfg(all(feature = "unix", unix))]
ConnectionImpl::Unix(_) => f.write_str("Connection(unix)"),
#[cfg(all(feature = "vsock", target_os = "linux"))]
ConnectionImpl::Vsock(_) => f.write_str("Connection(vsock)"),
#[cfg(feature = "inetd")]
ConnectionImpl::Stdio(_, _, _) => f.write_str("Connection(stdio)"),
}
Expand All @@ -52,6 +57,8 @@ pub(crate) enum ConnectionImpl {
#[pin] tokio::io::Stdout,
Option<Sender<()>>,
),
#[cfg(all(feature = "vsock", target_os = "linux"))]
Vsock(#[pin] VsockStream),
}

#[allow(missing_docs)]
Expand Down Expand Up @@ -88,6 +95,15 @@ impl Connection {
Err(self)
}
}
#[cfg(feature = "vsock")]
#[cfg_attr(docsrs_alt, doc(cfg(feature = "vsock")))]
pub fn try_into_vsock(self) -> Result<VsockStream, Self> {
if let ConnectionImpl::Vsock(vsock) = self.0 {
Ok(vsock)
} else {
Err(self)
}
}

pub fn try_borrow_tcp(&self) -> Option<&TcpStream> {
if let ConnectionImpl::Tcp(ref s) = self.0 {
Expand All @@ -114,6 +130,15 @@ impl Connection {
None
}
}
#[cfg(feature = "vsock")]
#[cfg_attr(docsrs_alt, doc(cfg(feature = "vsock")))]
pub fn try_borrow_vsock(&self) -> Option<&VsockStream> {
if let ConnectionImpl::Vsock(ref vsock) = self.0 {
Some(vsock)
} else {
None
}
}
}

impl From<TcpStream> for Connection {
Expand All @@ -136,6 +161,14 @@ impl From<(Stdin, Stdout, Option<Sender<()>>)> for Connection {
}
}

#[cfg(all(feature = "vsock", target_os = "linux"))]
#[cfg_attr(docsrs_alt, doc(cfg(all(feature = "vsock", target_os = "linux"))))]
impl From<VsockStream> for Connection {
fn from(s: VsockStream) ->Self {
Connection(ConnectionImpl::Vsock(s))
}
}

impl AsyncRead for Connection {
#[inline]
fn poll_read(
Expand All @@ -150,6 +183,8 @@ impl AsyncRead for Connection {
ConnectionImplProj::Unix(s) => s.poll_read(cx, buf),
#[cfg(feature = "inetd")]
ConnectionImplProj::Stdio(s, _, _) => s.poll_read(cx, buf),
#[cfg(all(feature = "vsock", target_os = "linux"))]
ConnectionImplProj::Vsock(s) => s.poll_read(cx, buf),
}
}
}
Expand All @@ -168,6 +203,8 @@ impl AsyncWrite for Connection {
ConnectionImplProj::Unix(s) => s.poll_write(cx, buf),
#[cfg(feature = "inetd")]
ConnectionImplProj::Stdio(_, s, _) => s.poll_write(cx, buf),
#[cfg(all(feature = "vsock", target_os = "linux"))]
ConnectionImplProj::Vsock(s) => s.poll_write(cx, buf),
}
}

Expand All @@ -180,6 +217,8 @@ impl AsyncWrite for Connection {
ConnectionImplProj::Unix(s) => s.poll_flush(cx),
#[cfg(feature = "inetd")]
ConnectionImplProj::Stdio(_, s, _) => s.poll_flush(cx),
#[cfg(all(feature = "vsock", target_os = "linux"))]
ConnectionImplProj::Vsock(s) => s.poll_flush(cx),
}
}

Expand Down Expand Up @@ -207,6 +246,8 @@ impl AsyncWrite for Connection {
Poll::Ready(ret)
}
},
#[cfg(all(feature = "vsock", target_os = "linux"))]
ConnectionImplProj::Vsock(s) => s.poll_shutdown(cx),
}
}

Expand All @@ -223,6 +264,8 @@ impl AsyncWrite for Connection {
ConnectionImplProj::Unix(s) => s.poll_write_vectored(cx, bufs),
#[cfg(feature = "inetd")]
ConnectionImplProj::Stdio(_, s, _) => s.poll_write_vectored(cx, bufs),
#[cfg(all(feature = "vsock", target_os = "linux"))]
ConnectionImplProj::Vsock(s) => s.poll_write_vectored(cx, bufs),
}
}

Expand All @@ -234,6 +277,8 @@ impl AsyncWrite for Connection {
ConnectionImpl::Unix(s) => s.is_write_vectored(),
#[cfg(feature = "inetd")]
ConnectionImpl::Stdio(_, s, _) => s.is_write_vectored(),
#[cfg(all(feature = "vsock", target_os = "linux"))]
ConnectionImpl::Vsock(s) => s.is_write_vectored(),
}
}
}
58 changes: 57 additions & 1 deletion src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ impl std::fmt::Debug for Listener {
ListenerImpl::Unix { .. } => f.write_str("tokio_listener::Listener(unix)"),
#[cfg(feature = "inetd")]
ListenerImpl::Stdio(_) => f.write_str("tokio_listener::Listener(stdio)"),
#[cfg(all(feature = "vsock", target_os = "linux"))]
ListenerImpl::Vsock(_) => f.write_str("tokio_listener::Listener(vsock)"),
#[cfg(feature = "multi-listener")]
ListenerImpl::Multi(ref x) => {
write!(f, "tokio_listener::Listener(multi, n={})", x.v.len())
Expand Down Expand Up @@ -166,6 +168,13 @@ fn listen_abstract(a: &String, usr_opts: &UserOptions) -> Result<ListenerImpl, s
}))
}

#[cfg(all(target_os = "linux", feature = "vsock"))]
fn listen_vsock(vs: &tokio_vsock::VsockAddr) -> Result<ListenerImpl, std::io::Error> {
use tokio_vsock::VsockListener;
let listener = VsockListener::bind(vs.to_owned())?;
Ok(ListenerImpl::Vsock(ListenerImplVsock{ s: listener}))
}

#[cfg(all(feature = "sd_listen", unix))]
fn listen_from_fd(
usr_opts: &UserOptions,
Expand Down Expand Up @@ -353,6 +362,8 @@ impl Listener {
ListenerAddress::FromFdNamed(fdname) => {
listen_from_fd_named(usr_opts, fdname, sys_opts)?
}
#[cfg(all(target_os = "linux", feature = "vsock"))]
ListenerAddress::Vsock(vs) => listen_vsock(vs)?,
#[allow(unreachable_patterns)]
_ => {
#[allow(unused_imports)]
Expand Down Expand Up @@ -410,7 +421,23 @@ impl Listener {
feature: "UNIX-like platform",
}
}
}
},
ListenerAddress::Vsock(_) => {
#[cfg(target_os = "linux")]
{
MissingCompileTimeFeature {
reason: "use vsock socket",
feature: "vsock",
}
}
#[cfg(not(target_os = "linux"))]
{
MissingPlatformSypport {
reason: "use vsock socket",
feature: "linux platform",
}
}
},
};
return err.ioerr();
}
Expand Down Expand Up @@ -634,6 +661,11 @@ pub(crate) struct ListenerImplUnix {
send_buffer_size: Option<usize>,
}

#[cfg(all(feature = "vsock", target_os = "linux"))]
pub(crate) struct ListenerImplVsock {
pub(crate) s: tokio_vsock::VsockListener,
}

#[cfg(feature = "multi-listener")]
pub(crate) struct ListenerImplMulti {
pub(crate) v: Vec<ListenerImpl>,
Expand All @@ -645,6 +677,8 @@ pub(crate) enum ListenerImpl {
Unix(ListenerImplUnix),
#[cfg(feature = "inetd")]
Stdio(StdioListener),
#[cfg(all(feature = "vsock", target_os = "linux"))]
Vsock(ListenerImplVsock),
#[cfg(feature = "multi-listener")]
Multi(ListenerImplMulti),
}
Expand All @@ -660,6 +694,8 @@ impl ListenerImpl {
ListenerImpl::Unix(ui) => ui.poll_accept(cx),
#[cfg(feature = "inetd")]
ListenerImpl::Stdio(x) => x.poll_accept(cx),
#[cfg(all(feature = "vsock", target_os = "linux"))]
ListenerImpl::Vsock(vs) => vs.poll_accept(cx),
#[cfg(feature = "multi-listener")]
ListenerImpl::Multi(x) => x.poll_accept(cx),
}
Expand Down Expand Up @@ -736,6 +772,26 @@ impl ListenerImplUnix {
}
}

#[cfg(all(feature = "vsock", target_os = "linux"))]
impl ListenerImplVsock {
fn poll_accept(
&mut self,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<(Connection, SomeSocketAddr)>> {
match self.s.poll_accept(cx) {
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Ready(Ok((c, a))) => {
debug!(r#type = "vsock", "incoming connection");
Poll::Ready(Ok((
Connection(ConnectionImpl::Vsock(c)),
SomeSocketAddr::Vsock(a),
)))
},
Poll::Pending => Poll::Pending,
}
}
}

#[cfg(feature = "multi-listener")]
impl ListenerImplMulti {
fn poll_accept(
Expand Down
7 changes: 6 additions & 1 deletion src/listener_address.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::{fmt::Display, net::SocketAddr, path::PathBuf, str::FromStr};

#[cfg(all(unix, feature = "vsock"))]
use tokio_vsock::VsockAddr;

/// Abstraction over socket address that instructs in which way and at what address (if any) [`Listener`]
/// should listen for incoming stream connections.
///
Expand Down Expand Up @@ -28,7 +31,7 @@ use std::{fmt::Display, net::SocketAddr, path::PathBuf, str::FromStr};
feature = "serde",
derive(serde_with::DeserializeFromStr, serde_with::SerializeDisplay)
)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Debug, Clone, PartialEq)]
pub enum ListenerAddress {
/// Usual server TCP socket. Triggered by specifying IPv4 or IPv6 address and port pair.
/// Example: `127.0.0.1:8080`.
Expand All @@ -52,6 +55,7 @@ pub enum ListenerAddress {
///
/// Special name `*` means to bind all passed addresses simultaneously, if `multi-listener` crate feature is enabled.
FromFdNamed(String),
Vsock(VsockAddr),
}

pub(crate) const SD_LISTEN_FDS_START: i32 = 3;
Expand Down Expand Up @@ -121,6 +125,7 @@ impl Display for ListenerAddress {
ListenerAddress::FromFdNamed(name) => {
write!(f, "sd-listen:{name}")
}
ListenerAddress::Vsock(a) => a.fmt(f),
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/some_socket_addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ pub enum SomeSocketAddr {
#[cfg(feature = "inetd")]
#[cfg_attr(docsrs_alt, doc(cfg(feature = "inetd")))]
Stdio,
#[cfg(all(feature = "vsock", target_os = "linux"))]
#[cfg_attr(docsrs_alt, doc(cfg(all(feature = "vsock", target_os = "linux"))))]
Vsock(tokio_vsock::VsockAddr),
#[cfg(feature = "multi-listener")]
#[cfg_attr(docsrs_alt, doc(cfg(feature = "multi-listener")))]
Multiple,
Expand All @@ -27,6 +30,8 @@ impl Display for SomeSocketAddr {
SomeSocketAddr::Unix(_x) => "unix".fmt(f),
#[cfg(feature = "inetd")]
SomeSocketAddr::Stdio => "stdio".fmt(f),
#[cfg(all(feature = "vsock", target_os = "linux"))]
SomeSocketAddr::Vsock(x) => x.fmt(f),
#[cfg(feature = "multi-listener")]
SomeSocketAddr::Multiple => "multiple".fmt(f),
}
Expand All @@ -44,6 +49,8 @@ impl SomeSocketAddr {
SomeSocketAddr::Unix(x) => SomeSocketAddrClonable::Unix(Arc::new(x)),
#[cfg(feature = "inetd")]
SomeSocketAddr::Stdio => SomeSocketAddrClonable::Stdio,
#[cfg(all(feature = "vsock", target_os = "linux"))]
SomeSocketAddr::Vsock(x) => SomeSocketAddrClonable::Vsock(x),
#[cfg(feature = "multi-listener")]
SomeSocketAddr::Multiple => SomeSocketAddrClonable::Multiple,
}
Expand All @@ -62,6 +69,9 @@ pub enum SomeSocketAddrClonable {
#[cfg(feature = "inetd")]
#[cfg_attr(docsrs_alt, doc(cfg(feature = "inetd")))]
Stdio,
#[cfg(all(feature = "vsock", target_os = "linux"))]
#[cfg_attr(docsrs_alt, doc(cfg(all(feature = "vsock", target_os = "linux"))))]
Vsock(tokio_vsock::VsockAddr),
#[cfg(feature = "multi-listener")]
#[cfg_attr(docsrs_alt, doc(cfg(feature = "multi-listener")))]
Multiple,
Expand All @@ -75,6 +85,8 @@ impl Display for SomeSocketAddrClonable {
SomeSocketAddrClonable::Unix(x) => write!(f, "unix:{x:?}"),
#[cfg(feature = "inetd")]
SomeSocketAddrClonable::Stdio => "stdio".fmt(f),
#[cfg(all(feature = "vsock", target_os = "linux"))]
SomeSocketAddrClonable::Vsock(x) => x.fmt(f),
#[cfg(feature = "multi-listener")]
SomeSocketAddrClonable::Multiple => "multiple".fmt(f),
}
Expand Down
4 changes: 4 additions & 0 deletions src/tokioutil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ impl tokio_util::net::Listener for crate::Listener {
}) => Ok(SomeSocketAddr::Unix(s.local_addr()?)),
#[cfg(feature = "inetd")]
crate::listener::ListenerImpl::Stdio(_) => Ok(SomeSocketAddr::Stdio),
#[cfg(all(feature = "vsock", target_os = "linux"))]
crate::listener::ListenerImpl::Vsock(crate::listener::ListenerImplVsock{s}) => {
Ok(SomeSocketAddr::Vsock(s.local_addr()?))
},
#[cfg(feature = "multi-listener")]
crate::listener::ListenerImpl::Multi(_) => Ok(SomeSocketAddr::Multiple),
}
Expand Down
4 changes: 4 additions & 0 deletions src/tonic010.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ impl Connected for Connection {
if self.try_borrow_stdio().is_some() {
return ListenerConnectInfo::Stdio;
}
#[cfg(feature = "vsock")]
if let Some(vsock) = self.try_borrow_vsock() {
return ListenerConnectInfo::Vsock(vsock.connect_info())
}

ListenerConnectInfo::Other
}
Expand Down

0 comments on commit 058499f

Please sign in to comment.