diff --git a/commons/zenoh-shm/src/posix_shm/segment.rs b/commons/zenoh-shm/src/posix_shm/segment.rs index 5458ab3e3e..657976ece1 100644 --- a/commons/zenoh-shm/src/posix_shm/segment.rs +++ b/commons/zenoh-shm/src/posix_shm/segment.rs @@ -12,10 +12,7 @@ // ZettaScale Zenoh Team, // -use std::{ - fmt::{Debug, Display}, - mem::size_of, -}; +use std::fmt::{Debug, Display}; use rand::Rng; use shared_memory::{Shmem, ShmemConf, ShmemError}; @@ -63,7 +60,7 @@ where // If creation fails because segment already exists for this id, // the creation attempt will be repeated with another id match ShmemConf::new() - .size(alloc_size + size_of::()) + .size(alloc_size) .os_id(Self::os_id(id.clone(), id_prefix)) .create() { @@ -71,7 +68,6 @@ where tracing::debug!( "Created SHM segment, size: {alloc_size}, prefix: {id_prefix}, id: {id}" ); - unsafe { *(shmem.as_ptr() as *mut usize) = alloc_size }; return Ok(Segment { shmem, id }); } Err(ShmemError::LinkExists) => {} @@ -94,10 +90,6 @@ where ) })?; - if shmem.len() <= size_of::() { - bail!("SHM segment too small") - } - tracing::debug!("Opened SHM segment, prefix: {id_prefix}, id: {id}"); Ok(Self { shmem, id }) @@ -110,17 +102,21 @@ where } pub fn as_ptr(&self) -> *mut u8 { - unsafe { self.shmem.as_ptr().add(size_of::()) } + self.shmem.as_ptr() } + /// Returns the length of this [`Segment`]. + /// NOTE: one some platforms (at least windows) the returned len will be the actual length of an shm segment + /// (a required len rounded up to the nearest multiply of page size), on other (at least linux and macos) this + /// returns a value requested upon segment creation pub fn len(&self) -> usize { - unsafe { *(self.shmem.as_ptr() as *mut usize) } + self.shmem.len() } // TODO: dead code warning occurs because of `tested_crate_module!()` macro when feature `test` is not enabled. Better to fix that #[allow(dead_code)] pub fn is_empty(&self) -> bool { - unsafe { *(self.shmem.as_ptr() as *mut usize) == 0 } + self.len() == 0 } pub fn id(&self) -> ID { diff --git a/commons/zenoh-shm/tests/posix_array.rs b/commons/zenoh-shm/tests/posix_array.rs index 562102ea17..83fdad88fb 100644 --- a/commons/zenoh-shm/tests/posix_array.rs +++ b/commons/zenoh-shm/tests/posix_array.rs @@ -41,25 +41,25 @@ impl TestElem { } fn validate_array( - array1: &mut ArrayInSHM, - array2: &ArrayInSHM, + created_array: &mut ArrayInSHM, + opened_array: &ArrayInSHM, expected_elem_count: usize, ) where ElemIndex: Unsigned + PrimInt + 'static + AsPrimitive, isize: AsPrimitive, usize: AsPrimitive, { - assert!(array1.elem_count() == expected_elem_count); - assert!(array2.elem_count() == expected_elem_count); + assert!(created_array.elem_count() == expected_elem_count); + assert!(opened_array.elem_count() >= expected_elem_count); let mut fill_ctr = 0; let mut validate_ctr = 0; // first of all, fill and validate elements sequentially - for i in 0..array1.elem_count() { + for i in 0..expected_elem_count { unsafe { - let elem1 = &mut *array1.elem_mut(i.as_()); - let elem2 = &*array2.elem(i.as_()); + let elem1 = &mut *created_array.elem_mut(i.as_()); + let elem2 = &*opened_array.elem(i.as_()); elem1.fill(&mut fill_ctr); elem2.validate(&mut validate_ctr); @@ -67,17 +67,17 @@ fn validate_array( } // then fill all the elements... - for i in 0..array1.elem_count() { + for i in 0..expected_elem_count { unsafe { - let elem1 = &mut *array1.elem_mut(i.as_()); + let elem1 = &mut *created_array.elem_mut(i.as_()); elem1.fill(&mut fill_ctr); } } // ...and validate all the elements - for i in 0..array2.elem_count() { + for i in 0..expected_elem_count { unsafe { - let elem2 = &*array2.elem(i.as_()); + let elem2 = &*opened_array.elem(i.as_()); elem2.validate(&mut validate_ctr); } } diff --git a/commons/zenoh-shm/tests/posix_segment.rs b/commons/zenoh-shm/tests/posix_segment.rs index 094ae40a85..879fccf298 100644 --- a/commons/zenoh-shm/tests/posix_segment.rs +++ b/commons/zenoh-shm/tests/posix_segment.rs @@ -19,18 +19,22 @@ use zenoh_shm::posix_shm::segment::Segment; pub mod common; use common::{validate_memory, TEST_SEGMENT_PREFIX}; -fn validate_segment(segment1: &Segment, segment2: &Segment) -where +fn validate_segment( + created_segment: &Segment, + opened_segment: &Segment, + expected_elem_count: usize, +) where rand::distributions::Standard: rand::distributions::Distribution, ID: Clone + Display, { - assert!(segment1.len() == segment2.len()); + assert!(created_segment.len() == expected_elem_count); + assert!(opened_segment.len() >= expected_elem_count); - let ptr1 = segment1.as_ptr(); - let ptr2 = segment2.as_ptr(); + let ptr1 = created_segment.as_ptr(); + let ptr2 = opened_segment.as_ptr(); - let slice1 = unsafe { slice::from_raw_parts_mut(ptr1, segment1.len()) }; - let slice2 = unsafe { slice::from_raw_parts(ptr2, segment2.len()) }; + let slice1 = unsafe { slice::from_raw_parts_mut(ptr1, expected_elem_count) }; + let slice2 = unsafe { slice::from_raw_parts(ptr2, expected_elem_count) }; validate_memory(slice1, slice2); } @@ -40,22 +44,24 @@ where rand::distributions::Standard: rand::distributions::Distribution, ID: Copy + Clone + Display, { - let new_segment: Segment = - Segment::create(900, TEST_SEGMENT_PREFIX).expect("error creating new segment"); + let elem_count = 900; + + let created_segment: Segment = + Segment::create(elem_count, TEST_SEGMENT_PREFIX).expect("error creating new segment"); - let opened_segment_instance_1 = Segment::open(new_segment.id(), TEST_SEGMENT_PREFIX) + let opened_segment_instance_1 = Segment::open(created_segment.id(), TEST_SEGMENT_PREFIX) .expect("error opening existing segment!"); - validate_segment(&new_segment, &opened_segment_instance_1); + validate_segment(&created_segment, &opened_segment_instance_1, elem_count); - let opened_segment_instance_2 = Segment::open(new_segment.id(), TEST_SEGMENT_PREFIX) + let opened_segment_instance_2 = Segment::open(created_segment.id(), TEST_SEGMENT_PREFIX) .expect("error opening existing segment!"); - validate_segment(&new_segment, &opened_segment_instance_1); - validate_segment(&new_segment, &opened_segment_instance_2); + validate_segment(&created_segment, &opened_segment_instance_1, elem_count); + validate_segment(&created_segment, &opened_segment_instance_2, elem_count); drop(opened_segment_instance_1); - validate_segment(&new_segment, &opened_segment_instance_2); + validate_segment(&created_segment, &opened_segment_instance_2, elem_count); } /// UNSIGNED /// @@ -116,19 +122,19 @@ fn segment_i128_id() { #[test] fn segment_open() { - let new_segment: Segment = + let created_segment: Segment = Segment::create(900, TEST_SEGMENT_PREFIX).expect("error creating new segment"); - let _opened_segment = Segment::open(new_segment.id(), TEST_SEGMENT_PREFIX) + let _opened_segment = Segment::open(created_segment.id(), TEST_SEGMENT_PREFIX) .expect("error opening existing segment!"); } #[test] fn segment_open_error() { let id = { - let new_segment: Segment = + let created_segment: Segment = Segment::create(900, TEST_SEGMENT_PREFIX).expect("error creating new segment"); - new_segment.id() + created_segment.id() }; let _opened_segment = Segment::open(id, TEST_SEGMENT_PREFIX) diff --git a/io/zenoh-transport/src/unicast/establishment/ext/shm.rs b/io/zenoh-transport/src/unicast/establishment/ext/shm.rs index e2068af94a..025aaaef44 100644 --- a/io/zenoh-transport/src/unicast/establishment/ext/shm.rs +++ b/io/zenoh-transport/src/unicast/establishment/ext/shm.rs @@ -36,6 +36,10 @@ const AUTH_SEGMENT_PREFIX: &str = "auth"; pub(crate) type AuthSegmentID = u32; pub(crate) type AuthChallenge = u64; +const LEN_INDEX: usize = 0; +const CHALLENGE_INDEX: usize = 1; +const ID_START_INDEX: usize = 2; + #[derive(Debug)] pub struct AuthSegment { array: ArrayInSHM, @@ -44,13 +48,14 @@ pub struct AuthSegment { impl AuthSegment { pub fn create(challenge: AuthChallenge, shm_protocols: &[ProtocolID]) -> ZResult { let array = ArrayInSHM::::create( - 1 + shm_protocols.len(), + ID_START_INDEX + shm_protocols.len(), AUTH_SEGMENT_PREFIX, )?; unsafe { - (*array.elem_mut(0)) = challenge; - for elem in 1..array.elem_count() { - (*array.elem_mut(elem)) = shm_protocols[elem - 1] as u64; + (*array.elem_mut(LEN_INDEX)) = shm_protocols.len() as AuthChallenge; + (*array.elem_mut(CHALLENGE_INDEX)) = challenge; + for elem in ID_START_INDEX..array.elem_count() { + (*array.elem_mut(elem)) = shm_protocols[elem - ID_START_INDEX] as u64; } }; Ok(Self { array }) @@ -62,12 +67,12 @@ impl AuthSegment { } pub fn challenge(&self) -> AuthChallenge { - unsafe { *self.array.elem(0) } + unsafe { *self.array.elem(CHALLENGE_INDEX) } } pub fn protocols(&self) -> Vec { let mut result = vec![]; - for elem in 1..self.array.elem_count() { + for elem in ID_START_INDEX..self.array.elem_count() { result.push(unsafe { *self.array.elem(elem) as u32 }); } result