Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to binding on the interface #755

Merged
merged 9 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions commons/zenoh-util/src/std_only/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// Contributors:
// ZettaScale Zenoh Team, <[email protected]>
//
use async_std::net::TcpStream;
use async_std::net::{TcpListener, TcpStream, UdpSocket};
use std::net::{IpAddr, Ipv6Addr};
use std::time::Duration;
use zenoh_core::zconfigurable;
Expand Down Expand Up @@ -210,12 +210,19 @@ pub fn get_multicast_interfaces() -> Vec<IpAddr> {
}
}

pub fn get_local_addresses() -> ZResult<Vec<IpAddr>> {
pub fn get_local_addresses(interface: Option<&str>) -> ZResult<Vec<IpAddr>> {
#[cfg(unix)]
{
Ok(pnet_datalink::interfaces()
.into_iter()
.filter(|iface| iface.is_up() && iface.is_running())
.filter(|iface| {
if let Some(interface) = interface.as_ref() {
if iface.name != *interface {
return false;
}
}
iface.is_up() && iface.is_running()
})
.flat_map(|iface| iface.ips)
.map(|ipnet| ipnet.ip())
.collect())
Expand All @@ -232,6 +239,11 @@ pub fn get_local_addresses() -> ZResult<Vec<IpAddr>> {
let mut result = vec![];
let mut next_iface = (buffer.as_ptr() as *mut IP_ADAPTER_ADDRESSES_LH).as_ref();
while let Some(iface) = next_iface {
if let Some(interface) = interface.as_ref() {
if ffi::pstr_to_string(iface.AdapterName) != *interface {
continue;
}
}
let mut next_ucast_addr = iface.FirstUnicastAddress.as_ref();
while let Some(ucast_addr) = next_ucast_addr {
if let Ok(ifaddr) = ffi::win::sockaddr_to_addr(ucast_addr.Address) {
Expand Down Expand Up @@ -412,8 +424,8 @@ pub fn get_interface_names_by_addr(addr: IpAddr) -> ZResult<Vec<String>> {
}
}

pub fn get_ipv4_ipaddrs() -> Vec<IpAddr> {
get_local_addresses()
pub fn get_ipv4_ipaddrs(interface: Option<&str>) -> Vec<IpAddr> {
get_local_addresses(interface)
.unwrap_or_else(|_| vec![])
.drain(..)
.filter_map(|x| match x {
Expand All @@ -425,12 +437,12 @@ pub fn get_ipv4_ipaddrs() -> Vec<IpAddr> {
.collect()
}

pub fn get_ipv6_ipaddrs() -> Vec<IpAddr> {
pub fn get_ipv6_ipaddrs(interface: Option<&str>) -> Vec<IpAddr> {
const fn is_unicast_link_local(addr: &Ipv6Addr) -> bool {
(addr.segments()[0] & 0xffc0) == 0xfe80
}

let ipaddrs = get_local_addresses().unwrap_or_else(|_| vec![]);
let ipaddrs = get_local_addresses(interface).unwrap_or_else(|_| vec![]);

// Get first all IPv4 addresses
let ipv4_iter = ipaddrs
Expand Down Expand Up @@ -479,3 +491,53 @@ pub fn get_ipv6_ipaddrs() -> Vec<IpAddr> {
.chain(priv_ipv4_addrs)
.collect()
}

#[cfg(target_os = "linux")]
fn set_bind_to_device(socket: std::os::raw::c_int, iface: Option<&str>) {
if let Some(iface) = iface {
// @TODO: switch to bind_device after tokio porting
log::debug!("Listen at the interface: {}", iface);
unsafe {
libc::setsockopt(
socket,
libc::SOL_SOCKET,
libc::SO_BINDTODEVICE,
iface.as_ptr() as *const std::os::raw::c_void,
iface.len() as libc::socklen_t,
);
}
}
}

#[cfg(target_os = "linux")]
pub fn set_bind_to_device_tcp_listener(socket: &TcpListener, iface: Option<&str>) {
use std::os::fd::AsRawFd;
set_bind_to_device(socket.as_raw_fd(), iface);
}

#[cfg(target_os = "linux")]
pub fn set_bind_to_device_tcp_stream(socket: &TcpStream, iface: Option<&str>) {
use std::os::fd::AsRawFd;
set_bind_to_device(socket.as_raw_fd(), iface);
}

#[cfg(target_os = "linux")]
pub fn set_bind_to_device_udp_socket(socket: &UdpSocket, iface: Option<&str>) {
use std::os::fd::AsRawFd;
set_bind_to_device(socket.as_raw_fd(), iface);
}

#[cfg(any(target_os = "macos", target_os = "windows"))]
pub fn set_bind_to_device_tcp_listener(_socket: &TcpListener, _iface: Option<&str>) {
log::warn!("Listen at the interface is not supported for this platform");
}

#[cfg(any(target_os = "macos", target_os = "windows"))]
pub fn set_bind_to_device_tcp_stream(_socket: &TcpStream, _iface: Option<&str>) {
log::warn!("Listen at the interface is not supported for this platform");
}

#[cfg(any(target_os = "macos", target_os = "windows"))]
pub fn set_bind_to_device_udp_socket(_socket: &UdpSocket, _iface: Option<&str>) {
log::warn!("Listen at the interface is not supported for this platform");
}
3 changes: 3 additions & 0 deletions io/zenoh-link-commons/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ use zenoh_result::ZResult;
/*************************************/
/* GENERAL */
/*************************************/

pub const BIND_INTERFACE: &str = "iface";

#[derive(Clone, Debug, Serialize, Hash, PartialEq, Eq)]
pub struct Link {
pub src: Locator,
Expand Down
8 changes: 6 additions & 2 deletions io/zenoh-link-commons/src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use zenoh_protocol::core::{EndPoint, Locator};
use zenoh_result::{zerror, ZResult};
use zenoh_sync::Signal;

use crate::BIND_INTERFACE;

pub struct ListenerUnicastIP {
endpoint: EndPoint,
active: Arc<AtomicBool>,
Expand Down Expand Up @@ -109,12 +111,14 @@ impl ListenersUnicastIP {
let guard = zread!(self.listeners);
for (key, value) in guard.iter() {
let (kip, kpt) = (key.ip(), key.port());
let config = value.endpoint.config();
let iface = config.get(BIND_INTERFACE);

// Either ipv4/0.0.0.0 or ipv6/[::]
if kip.is_unspecified() {
let mut addrs = match kip {
IpAddr::V4(_) => zenoh_util::net::get_ipv4_ipaddrs(),
IpAddr::V6(_) => zenoh_util::net::get_ipv6_ipaddrs(),
IpAddr::V4(_) => zenoh_util::net::get_ipv4_ipaddrs(iface),
IpAddr::V6(_) => zenoh_util::net::get_ipv6_ipaddrs(iface),
};
let iter = addrs.drain(..).map(|x| {
Locator::new(
Expand Down
21 changes: 17 additions & 4 deletions io/zenoh-links/zenoh-link-tcp/src/unicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::sync::Arc;
use std::time::Duration;
use zenoh_link_commons::{
get_ip_interface_names, LinkManagerUnicastTrait, LinkUnicast, LinkUnicastTrait,
ListenersUnicastIP, NewLinkChannelSender,
ListenersUnicastIP, NewLinkChannelSender, BIND_INTERFACE,
};
use zenoh_protocol::core::{EndPoint, Locator};
use zenoh_result::{bail, zerror, Error as ZError, ZResult};
Expand Down Expand Up @@ -199,6 +199,7 @@ impl LinkManagerUnicastTcp {
async fn new_link_inner(
&self,
dst_addr: &SocketAddr,
iface: Option<&str>,
) -> ZResult<(TcpStream, SocketAddr, SocketAddr)> {
let stream = TcpStream::connect(dst_addr)
.await
Expand All @@ -212,15 +213,23 @@ impl LinkManagerUnicastTcp {
.peer_addr()
.map_err(|e| zerror!("{}: {}", dst_addr, e))?;

zenoh_util::net::set_bind_to_device_tcp_stream(&stream, iface);

Ok((stream, src_addr, dst_addr))
}

async fn new_listener_inner(&self, addr: &SocketAddr) -> ZResult<(TcpListener, SocketAddr)> {
async fn new_listener_inner(
&self,
addr: &SocketAddr,
iface: Option<&str>,
) -> ZResult<(TcpListener, SocketAddr)> {
// Bind the TCP socket
let socket = TcpListener::bind(addr)
.await
.map_err(|e| zerror!("{}: {}", addr, e))?;

zenoh_util::net::set_bind_to_device_tcp_listener(&socket, iface);

let local_addr = socket
.local_addr()
.map_err(|e| zerror!("{}: {}", addr, e))?;
Expand All @@ -233,10 +242,12 @@ impl LinkManagerUnicastTcp {
impl LinkManagerUnicastTrait for LinkManagerUnicastTcp {
async fn new_link(&self, endpoint: EndPoint) -> ZResult<LinkUnicast> {
let dst_addrs = get_tcp_addrs(endpoint.address()).await?;
let config = endpoint.config();
let iface = config.get(BIND_INTERFACE);

let mut errs: Vec<ZError> = vec![];
for da in dst_addrs {
match self.new_link_inner(&da).await {
match self.new_link_inner(&da, iface).await {
Ok((stream, src_addr, dst_addr)) => {
let link = Arc::new(LinkUnicastTcp::new(stream, src_addr, dst_addr));
return Ok(LinkUnicast(link));
Expand All @@ -260,10 +271,12 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastTcp {

async fn new_listener(&self, mut endpoint: EndPoint) -> ZResult<Locator> {
let addrs = get_tcp_addrs(endpoint.address()).await?;
let config = endpoint.config();
let iface = config.get(BIND_INTERFACE);

let mut errs: Vec<ZError> = vec![];
for da in addrs {
match self.new_listener_inner(&da).await {
match self.new_listener_inner(&da, iface).await {
Ok((socket, local_addr)) => {
// Update the endpoint locator address
endpoint = EndPoint::new(
Expand Down
21 changes: 17 additions & 4 deletions io/zenoh-links/zenoh-link-udp/src/unicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use std::time::Duration;
use zenoh_core::{zasynclock, zlock};
use zenoh_link_commons::{
get_ip_interface_names, ConstructibleLinkManagerUnicast, LinkManagerUnicastTrait, LinkUnicast,
LinkUnicastTrait, ListenersUnicastIP, NewLinkChannelSender,
LinkUnicastTrait, ListenersUnicastIP, NewLinkChannelSender, BIND_INTERFACE,
};
use zenoh_protocol::core::{EndPoint, Locator};
use zenoh_result::{bail, zerror, Error as ZError, ZResult};
Expand Down Expand Up @@ -261,6 +261,7 @@ impl LinkManagerUnicastUdp {
async fn new_link_inner(
&self,
dst_addr: &SocketAddr,
iface: Option<&str>,
) -> ZResult<(UdpSocket, SocketAddr, SocketAddr)> {
// Establish a UDP socket
let socket = UdpSocket::bind(SocketAddr::new(
Expand All @@ -278,6 +279,8 @@ impl LinkManagerUnicastUdp {
e
})?;

zenoh_util::net::set_bind_to_device_udp_socket(&socket, iface);

// Connect the socket to the remote address
socket.connect(dst_addr).await.map_err(|e| {
let e = zerror!("Can not create a new UDP link bound to {}: {}", dst_addr, e);
Expand All @@ -301,14 +304,20 @@ impl LinkManagerUnicastUdp {
Ok((socket, src_addr, dst_addr))
}

async fn new_listener_inner(&self, addr: &SocketAddr) -> ZResult<(UdpSocket, SocketAddr)> {
async fn new_listener_inner(
&self,
addr: &SocketAddr,
iface: Option<&str>,
) -> ZResult<(UdpSocket, SocketAddr)> {
// Bind the UDP socket
let socket = UdpSocket::bind(addr).await.map_err(|e| {
let e = zerror!("Can not create a new UDP listener on {}: {}", addr, e);
log::warn!("{}", e);
e
})?;

zenoh_util::net::set_bind_to_device_udp_socket(&socket, iface);

let local_addr = socket.local_addr().map_err(|e| {
let e = zerror!("Can not create a new UDP listener on {}: {}", addr, e);
log::warn!("{}", e);
Expand All @@ -325,10 +334,12 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastUdp {
let dst_addrs = get_udp_addrs(endpoint.address())
.await?
.filter(|a| !a.ip().is_multicast());
let config = endpoint.config();
let iface = config.get(BIND_INTERFACE);

let mut errs: Vec<ZError> = vec![];
for da in dst_addrs {
match self.new_link_inner(&da).await {
match self.new_link_inner(&da, iface).await {
Ok((socket, src_addr, dst_addr)) => {
// Create UDP link
let link = Arc::new(LinkUnicastUdp::new(
Expand Down Expand Up @@ -362,10 +373,12 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastUdp {
let addrs = get_udp_addrs(endpoint.address())
.await?
.filter(|a| !a.ip().is_multicast());
let config = endpoint.config();
let iface = config.get(BIND_INTERFACE);

let mut errs: Vec<ZError> = vec![];
for da in addrs {
match self.new_listener_inner(&da).await {
match self.new_listener_inner(&da, iface).await {
Ok((socket, local_addr)) => {
// Update the endpoint locator address
endpoint = EndPoint::new(
Expand Down
4 changes: 2 additions & 2 deletions io/zenoh-links/zenoh-link-ws/src/unicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastWs {
for (key, value) in guard.iter() {
let listener_locator = value.endpoint.to_locator();
if key.ip() == default_ipv4 {
match zenoh_util::net::get_local_addresses() {
match zenoh_util::net::get_local_addresses(None) {
Ok(ipaddrs) => {
for ipaddr in ipaddrs {
if !ipaddr.is_loopback() && !ipaddr.is_multicast() && ipaddr.is_ipv4() {
Expand All @@ -433,7 +433,7 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastWs {
Err(err) => log::error!("Unable to get local addresses: {}", err),
}
} else if key.ip() == default_ipv6 {
match zenoh_util::net::get_local_addresses() {
match zenoh_util::net::get_local_addresses(None) {
Ok(ipaddrs) => {
for ipaddr in ipaddrs {
if !ipaddr.is_loopback() && !ipaddr.is_multicast() && ipaddr.is_ipv6() {
Expand Down
Loading