From 5455ecc8fed27d31d988edc694f7af72fb11ce1b Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Mon, 25 Nov 2024 20:59:02 +0100 Subject: [PATCH] feat: add local/remote sub cache to publisher --- examples/examples/z_local_pub_sub_thr.rs | 155 ++++++++++++++++++ zenoh/src/api/builders/publisher.rs | 11 +- zenoh/src/api/publisher.rs | 3 + zenoh/src/api/session.rs | 65 ++++++-- zenoh/src/net/routing/dispatcher/face.rs | 5 + zenoh/src/net/routing/dispatcher/pubsub.rs | 180 ++++++++++----------- 6 files changed, 315 insertions(+), 104 deletions(-) create mode 100644 examples/examples/z_local_pub_sub_thr.rs diff --git a/examples/examples/z_local_pub_sub_thr.rs b/examples/examples/z_local_pub_sub_thr.rs new file mode 100644 index 000000000..666fd89f8 --- /dev/null +++ b/examples/examples/z_local_pub_sub_thr.rs @@ -0,0 +1,155 @@ +// +// Copyright (c) 2023 ZettaScale Technology +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// +// Contributors: +// ZettaScale Zenoh Team, +// + +use std::{convert::TryInto, time::Instant}; + +use clap::Parser; +use zenoh::{ + bytes::ZBytes, + qos::{CongestionControl, Priority}, + Wait, +}; +use zenoh_examples::CommonArgs; + +struct Stats { + round_count: usize, + round_size: usize, + finished_rounds: usize, + round_start: Instant, + global_start: Option, +} +impl Stats { + fn new(round_size: usize) -> Self { + Stats { + round_count: 0, + round_size, + finished_rounds: 0, + round_start: Instant::now(), + global_start: None, + } + } + fn increment(&mut self) { + if self.round_count == 0 { + self.round_start = Instant::now(); + if self.global_start.is_none() { + self.global_start = Some(self.round_start) + } + self.round_count += 1; + } else if self.round_count < self.round_size { + self.round_count += 1; + } else { + self.print_round(); + self.finished_rounds += 1; + self.round_count = 0; + } + } + fn print_round(&self) { + let elapsed = self.round_start.elapsed().as_secs_f64(); + let throughput = (self.round_size as f64) / elapsed; + println!("{throughput} msg/s"); + } +} +impl Drop for Stats { + fn drop(&mut self) { + let Some(global_start) = self.global_start else { + return; + }; + let elapsed = global_start.elapsed().as_secs_f64(); + let total = self.round_size * self.finished_rounds + self.round_count; + let throughput = total as f64 / elapsed; + println!("Received {total} messages over {elapsed:.2}s: {throughput}msg/s"); + } +} + +fn main() { + // initiate logging + zenoh::init_log_from_env_or("error"); + let args = Args::parse(); + + let session = zenoh::open(args.common).wait().unwrap(); + + let key_expr = "test/thr"; + + let mut stats = Stats::new(args.number); + session + .declare_subscriber(key_expr) + .callback_mut(move |_sample| { + stats.increment(); + if stats.finished_rounds >= args.samples { + std::process::exit(0) + } + }) + .background() + .wait() + .unwrap(); + + let mut prio = Priority::DEFAULT; + if let Some(p) = args.priority { + prio = p.try_into().unwrap(); + } + + let publisher = session + .declare_publisher(key_expr) + .congestion_control(CongestionControl::Block) + .priority(prio) + .express(args.express) + .wait() + .unwrap(); + + println!("Press CTRL-C to quit..."); + let payload_size = args.payload_size; + let data: ZBytes = (0..payload_size) + .map(|i| (i % 10) as u8) + .collect::>() + .into(); + let mut count: usize = 0; + let mut start = std::time::Instant::now(); + loop { + publisher.put(data.clone()).wait().unwrap(); + + if args.print { + if count < args.number { + count += 1; + } else { + let thpt = count as f64 / start.elapsed().as_secs_f64(); + println!("{thpt} msg/s"); + count = 0; + start = std::time::Instant::now(); + } + } + } +} + +#[derive(Parser, Clone, PartialEq, Eq, Hash, Debug)] +struct Args { + #[arg(short, long, default_value = "10")] + /// Number of throughput measurements. + samples: usize, + /// express for sending data + #[arg(long, default_value = "false")] + express: bool, + /// Priority for sending data + #[arg(short, long)] + priority: Option, + /// Print the statistics + #[arg(short = 't', long)] + print: bool, + /// Number of messages in each throughput measurements + #[arg(short, long, default_value = "10000000")] + number: usize, + /// Sets the size of the payload to publish + payload_size: usize, + #[command(flatten)] + common: CommonArgs, +} diff --git a/zenoh/src/api/builders/publisher.rs b/zenoh/src/api/builders/publisher.rs index a2515eb28..ea0ce193a 100644 --- a/zenoh/src/api/builders/publisher.rs +++ b/zenoh/src/api/builders/publisher.rs @@ -11,7 +11,10 @@ // Contributors: // ZettaScale Zenoh Team, // -use std::future::{IntoFuture, Ready}; +use std::{ + future::{IntoFuture, Ready}, + sync::atomic::AtomicU64, +}; use zenoh_core::{Resolvable, Result as ZResult, Wait}; use zenoh_protocol::core::CongestionControl; @@ -206,6 +209,7 @@ impl Wait for PublicationBuilder, PublicationBuilderPut #[inline] fn wait(self) -> ::To { self.publisher.session.0.resolve_put( + None, &self.publisher.key_expr?, self.kind.payload, SampleKind::Put, @@ -228,6 +232,7 @@ impl Wait for PublicationBuilder, PublicationBuilderDel #[inline] fn wait(self) -> ::To { self.publisher.session.0.resolve_put( + None, &self.publisher.key_expr?, ZBytes::new(), SampleKind::Delete, @@ -383,6 +388,8 @@ impl Wait for PublisherBuilder<'_, '_> { .declare_publisher_inner(key_expr.clone(), self.destination)?; Ok(Publisher { session: self.session.downgrade(), + // TODO use constants here + cache: AtomicU64::new(0b11), id, key_expr, encoding: self.encoding, @@ -411,6 +418,7 @@ impl IntoFuture for PublisherBuilder<'_, '_> { impl Wait for PublicationBuilder<&Publisher<'_>, PublicationBuilderPut> { fn wait(self) -> ::To { self.publisher.session.resolve_put( + Some(&self.publisher.cache), &self.publisher.key_expr, self.kind.payload, SampleKind::Put, @@ -432,6 +440,7 @@ impl Wait for PublicationBuilder<&Publisher<'_>, PublicationBuilderPut> { impl Wait for PublicationBuilder<&Publisher<'_>, PublicationBuilderDelete> { fn wait(self) -> ::To { self.publisher.session.resolve_put( + Some(&self.publisher.cache), &self.publisher.key_expr, ZBytes::new(), SampleKind::Delete, diff --git a/zenoh/src/api/publisher.rs b/zenoh/src/api/publisher.rs index 78d44fc79..c50501d45 100644 --- a/zenoh/src/api/publisher.rs +++ b/zenoh/src/api/publisher.rs @@ -17,6 +17,7 @@ use std::{ fmt, future::{IntoFuture, Ready}, pin::Pin, + sync::atomic::AtomicU64, task::{Context, Poll}, }; @@ -100,6 +101,7 @@ impl fmt::Debug for PublisherState { #[derive(Debug)] pub struct Publisher<'a> { pub(crate) session: WeakSession, + pub(crate) cache: AtomicU64, pub(crate) id: Id, pub(crate) key_expr: KeyExpr<'a>, pub(crate) encoding: Encoding, @@ -385,6 +387,7 @@ impl Sink for Publisher<'_> { .. } = item.into(); self.session.resolve_put( + Some(&self.cache), &self.key_expr, payload, kind, diff --git a/zenoh/src/api/session.rs b/zenoh/src/api/session.rs index b679a13db..08f00a7cd 100644 --- a/zenoh/src/api/session.rs +++ b/zenoh/src/api/session.rs @@ -19,7 +19,7 @@ use std::{ fmt, ops::Deref, sync::{ - atomic::{AtomicU16, Ordering}, + atomic::{AtomicU16, AtomicU64, Ordering}, Arc, Mutex, RwLock, }, time::{Duration, SystemTime, UNIX_EPOCH}, @@ -121,6 +121,7 @@ zconfigurable! { } pub(crate) struct SessionState { + pub(crate) subscription_version: u64, pub(crate) primitives: Option>, // @TODO replace with MaybeUninit ?? pub(crate) expr_id_counter: AtomicExprId, // @TODO: manage rollover and uniqueness pub(crate) qid_counter: AtomicRequestId, @@ -156,6 +157,7 @@ impl SessionState { aggregated_publishers: Vec, ) -> SessionState { SessionState { + subscription_version: 0, primitives: None, expr_id_counter: AtomicExprId::new(1), // Note: start at 1 because 0 is reserved for NO_RESOURCE qid_counter: AtomicRequestId::new(0), @@ -1496,6 +1498,7 @@ impl SessionInner { callback: Callback, ) -> ZResult> { let mut state = zwrite!(self.state); + state.subscription_version += 1; tracing::trace!("declare_subscriber({:?})", key_expr); let id = self.runtime.next_id(); let (sub_state, declared_sub) = state.register_subscriber(id, key_expr, origin, callback); @@ -2079,11 +2082,11 @@ impl SessionInner { kind: SubscriberKind, #[cfg(feature = "unstable")] reliability: Reliability, attachment: Option, - ) { + ) -> bool { let mut callbacks = SingleOrVec::default(); let state = zread!(self.state); if state.primitives.is_none() { - return; // Session closing or closed + return false; // Session closing or closed } if key_expr.suffix.is_empty() { match state.get_res(&key_expr.scope, key_expr.mapping, local) { @@ -2101,11 +2104,11 @@ impl SessionInner { "Received Data for `{}`, which isn't a key expression", prefix ); - return; + return false; } None => { tracing::error!("Received Data for unknown expr_id: {}", key_expr.scope); - return; + return false; } } } else { @@ -2122,11 +2125,14 @@ impl SessionInner { } Err(err) => { tracing::error!("Received Data for unknown key_expr: {}", err); - return; + return false; } } }; drop(state); + if callbacks.is_empty() { + return false; + } let mut sample = info.clone().into_sample( // SAFETY: the keyexpr is valid unsafe { KeyExpr::from_str_unchecked("dummy") }, @@ -2144,11 +2150,14 @@ impl SessionInner { sample.key_expr = key_expr; cb.call(sample); } + true } #[allow(clippy::too_many_arguments)] // TODO fixme + #[inline(always)] pub(crate) fn resolve_put( &self, + cache: Option<&AtomicU64>, key_expr: &KeyExpr, payload: ZBytes, kind: SampleKind, @@ -2162,12 +2171,32 @@ impl SessionInner { #[cfg(feature = "unstable")] source_info: SourceInfo, attachment: Option, ) -> ZResult<()> { + const REMOTE_TAG: u64 = 0b01; + const LOCAL_TAG: u64 = 0b10; + const VERSION_SHIFT: u64 = 2; trace!("write({:?}, [...])", key_expr); - let primitives = zread!(self.state).primitives()?; + let state = zread!(self.state); + let primitives = state + .primitives + .as_ref() + .cloned() + .ok_or(SessionClosedError)?; + let mut cached = REMOTE_TAG | LOCAL_TAG; + let mut to_cache = REMOTE_TAG | LOCAL_TAG; + if let Some(cache) = cache { + cached = cache.load(Ordering::Relaxed); + let version = cached >> VERSION_SHIFT; + if version == state.subscription_version { + to_cache = cached; + } else { + to_cache = (state.subscription_version << VERSION_SHIFT) | REMOTE_TAG | LOCAL_TAG; + } + } + drop(state); let timestamp = timestamp.or_else(|| self.runtime.new_timestamp()); let wire_expr = key_expr.to_wire(self); - if destination != Locality::SessionLocal { - primitives.send_push( + if (to_cache & REMOTE_TAG) != 0 && destination != Locality::SessionLocal { + let remote = primitives.route_data( Push { wire_expr: wire_expr.to_owned(), ext_qos: push::ext::QoSType::new( @@ -2207,8 +2236,11 @@ impl SessionInner { #[cfg(not(feature = "unstable"))] Reliability::DEFAULT, ); + if !remote { + to_cache &= !REMOTE_TAG; + } } - if destination != Locality::Remote { + if (to_cache & LOCAL_TAG) != 0 && destination != Locality::Remote { let data_info = DataInfo { kind, encoding: Some(encoding), @@ -2222,7 +2254,7 @@ impl SessionInner { )), }; - self.execute_subscriber_callbacks( + let local = self.execute_subscriber_callbacks( true, &wire_expr, Some(data_info), @@ -2232,6 +2264,12 @@ impl SessionInner { reliability, attachment, ); + if !local { + to_cache &= !LOCAL_TAG; + } + } + if let Some(cache) = cache.filter(|_| to_cache != cached) { + let _ = cache.compare_exchange(cached, to_cache, Ordering::Relaxed, Ordering::Relaxed); } Ok(()) } @@ -2546,6 +2584,7 @@ impl Primitives for WeakSession { #[cfg(feature = "unstable")] { let mut state = zwrite!(self.state); + state.subscription_version += 1; if state.primitives.is_none() { return; // Session closing or closed } @@ -2798,7 +2837,7 @@ impl Primitives for WeakSession { #[cfg(feature = "unstable")] _reliability, m.ext_attachment.map(Into::into), - ) + ); } PushBody::Del(m) => { let info = DataInfo { @@ -2818,7 +2857,7 @@ impl Primitives for WeakSession { #[cfg(feature = "unstable")] _reliability, m.ext_attachment.map(Into::into), - ) + ); } } } diff --git a/zenoh/src/net/routing/dispatcher/face.rs b/zenoh/src/net/routing/dispatcher/face.rs index 6e1db6bbf..8ad41ba98 100644 --- a/zenoh/src/net/routing/dispatcher/face.rs +++ b/zenoh/src/net/routing/dispatcher/face.rs @@ -213,6 +213,11 @@ impl Face { state: Arc::downgrade(&self.state), } } + + #[inline] + pub fn route_data(&self, msg: Push, reliability: Reliability) -> bool { + route_data(&self.tables, &self.state, msg, reliability) + } } impl Primitives for Face { diff --git a/zenoh/src/net/routing/dispatcher/pubsub.rs b/zenoh/src/net/routing/dispatcher/pubsub.rs index c755e26a4..4782e9f00 100644 --- a/zenoh/src/net/routing/dispatcher/pubsub.rs +++ b/zenoh/src/net/routing/dispatcher/pubsub.rs @@ -305,7 +305,7 @@ macro_rules! treat_timestamp { "Error treating timestamp for received Data ({}). Drop it!", e ); - return; + return false; } else { data.timestamp = Some(hlc.new_timestamp()); tracing::error!( @@ -393,104 +393,104 @@ pub fn route_data( face: &FaceState, mut msg: Push, reliability: Reliability, -) { +) -> bool { let tables = zread!(tables_ref.tables); - match tables + let Some(prefix) = tables .get_mapping(face, &msg.wire_expr.scope, msg.wire_expr.mapping) .cloned() - { - Some(prefix) => { - tracing::trace!( - "{} Route data for res {}{}", - face, - prefix.expr(), - msg.wire_expr.suffix.as_ref() - ); - let mut expr = RoutingExpr::new(&prefix, msg.wire_expr.suffix.as_ref()); - - #[cfg(feature = "stats")] - let admin = expr.full_expr().starts_with("@/"); - #[cfg(feature = "stats")] - if !admin { - inc_stats!(face, rx, user, msg.payload) - } else { - inc_stats!(face, rx, admin, msg.payload) - } - - if tables.hat_code.ingress_filter(&tables, face, &mut expr) { - let res = Resource::get_resource(&prefix, expr.suffix); - - let route = get_data_route(&tables, face, &res, &mut expr, msg.ext_nodeid.node_id); - - if !route.is_empty() { - treat_timestamp!(&tables.hlc, msg.payload, tables.drop_future_timestamp); - - if route.len() == 1 { - let (outface, key_expr, context) = route.values().next().unwrap(); - if tables - .hat_code - .egress_filter(&tables, face, outface, &mut expr) - { - drop(tables); - #[cfg(feature = "stats")] - if !admin { - inc_stats!(face, tx, user, msg.payload) - } else { - inc_stats!(face, tx, admin, msg.payload) - } - - outface.primitives.send_push( - Push { - wire_expr: key_expr.into(), - ext_qos: msg.ext_qos, - ext_tstamp: msg.ext_tstamp, - ext_nodeid: ext::NodeIdType { node_id: *context }, - payload: msg.payload, - }, - reliability, - ) - } + else { + tracing::error!( + "{} Route data with unknown scope {}!", + face, + msg.wire_expr.scope + ); + return false; + }; + tracing::trace!( + "{} Route data for res {}{}", + face, + prefix.expr(), + msg.wire_expr.suffix.as_ref() + ); + let mut expr = RoutingExpr::new(&prefix, msg.wire_expr.suffix.as_ref()); + + #[cfg(feature = "stats")] + let admin = expr.full_expr().starts_with("@/"); + #[cfg(feature = "stats")] + if !admin { + inc_stats!(face, rx, user, msg.payload) + } else { + inc_stats!(face, rx, admin, msg.payload) + } + let mut routed = false; + if tables.hat_code.ingress_filter(&tables, face, &mut expr) { + let res = Resource::get_resource(&prefix, expr.suffix); + + let route = get_data_route(&tables, face, &res, &mut expr, msg.ext_nodeid.node_id); + + if !route.is_empty() { + treat_timestamp!(&tables.hlc, msg.payload, tables.drop_future_timestamp); + + if route.len() == 1 { + let (outface, key_expr, context) = route.values().next().unwrap(); + if tables + .hat_code + .egress_filter(&tables, face, outface, &mut expr) + { + drop(tables); + #[cfg(feature = "stats")] + if !admin { + inc_stats!(face, tx, user, msg.payload) } else { - let route = route - .values() - .filter(|(outface, _key_expr, _context)| { - tables - .hat_code - .egress_filter(&tables, face, outface, &mut expr) - }) - .cloned() - .collect::>(); - - drop(tables); - for (outface, key_expr, context) in route { - #[cfg(feature = "stats")] - if !admin { - inc_stats!(face, tx, user, msg.payload) - } else { - inc_stats!(face, tx, admin, msg.payload) - } + inc_stats!(face, tx, admin, msg.payload) + } - outface.primitives.send_push( - Push { - wire_expr: key_expr, - ext_qos: msg.ext_qos, - ext_tstamp: None, - ext_nodeid: ext::NodeIdType { node_id: context }, - payload: msg.payload.clone(), - }, - reliability, - ) + outface.primitives.send_push( + Push { + wire_expr: key_expr.into(), + ext_qos: msg.ext_qos, + ext_tstamp: msg.ext_tstamp, + ext_nodeid: ext::NodeIdType { node_id: *context }, + payload: msg.payload, + }, + reliability, + ); + routed = true; + } else { + let route = route + .values() + .filter(|(outface, _key_expr, _context)| { + tables + .hat_code + .egress_filter(&tables, face, outface, &mut expr) + }) + .cloned() + .collect::>(); + + drop(tables); + for (outface, key_expr, context) in route { + #[cfg(feature = "stats")] + if !admin { + inc_stats!(face, tx, user, msg.payload) + } else { + inc_stats!(face, tx, admin, msg.payload) } + + outface.primitives.send_push( + Push { + wire_expr: key_expr, + ext_qos: msg.ext_qos, + ext_tstamp: None, + ext_nodeid: ext::NodeIdType { node_id: context }, + payload: msg.payload.clone(), + }, + reliability, + ); + routed = true; } } } } - None => { - tracing::error!( - "{} Route data with unknown scope {}!", - face, - msg.wire_expr.scope - ); - } } + routed }