Skip to content

Commit

Permalink
Finish implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
sashacmc committed Mar 14, 2024
1 parent a6b9f11 commit 8fc15d0
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 73 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions io/zenoh-links/zenoh-link-vsock/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ tokio = { workspace = true, features = ["net", "io-util", "rt", "time"] }
tokio-util = { workspace = true, features = ["rt"] }
tokio-vsock = { workspace = true }
log = { workspace = true }
libc = { workspace = true }
zenoh-core = { workspace = true }
zenoh-link-commons = { workspace = true }
zenoh-protocol = { workspace = true }
Expand Down
50 changes: 29 additions & 21 deletions io/zenoh-links/zenoh-link-vsock/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@
//!
//! [Click here for Zenoh's documentation](../zenoh/index.html)
use async_trait::async_trait;
use libc::VMADDR_PORT_ANY;
use tokio_vsock::{
VsockAddr, VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL,
};
use zenoh_core::zconfigurable;
use zenoh_link_commons::LocatorInspector;
use zenoh_protocol::core::{endpoint::Address, Locator};
use zenoh_result::{zerror, ZResult};
use zenoh_result::{bail, ZResult};

mod unicast;
pub use unicast::*;

pub const VSOCK_LOCATOR_PREFIX: &str = "vsock";

pub const VSOCK_ADDR_VMADDR_CID_ANY: &str = "VMADDR_CID_ANY";
pub const VSOCK_ADDR_VMADDR_CID_HYPERVISOR: &str = "VMADDR_CID_HYPERVISOR";
pub const VSOCK_ADDR_VMADDR_CID_LOCAL: &str = "VMADDR_CID_LOCAL";
pub const VSOCK_ADDR_VMADDR_CID_HOST: &str = "VMADDR_CID_HOST";
pub const VSOCK_VMADDR_CID_ANY: &str = "VMADDR_CID_ANY";
pub const VSOCK_VMADDR_CID_HYPERVISOR: &str = "VMADDR_CID_HYPERVISOR";
pub const VSOCK_VMADDR_CID_LOCAL: &str = "VMADDR_CID_LOCAL";
pub const VSOCK_VMADDR_CID_HOST: &str = "VMADDR_CID_HOST";

pub const VSOCK_VMADDR_PORT_ANY: &str = "VMADDR_PORT_ANY";

#[derive(Default, Clone, Copy)]
pub struct VsockLocatorInspector;
Expand All @@ -54,34 +57,39 @@ zconfigurable! {
static ref VSOCK_DEFAULT_MTU: u16 = u16::MAX;
// Amount of time in microseconds to throttle the accept loop upon an error.
// Default set to 100 ms.
static ref TCP_ACCEPT_THROTTLE_TIME: u64 = 100_000;
static ref VSOCK_ACCEPT_THROTTLE_TIME: u64 = 100_000;
}

pub fn get_vsock_addr(address: Address<'_>) -> ZResult<VsockAddr> {
let parts: Vec<&str> = address.as_str().split(':').collect();

if parts.len() != 2 {
zerror!("Incorrect vsock address: {:?}", address);
}

let mut port = 0;
if let Ok(p) = parts[1].parse::<u32>() {
port = p;
} else {
zerror!("Incorrect vsock port: {:?}", parts[1]);
bail!("Incorrect vsock address: {:?}", address);
}

let cid = match parts[0] {
VSOCK_ADDR_VMADDR_CID_ANY => VMADDR_CID_ANY,
VSOCK_ADDR_VMADDR_CID_HYPERVISOR => VMADDR_CID_HYPERVISOR,
VSOCK_ADDR_VMADDR_CID_HOST => VMADDR_CID_HOST,
VSOCK_ADDR_VMADDR_CID_LOCAL => VMADDR_CID_LOCAL,
let cid = match parts[0].to_uppercase().as_str() {
VSOCK_VMADDR_CID_HYPERVISOR => VMADDR_CID_HYPERVISOR,
VSOCK_VMADDR_CID_HOST => VMADDR_CID_HOST,
VSOCK_VMADDR_CID_LOCAL => VMADDR_CID_LOCAL,
VSOCK_VMADDR_CID_ANY => VMADDR_CID_ANY,
"-1" => VMADDR_CID_ANY,
_ => {
if let Ok(cid) = parts[0].parse::<u32>() {
cid
} else {
zerror!("Incorrect vsock cid: {:?}", parts[0]);
0
bail!("Incorrect vsock cid: {:?}", parts[0]);
}
}
};

let port = match parts[1].to_uppercase().as_str() {
VSOCK_VMADDR_PORT_ANY => VMADDR_PORT_ANY,
"-1" => VMADDR_PORT_ANY,
_ => {
if let Ok(cid) = parts[1].parse::<u32>() {
cid
} else {
bail!("Incorrect vsock port: {:?}", parts[1]);
}
}
};
Expand Down
103 changes: 51 additions & 52 deletions io/zenoh-links/zenoh-link-vsock/src/unicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@
// Contributors:
// ZettaScale Zenoh Team, <[email protected]>
//

use async_trait::async_trait;
use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::convert::TryInto;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::RwLock as AsyncRwLock;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use zenoh_core::{zasyncread, zasyncwrite};
use zenoh_link_commons::{
get_ip_interface_names, LinkManagerUnicastTrait, LinkUnicast, LinkUnicastTrait,
ListenersUnicastIP, NewLinkChannelSender, BIND_INTERFACE,
LinkManagerUnicastTrait, LinkUnicast, LinkUnicastTrait, NewLinkChannelSender,
};
use zenoh_protocol::core::{EndPoint, Locator};
use zenoh_result::{bail, zerror, Error as ZError, ZResult};
use zenoh_result::{bail, zerror, ZResult};

use super::{get_vsock_addr, TCP_ACCEPT_THROTTLE_TIME, VSOCK_DEFAULT_MTU, VSOCK_LOCATOR_PREFIX};
use super::{get_vsock_addr, VSOCK_ACCEPT_THROTTLE_TIME, VSOCK_DEFAULT_MTU, VSOCK_LOCATOR_PREFIX};
use tokio_vsock::{VsockAddr, VsockListener, VsockStream};

pub struct LinkUnicastVsock {
Expand Down Expand Up @@ -66,7 +66,6 @@ impl LinkUnicastVsock {
impl LinkUnicastTrait for LinkUnicastVsock {
async fn close(&self) -> ZResult<()> {
log::trace!("Closing vsock link: {}", self);
// Close the underlying vsock socket
self.get_mut_socket().shutdown().await.map_err(|e| {
let e = zerror!("vsock link shutdown {}: {:?}", self, e);
log::trace!("{}", e);
Expand Down Expand Up @@ -128,8 +127,7 @@ impl LinkUnicastTrait for LinkUnicastVsock {

#[inline(always)]
fn get_interface_names(&self) -> Vec<String> {
// TODO(sashacmc): get_ip_interface_names(&self.src_addr)
vec![]
vec!["vsock".to_string()]
}

#[inline(always)]
Expand All @@ -152,7 +150,7 @@ impl fmt::Display for LinkUnicastVsock {

impl fmt::Debug for LinkUnicastVsock {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tcp")
f.debug_struct("Vsock")
.field("src", &self.src_addr)
.field("dst", &self.dst_addr)
.finish()
Expand Down Expand Up @@ -181,14 +179,14 @@ impl ListenerUnicastVsock {

pub struct LinkManagerUnicastVsock {
manager: NewLinkChannelSender,
listeners: tokio::sync::RwLock<HashMap<VsockAddr, ListenerUnicastVsock>>,
listeners: Arc<AsyncRwLock<HashMap<VsockAddr, ListenerUnicastVsock>>>,
}

impl LinkManagerUnicastVsock {
pub fn new(manager: NewLinkChannelSender) -> Self {
Self {
manager,
listeners: tokio::sync::RwLock::new(HashMap::new()),
listeners: Arc::new(AsyncRwLock::new(HashMap::new())),
}
}
}
Expand All @@ -198,11 +196,9 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastVsock {
async fn new_link(&self, endpoint: EndPoint) -> ZResult<LinkUnicast> {
let addr = get_vsock_addr(endpoint.address())?;
if let Ok(stream) = VsockStream::connect(addr).await {
let link = Arc::new(LinkUnicastVsock::new(
stream,
stream.local_addr()?,
stream.peer_addr()?,
));
let local_addr = stream.local_addr()?;
let peer_addr = stream.peer_addr()?;
let link = Arc::new(LinkUnicastVsock::new(stream, local_addr, peer_addr));
return Ok(LinkUnicast(link));
}

Expand All @@ -224,80 +220,83 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastVsock {
let c_token = token.clone();

let c_manager = self.manager.clone();
let task = async move { accept_task(listener, c_token, c_manager).await };

let locator = endpoint.to_locator();
zasyncwrite!(self.listeners)
.add_listener(endpoint, local_addr, task, token)
.await?;

let mut listeners = zasyncwrite!(self.listeners);
let c_listeners = self.listeners.clone();
let c_addr = addr;
let task = async move {
// Wait for the accept loop to terminate
let res = accept_task(listener, c_token, c_manager).await;
zasyncwrite!(c_listeners).remove(&c_addr);
res
};
let handle = zenoh_runtime::ZRuntime::Acceptor.spawn(task);

let listener = ListenerUnicastVsock::new(endpoint, token, handle);
// Update the list of active listeners on the manager
listeners.insert(addr, listener);
return Ok(locator);
}

bail!("Can not create a new vsock listener bound to {}", endpoint)
}

async fn del_listener(&self, endpoint: &EndPoint) -> ZResult<()> {
let addrs = get_vsock_addrs(endpoint.address()).await?;

// Stop the listener
let mut errs: Vec<ZError> = vec![];
let mut failed = true;
for a in addrs {
match self.listeners.del_listener(a).await {
Ok(_) => {
failed = false;
break;
}
Err(err) => {
errs.push(zerror!("{}", err).into());
}
}
}
let addr = get_vsock_addr(endpoint.address())?;

if failed {
bail!(
"Can not delete the TCP listener bound to {}: {:?}",
endpoint,
errs
let listener = zasyncwrite!(self.listeners).remove(&addr).ok_or_else(|| {
zerror!(
"Can not delete the listener because it has not been found: {}",
addr
)
}
Ok(())
})?;

// Send the stop signal
listener.stop().await;
listener.handle.await?
}

async fn get_listeners(&self) -> Vec<EndPoint> {
self.listeners.get_endpoints()
zasyncread!(self.listeners)
.values()
.map(|x| x.endpoint.clone())
.collect()
}

async fn get_locators(&self) -> Vec<Locator> {
self.listeners.get_locators()
zasyncread!(self.listeners)
.values()
.map(|x| x.endpoint.to_locator())
.collect()
}
}

async fn accept_task(
socket: VsockListener,
mut socket: VsockListener,
token: CancellationToken,
manager: NewLinkChannelSender,
) -> ZResult<()> {
async fn accept(socket: &VsockListener) -> ZResult<(VsockStream, VsockAddr)> {
async fn accept(socket: &mut VsockListener) -> ZResult<(VsockStream, VsockAddr)> {
let res = socket.accept().await.map_err(|e| zerror!(e))?;
Ok(res)
}

let src_addr = socket.local_addr().map_err(|e| {
let e = zerror!("Can not accept TCP connections: {}", e);
let e = zerror!("Can not accept vsock connections: {}", e);
log::warn!("{}", e);
e
})?;

log::trace!("Ready to accept TCP connections on: {:?}", src_addr);
log::trace!("Ready to accept vsock connections on: {:?}", src_addr);
loop {
tokio::select! {
_ = token.cancelled() => break,
res = accept(&socket) => {
res = accept(&mut socket) => {
match res {
Ok((stream, dst_addr)) => {
log::debug!("Accepted TCP connection on {:?}: {:?}", src_addr, dst_addr);
log::debug!("Accepted vsock connection on {:?}: {:?}", src_addr, dst_addr);
// Create the new link object
let link = Arc::new(LinkUnicastVsock::new(stream, src_addr, dst_addr));

Expand All @@ -314,7 +313,7 @@ async fn accept_task(
// Linux systems this limit can be changed by using the "ulimit" command line
// tool. In case of systemd-based systems, this can be changed by using the
// "sysctl" command line tool.
tokio::time::sleep(Duration::from_micros(*TCP_ACCEPT_THROTTLE_TIME)).await;
tokio::time::sleep(Duration::from_micros(*VSOCK_ACCEPT_THROTTLE_TIME)).await;
}

}
Expand Down

0 comments on commit 8fc15d0

Please sign in to comment.