diff --git a/commons/zenoh-keyexpr/src/key_expr/include.rs b/commons/zenoh-keyexpr/src/key_expr/include.rs index 46a4e40d69..a8dbdaf4cd 100644 --- a/commons/zenoh-keyexpr/src/key_expr/include.rs +++ b/commons/zenoh-keyexpr/src/key_expr/include.rs @@ -22,9 +22,33 @@ pub trait Includer { impl Includer<&'a [u8], &'a [u8]>> Includer<&keyexpr, &keyexpr> for T { fn includes(&self, left: &keyexpr, right: &keyexpr) -> bool { - let left = left.as_bytes(); - let right = right.as_bytes(); - if left == right || left == b"**" { + let mut left = left.as_bytes(); + let mut right = right.as_bytes(); + if left == right { + return true; + } + + if unsafe { *left.get_unchecked(0) == b'@' || *right.get_unchecked(0) == b'@' } { + let mut end = left.len().min(right.len()); + for i in 0..end { + if left[i] != right[i] { + return false; + } + if left[i] == DELIMITER { + end = i; + break; + } + } + if left.len() == end { + return false; + } + if right.len() == end { + return left.get(end..) == Some(b"/**"); + } + left = &left[(end + 1)..]; + right = &right[(end + 1)..]; + } + if left == b"**" { return true; } self.includes(left, right) diff --git a/commons/zenoh-keyexpr/src/key_expr/intersect/mod.rs b/commons/zenoh-keyexpr/src/key_expr/intersect/mod.rs index bda0677404..cebc194d16 100644 --- a/commons/zenoh-keyexpr/src/key_expr/intersect/mod.rs +++ b/commons/zenoh-keyexpr/src/key_expr/intersect/mod.rs @@ -12,6 +12,8 @@ // ZettaScale Zenoh Team, // +use crate::DELIMITER; + use super::keyexpr; mod classical; @@ -88,13 +90,37 @@ impl< > Intersector<&keyexpr, &keyexpr> for T { fn intersect(&self, left: &keyexpr, right: &keyexpr) -> bool { - let left_bytes = left.as_bytes(); - let right_bytes = right.as_bytes(); + let mut left_bytes = left.as_bytes(); + let mut right_bytes = right.as_bytes(); if left_bytes == right_bytes { return true; } - match left.match_complexity() as u8 | right.match_complexity() as u8 { - 0 => false, + let complexity = left.match_complexity() as u8 | right.match_complexity() as u8; + if complexity == 0 { + return false; + } + if unsafe { *left_bytes.get_unchecked(0) == b'@' || *right_bytes.get_unchecked(0) == b'@' } + { + let mut end = left_bytes.len().min(right_bytes.len()); + for i in 0..end { + if left_bytes[i] != right_bytes[i] { + return false; + } + if left_bytes[i] == DELIMITER { + end = i; + break; + } + } + if left_bytes.len() == end { + return right_bytes.get(end..) == Some(b"/**"); + } + if right_bytes.len() == end { + return left_bytes.get(end..) == Some(b"/**"); + } + left_bytes = &left_bytes[(end + 1)..]; + right_bytes = &right_bytes[(end + 1)..]; + } + match complexity { 1 => self.intersect(NoSubWilds(left_bytes), NoSubWilds(right_bytes)), _ => self.intersect(left_bytes, right_bytes), } diff --git a/commons/zenoh-keyexpr/src/key_expr/tests.rs b/commons/zenoh-keyexpr/src/key_expr/tests.rs index ccba9b1ff6..002977255b 100644 --- a/commons/zenoh-keyexpr/src/key_expr/tests.rs +++ b/commons/zenoh-keyexpr/src/key_expr/tests.rs @@ -84,6 +84,16 @@ fn intersections() { assert!(intersect("x/a$*d$*e", "x/ade")); assert!(!intersect("x/c$*", "x/abc$*")); assert!(!intersect("x/$*d", "x/$*e")); + + assert!(intersect("@a", "@a")); + assert!(!intersect("@a", "@ab")); + assert!(!intersect("@a", "@a/b")); + assert!(!intersect("@a", "@a/*")); + assert!(!intersect("@a", "@a/*/**")); + assert!(!intersect("@a", "@a$*/**")); + assert!(intersect("@a", "@a/**")); + assert!(!intersect("**/xyz$*xyz", "@a/b/xyzdefxyz")); + assert!(intersect("@a/**/c/**/e", "@a/b/b/b/c/d/d/d/e")); } fn includes< @@ -146,6 +156,17 @@ fn inclusions() { assert!(!includes("x/c$*", "x/abc$*")); assert!(includes("x/$*c$*", "x/abc$*")); assert!(!includes("x/$*d", "x/$*e")); + + assert!(includes("@a", "@a")); + assert!(!includes("@a", "@ab")); + assert!(!includes("@a", "@a/b")); + assert!(!includes("@a", "@a/*")); + assert!(!includes("@a", "@a/*/**")); + assert!(!includes("@a$*/**", "@a")); + assert!(!includes("@a", "@a/**")); + assert!(includes("@a/**", "@a")); + assert!(!includes("**/xyz$*xyz", "@a/b/xyzdefxyz")); + assert!(includes("@a/**/c/**/e", "@a/b/b/b/c/d/d/d/e")); } #[test]