Skip to content

Commit

Permalink
Allow to join multiple multicast groups on UDP (#554)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mallets authored Sep 20, 2023
1 parent a8cd82c commit 0b8f431
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 62 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ serde_yaml = "0.9.19"
sha3 = "0.10.6"
shared_memory = "0.12.4"
shellexpand = "3.0.0"
socket2 = "0.5.1"
socket2 = { version ="0.5.1", features = [ "all" ] }
stop-token = "0.7.0"
syn = "2.0"
tide = "0.16.0"
Expand Down
52 changes: 52 additions & 0 deletions commons/zenoh-protocol/src/core/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub const METADATA_SEPARATOR: char = '?';
pub const LIST_SEPARATOR: char = ';';
pub const FIELD_SEPARATOR: char = '=';
pub const CONFIG_SEPARATOR: char = '#';
pub const VALUE_SEPARATOR: char = '|';

fn split_once(s: &str, c: char) -> (&str, &str) {
match s.find(c) {
Expand Down Expand Up @@ -98,6 +99,17 @@ impl Parameters {
Self::iter(s).find(|x| x.0 == k).map(|x| x.1)
}

pub fn values<'s>(s: &'s str, k: &str) -> impl Iterator<Item = &'s str> + DoubleEndedIterator {
match Self::get(s, k) {
Some(v) => v.split(VALUE_SEPARATOR),
None => {
let mut i = "".split(VALUE_SEPARATOR);
i.next();
i
}
}
}

pub(super) fn insert<'s, I>(iter: I, k: &'s str, v: &'s str) -> String
where
I: Iterator<Item = (&'s str, &'s str)>,
Expand Down Expand Up @@ -272,6 +284,10 @@ impl<'a> Metadata<'a> {
pub fn get(&'a self, k: &str) -> Option<&'a str> {
Parameters::get(self.0, k)
}

pub fn values(&'a self, k: &str) -> impl Iterator<Item = &'a str> + DoubleEndedIterator {
Parameters::values(self.0, k)
}
}

impl AsRef<str> for Metadata<'_> {
Expand Down Expand Up @@ -385,6 +401,10 @@ impl<'a> Config<'a> {
pub fn get(&'a self, k: &str) -> Option<&'a str> {
Parameters::get(self.0, k)
}

pub fn values(&'a self, k: &str) -> impl Iterator<Item = &'a str> + DoubleEndedIterator {
Parameters::values(self.0, k)
}
}

impl AsRef<str> for Config<'_> {
Expand Down Expand Up @@ -764,11 +784,13 @@ fn endpoints() {
.iter()
.find(|x| x == &("a", "1"))
.unwrap();
assert_eq!(endpoint.metadata().get("a"), Some("1"));
endpoint
.metadata()
.iter()
.find(|x| x == &("b", "2"))
.unwrap();
assert_eq!(endpoint.metadata().get("b"), Some("2"));
assert!(endpoint.config().as_str().is_empty());
assert_eq!(endpoint.config().iter().count(), 0);

Expand All @@ -783,11 +805,13 @@ fn endpoints() {
.iter()
.find(|x| x == &("a", "1"))
.unwrap();
assert_eq!(endpoint.metadata().get("a"), Some("1"));
endpoint
.metadata()
.iter()
.find(|x| x == &("b", "2"))
.unwrap();
assert_eq!(endpoint.metadata().get("a"), Some("1"));
assert!(endpoint.config().as_str().is_empty());
assert_eq!(endpoint.config().iter().count(), 0);

Expand All @@ -800,7 +824,9 @@ fn endpoints() {
assert_eq!(endpoint.config().as_str(), "A=1;B=2");
assert_eq!(endpoint.config().iter().count(), 2);
endpoint.config().iter().find(|x| x == &("A", "1")).unwrap();
assert_eq!(endpoint.config().get("A"), Some("1"));
endpoint.config().iter().find(|x| x == &("B", "2")).unwrap();
assert_eq!(endpoint.config().get("B"), Some("2"));

let endpoint = EndPoint::from_str("udp/127.0.0.1:7447#B=2;A=1").unwrap();
assert_eq!(endpoint.as_str(), "udp/127.0.0.1:7447#A=1;B=2");
Expand All @@ -811,7 +837,9 @@ fn endpoints() {
assert_eq!(endpoint.config().as_str(), "A=1;B=2");
assert_eq!(endpoint.config().iter().count(), 2);
endpoint.config().iter().find(|x| x == &("A", "1")).unwrap();
assert_eq!(endpoint.config().get("A"), Some("1"));
endpoint.config().iter().find(|x| x == &("B", "2")).unwrap();
assert_eq!(endpoint.config().get("B"), Some("2"));

let endpoint = EndPoint::from_str("udp/127.0.0.1:7447?a=1;b=2#A=1;B=2").unwrap();
assert_eq!(endpoint.as_str(), "udp/127.0.0.1:7447?a=1;b=2#A=1;B=2");
Expand All @@ -824,15 +852,19 @@ fn endpoints() {
.iter()
.find(|x| x == &("a", "1"))
.unwrap();
assert_eq!(endpoint.metadata().get("a"), Some("1"));
endpoint
.metadata()
.iter()
.find(|x| x == &("b", "2"))
.unwrap();
assert_eq!(endpoint.metadata().get("b"), Some("2"));
assert_eq!(endpoint.config().as_str(), "A=1;B=2");
assert_eq!(endpoint.config().iter().count(), 2);
endpoint.config().iter().find(|x| x == &("A", "1")).unwrap();
assert_eq!(endpoint.config().get("A"), Some("1"));
endpoint.config().iter().find(|x| x == &("B", "2")).unwrap();
assert_eq!(endpoint.config().get("B"), Some("2"));

let endpoint = EndPoint::from_str("udp/127.0.0.1:7447?b=2;a=1#B=2;A=1").unwrap();
assert_eq!(endpoint.as_str(), "udp/127.0.0.1:7447?a=1;b=2#A=1;B=2");
Expand All @@ -845,15 +877,19 @@ fn endpoints() {
.iter()
.find(|x| x == &("a", "1"))
.unwrap();
assert_eq!(endpoint.metadata().get("a"), Some("1"));
endpoint
.metadata()
.iter()
.find(|x| x == &("b", "2"))
.unwrap();
assert_eq!(endpoint.metadata().get("b"), Some("2"));
assert_eq!(endpoint.config().as_str(), "A=1;B=2");
assert_eq!(endpoint.config().iter().count(), 2);
endpoint.config().iter().find(|x| x == &("A", "1")).unwrap();
assert_eq!(endpoint.config().get("A"), Some("1"));
endpoint.config().iter().find(|x| x == &("B", "2")).unwrap();
assert_eq!(endpoint.config().get("B"), Some("2"));

let mut endpoint = EndPoint::from_str("udp/127.0.0.1:7447?a=1;b=2").unwrap();
endpoint.metadata_mut().insert("c", "3").unwrap();
Expand Down Expand Up @@ -884,4 +920,20 @@ fn endpoints() {
.extend([("A", "1"), ("C", "3"), ("B", "2")].iter().copied())
.unwrap();
assert_eq!(endpoint.as_str(), "udp/127.0.0.1:7447#A=1;B=2;C=3");

let endpoint =
EndPoint::from_str("udp/127.0.0.1:7447#iface=en0;join=224.0.0.1|224.0.0.2|224.0.0.3")
.unwrap();
let c = endpoint.config();
assert_eq!(c.get("iface"), Some("en0"));
assert_eq!(c.get("join"), Some("224.0.0.1|224.0.0.2|224.0.0.3"));
assert_eq!(c.values("iface").count(), 1);
let mut i = c.values("iface");
assert_eq!(i.next(), Some("en0"));
assert_eq!(c.values("join").count(), 3);
let mut i = c.values("join");
assert_eq!(i.next(), Some("224.0.0.1"));
assert_eq!(i.next(), Some("224.0.0.2"));
assert_eq!(i.next(), Some("224.0.0.3"));
assert_eq!(i.next(), None);
}
1 change: 1 addition & 0 deletions io/zenoh-links/zenoh-link-udp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl LocatorInspector for UdpLocatorInspector {

pub mod config {
pub const UDP_MULTICAST_IFACE: &str = "iface";
pub const UDP_MULTICAST_JOIN: &str = "join";
}

pub async fn get_udp_addrs(address: Address<'_>) -> ZResult<impl Iterator<Item = SocketAddr>> {
Expand Down
85 changes: 48 additions & 37 deletions io/zenoh-links/zenoh-link-udp/src/multicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use socket2::{Domain, Protocol, Socket, Type};
use std::sync::Arc;
use std::{borrow::Cow, fmt};
use zenoh_link_commons::{LinkManagerMulticastTrait, LinkMulticast, LinkMulticastTrait};
use zenoh_protocol::core::{EndPoint, Locator};
use zenoh_protocol::core::{Config, EndPoint, Locator};
use zenoh_result::{bail, zerror, Error as ZError, ZResult};

pub struct LinkMulticastUdp {
Expand Down Expand Up @@ -154,22 +154,16 @@ impl LinkManagerMulticastUdp {
async fn new_link_inner(
&self,
mcast_addr: &SocketAddr,
iface: Option<&str>,
config: Config<'_>,
) -> ZResult<(UdpSocket, UdpSocket, SocketAddr)> {
let domain = match mcast_addr.ip() {
IpAddr::V4(_) => Domain::IPV4,
IpAddr::V6(_) => Domain::IPV6,
};

// Defaults
let _default_ipv4_iface = Ipv4Addr::UNSPECIFIED;
let default_ipv6_iface = 0;
let default_ipv4_addr = Ipv4Addr::UNSPECIFIED;
let default_ipv6_addr = Ipv6Addr::UNSPECIFIED;

// Get default iface address to bind the socket on if provided
let mut iface_addr: Option<IpAddr> = None;
if let Some(iface) = iface {
if let Some(iface) = config.get(UDP_MULTICAST_IFACE) {
iface_addr = match iface.parse() {
Ok(addr) => Some(addr),
Err(_) => zenoh_util::net::get_unicast_addresses_of_interface(iface)?
Expand Down Expand Up @@ -206,8 +200,8 @@ impl LinkManagerMulticastUdp {
match iface {
Some(iface) => iface,
None => match mcast_addr.ip() {
IpAddr::V4(_) => IpAddr::V4(default_ipv4_addr),
IpAddr::V6(_) => IpAddr::V6(default_ipv6_addr),
IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
},
}
}
Expand Down Expand Up @@ -242,37 +236,57 @@ impl LinkManagerMulticastUdp {
mcast_sock
.set_reuse_address(true)
.map_err(|e| zerror!("{}: {}", mcast_addr, e))?;
#[cfg(target_family = "unix")]
{
mcast_sock
.set_reuse_port(true)
.map_err(|e| zerror!("{}: {}", mcast_addr, e))?;
}

// Bind the socket
let default_mcast_addr = {
#[cfg(unix)]
{
match mcast_addr.ip() {
IpAddr::V4(ip4) => IpAddr::V4(ip4),
IpAddr::V6(_) => local_addr,
}
} // See UNIX Network Programmping p.212
#[cfg(windows)]
{
match mcast_addr.ip() {
IpAddr::V4(_) => IpAddr::V4(default_ipv4_addr),
IpAddr::V6(_) => IpAddr::V6(default_ipv6_addr),
}
}
// Bind the socket: let's bing to the unspecified address so we can join and read
// from multiple multicast groups.
let bind_mcast_addr = match mcast_addr.ip() {
IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
};
mcast_sock
.bind(&SocketAddr::new(default_mcast_addr, mcast_addr.port()).into())
.bind(&SocketAddr::new(bind_mcast_addr, mcast_addr.port()).into())
.map_err(|e| zerror!("{}: {}", mcast_addr, e))?;

// Join the multicast group
let join = config.values(UDP_MULTICAST_JOIN);
match mcast_addr.ip() {
IpAddr::V4(dst_ip4) => match local_addr {
IpAddr::V4(src_ip4) => mcast_sock.join_multicast_v4(&dst_ip4, &src_ip4),
IpAddr::V6(_) => panic!(),
IpAddr::V4(src_ip4) => {
// Join default multicast group
mcast_sock
.join_multicast_v4(&dst_ip4, &src_ip4)
.map_err(|e| zerror!("{}: {}", mcast_addr, e))?;
// Join any additional multicast group
for g in join {
let g: Ipv4Addr =
g.parse().map_err(|e| zerror!("{}: {}", mcast_addr, e))?;
mcast_sock
.join_multicast_v4(&g, &src_ip4)
.map_err(|e| zerror!("{}: {}", mcast_addr, e))?;
}
}
IpAddr::V6(src_ip6) => bail!("{}: unexepcted IPv6 source address", src_ip6),
},
IpAddr::V6(dst_ip6) => mcast_sock.join_multicast_v6(&dst_ip6, default_ipv6_iface),
}
.map_err(|e| zerror!("{}: {}", mcast_addr, e))?;
IpAddr::V6(dst_ip6) => {
// Join default multicast group
mcast_sock
.join_multicast_v6(&dst_ip6, 0)
.map_err(|e| zerror!("{}: {}", mcast_addr, e))?;
// Join any additional multicast group
for g in join {
let g: Ipv6Addr = g.parse().map_err(|e| zerror!("{}: {}", mcast_addr, e))?;
mcast_sock
.join_multicast_v6(&g, 0)
.map_err(|e| zerror!("{}: {}", mcast_addr, e))?;
}
}
};

// Build the async_std multicast UdpSocket
let mcast_sock: UdpSocket = std::net::UdpSocket::from(mcast_sock).into();
Expand All @@ -296,10 +310,7 @@ impl LinkManagerMulticastTrait for LinkManagerMulticastUdp {

let mut errs: Vec<ZError> = vec![];
for maddr in mcast_addrs {
match self
.new_link_inner(&maddr, endpoint.config().get(UDP_MULTICAST_IFACE))
.await
{
match self.new_link_inner(&maddr, endpoint.config()).await {
Ok((mcast_sock, ucast_sock, ucast_addr)) => {
let link = Arc::new(LinkMulticastUdp::new(
ucast_addr, ucast_sock, maddr, mcast_sock,
Expand Down
29 changes: 6 additions & 23 deletions io/zenoh-transport/src/multicast/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use zenoh_core::zlock;
use zenoh_link::{LinkMulticast, Locator};
use zenoh_protocol::{
core::{Bits, Priority, Resolution, WhatAmI, ZenohId},
transport::{BatchSize, Join, KeepAlive, PrioritySn, TransportMessage, TransportSn},
transport::{BatchSize, Join, PrioritySn, TransportMessage, TransportSn},
};
use zenoh_result::{bail, zerror, ZResult};
use zenoh_sync::{RecyclingObjectPool, Signal};
Expand All @@ -40,7 +40,6 @@ pub(super) struct TransportLinkMulticastConfig {
pub(super) zid: ZenohId,
pub(super) whatami: WhatAmI,
pub(super) lease: Duration,
pub(super) keep_alive: usize,
pub(super) join_interval: Duration,
pub(super) sn_resolution: Bits,
pub(super) batch_size: BatchSize,
Expand Down Expand Up @@ -212,17 +211,13 @@ async fn tx_task(
enum Action {
Pull((WBatch, usize)),
Join,
KeepAlive,
Stop,
}

async fn pull(pipeline: &mut TransmissionPipelineConsumer, keep_alive: Duration) -> Action {
match pipeline.pull().timeout(keep_alive).await {
Ok(res) => match res {
Some(sb) => Action::Pull(sb),
None => Action::Stop,
},
Err(_) => Action::KeepAlive,
async fn pull(pipeline: &mut TransmissionPipelineConsumer) -> Action {
match pipeline.pull().await {
Some(sb) => Action::Pull(sb),
None => Action::Stop,
}
}

Expand All @@ -236,10 +231,9 @@ async fn tx_task(
Action::Join
}

let keep_alive = config.join_interval / config.keep_alive as u32;
let mut last_join = Instant::now().checked_sub(config.join_interval).unwrap();
loop {
match pull(&mut pipeline, keep_alive)
match pull(&mut pipeline)
.race(join(last_join, config.join_interval))
.await
{
Expand Down Expand Up @@ -300,17 +294,6 @@ async fn tx_task(

last_join = Instant::now();
}
Action::KeepAlive => {
let message: TransportMessage = KeepAlive.into();

#[allow(unused_variables)] // Used when stats feature is enabled
let n = link.send(&message).await?;
#[cfg(feature = "stats")]
{
stats.inc_tx_t_msgs(1);
stats.inc_tx_bytes(n);
}
}
Action::Stop => {
// Drain the transmission pipeline and write remaining bytes on the wire
let mut batches = pipeline.drain();
Expand Down
1 change: 0 additions & 1 deletion io/zenoh-transport/src/multicast/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ impl TransportMulticastInner {
zid: self.manager.config.zid,
whatami: self.manager.config.whatami,
lease: self.manager.config.multicast.lease,
keep_alive: self.manager.config.multicast.keep_alive,
join_interval: self.manager.config.multicast.join_interval,
sn_resolution: self.manager.config.resolution.get(Field::FrameSN),
batch_size,
Expand Down

0 comments on commit 0b8f431

Please sign in to comment.