diff --git a/zenoh/src/api/builders/publisher.rs b/zenoh/src/api/builders/publisher.rs index e4a879235..1264cf39d 100644 --- a/zenoh/src/api/builders/publisher.rs +++ b/zenoh/src/api/builders/publisher.rs @@ -446,8 +446,7 @@ impl Wait for PublisherBuilder<'_, '_> { .declare_publisher_inner(key_expr.clone(), self.destination)?; Ok(Publisher { session: self.session.downgrade(), - // TODO use constants here - cache: AtomicU64::new(0b11), + cache: AtomicU64::new(0), id, key_expr, encoding: self.encoding, diff --git a/zenoh/src/api/session.rs b/zenoh/src/api/session.rs index f0a2b1b29..bdc0be249 100644 --- a/zenoh/src/api/session.rs +++ b/zenoh/src/api/session.rs @@ -2153,8 +2153,8 @@ impl SessionInner { #[cfg(feature = "unstable")] source_info: SourceInfo, attachment: Option, ) -> ZResult<()> { - const REMOTE_TAG: u64 = 0b01; - const LOCAL_TAG: u64 = 0b10; + const NO_REMOTE_FLAG: u64 = 0b01; + const NO_LOCAL_FLAG: u64 = 0b10; const VERSION_SHIFT: u64 = 2; trace!("write({:?}, [...])", key_expr); let state = zread!(self.state); @@ -2163,21 +2163,26 @@ impl SessionInner { .as_ref() .cloned() .ok_or(SessionClosedError)?; - let mut cached = REMOTE_TAG | LOCAL_TAG; - let mut to_cache = REMOTE_TAG | LOCAL_TAG; + let version = state.subscription_version; + drop(state); + let mut cached = 0; + let mut update_cache = None; if let Some(cache) = cache { - cached = cache.load(Ordering::Relaxed); - let version = cached >> VERSION_SHIFT; - if version == state.subscription_version { - to_cache = cached; + let c = cache.load(Ordering::Relaxed); + if (c >> VERSION_SHIFT) == version { + cached = c; } else { - to_cache = (state.subscription_version << VERSION_SHIFT) | REMOTE_TAG | LOCAL_TAG; + cached = version << VERSION_SHIFT; } + update_cache = Some(move |cached| { + if cached != c { + let _ = cache.compare_exchange(c, cached, Ordering::Relaxed, Ordering::Relaxed); + } + }); } - drop(state); let timestamp = timestamp.or_else(|| self.runtime.new_timestamp()); let wire_expr = key_expr.to_wire(self); - if (to_cache & REMOTE_TAG) != 0 && destination != Locality::SessionLocal { + if (cached & NO_REMOTE_FLAG) == 0 && destination != Locality::SessionLocal { let remote = primitives.route_data( Push { wire_expr: wire_expr.to_owned(), @@ -2219,10 +2224,10 @@ impl SessionInner { Reliability::DEFAULT, ); if !remote { - to_cache &= !REMOTE_TAG; + cached |= NO_REMOTE_FLAG } } - if (to_cache & LOCAL_TAG) != 0 && destination != Locality::Remote { + if (cached & NO_LOCAL_FLAG) == 0 && destination != Locality::Remote { let data_info = DataInfo { kind, encoding: Some(encoding), @@ -2247,11 +2252,11 @@ impl SessionInner { attachment, ); if !local { - to_cache &= !LOCAL_TAG; + cached |= NO_LOCAL_FLAG; } } - if let Some(cache) = cache.filter(|_| to_cache != cached) { - let _ = cache.compare_exchange(cached, to_cache, Ordering::Relaxed, Ordering::Relaxed); + if let Some(update) = update_cache { + update(cached); } Ok(()) } @@ -2563,10 +2568,10 @@ impl Primitives for WeakSession { } zenoh_protocol::network::DeclareBody::DeclareSubscriber(m) => { trace!("recv DeclareSubscriber {} {:?}", m.id, m.wire_expr); + let mut state = zwrite!(self.state); + state.subscription_version += 1; #[cfg(feature = "unstable")] { - let mut state = zwrite!(self.state); - state.subscription_version += 1; if state.primitives.is_none() { return; // Session closing or closed }