From a825511f500e794737e9e976ff186f04dbf0144b Mon Sep 17 00:00:00 2001 From: Michael Hoy Date: Thu, 14 Dec 2023 08:59:54 +0800 Subject: [PATCH 1/3] Add raw publisher --- r2r/src/lib.rs | 2 +- r2r/src/nodes.rs | 16 +++++++- r2r/src/publishers.rs | 56 ++++++++++++++++++++++++++++ r2r/tests/tokio_test_raw.rs | 73 +++++++++++++++++++++++++++++++++++++ 4 files changed, 145 insertions(+), 2 deletions(-) diff --git a/r2r/src/lib.rs b/r2r/src/lib.rs index 3bb5a52e9..bfae684ef 100644 --- a/r2r/src/lib.rs +++ b/r2r/src/lib.rs @@ -88,7 +88,7 @@ pub use utils::*; mod subscribers; mod publishers; -pub use publishers::{Publisher, PublisherUntyped}; +pub use publishers::{Publisher, PublisherUntyped, PublisherRaw}; mod services; pub use services::ServiceRequest; diff --git a/r2r/src/nodes.rs b/r2r/src/nodes.rs index d7a78b2d7..2be4a84be 100644 --- a/r2r/src/nodes.rs +++ b/r2r/src/nodes.rs @@ -791,7 +791,7 @@ impl Node { Ok(p) } - /// Create a ROS publisher with a type given at runtime. + /// Create a ROS publisher with a type given at runtime, where the data is supplied as JSON. pub fn create_publisher_untyped( &mut self, topic: &str, topic_type: &str, qos_profile: QosProfile, ) -> Result { @@ -804,6 +804,20 @@ impl Node { Ok(p) } + /// Create a ROS publisher with a type given at runtime, where the data is supplied as + /// a pre-serialized ROS message (i.e. &[u8]) + pub fn create_publisher_raw( + &mut self, topic: &str, topic_type: &str, qos_profile: QosProfile, + ) -> Result { + let dummy = WrappedNativeMsgUntyped::new_from(topic_type)?; + let publisher_handle = + create_publisher_helper(self.node_handle.as_mut(), topic, dummy.ts, qos_profile)?; + let arc = Arc::new(publisher_handle); + let p = make_publisher_raw(Arc::downgrade(&arc), topic_type.to_owned()); + self.pubs.push(arc); + Ok(p) + } + /// Spin the ROS node. /// /// This handles wakeups of all subscribes, services, etc on the diff --git a/r2r/src/publishers.rs b/r2r/src/publishers.rs index 55668c49a..32b95fed4 100644 --- a/r2r/src/publishers.rs +++ b/r2r/src/publishers.rs @@ -62,6 +62,18 @@ pub struct PublisherUntyped { type_: String, } +unsafe impl Send for PublisherRaw {} + +/// A ROS (raw) publisher. +/// +/// This contains a `Weak Arc` to an "untyped" publisher. As such it is safe to +/// move between threads. +#[derive(Debug, Clone)] +pub struct PublisherRaw { + handle: Weak, + type_: String, +} + pub fn make_publisher(handle: Weak) -> Publisher where T: WrappedTypesupport, @@ -76,6 +88,10 @@ pub fn make_publisher_untyped(handle: Weak, type_: String) -> P PublisherUntyped { handle, type_ } } +pub fn make_publisher_raw(handle: Weak, type_: String) -> PublisherRaw { + PublisherRaw { handle, type_ } +} + pub fn create_publisher_helper( node: &mut rcl_node_t, topic: &str, typesupport: *const rosidl_message_type_support_t, qos_profile: QosProfile, @@ -127,6 +143,46 @@ impl PublisherUntyped { } } + +impl PublisherRaw { + /// Publish an pre-serialized ROS message represented by a `&[u8]`. + /// + /// It is up to the user to make sure data is a valid ROS serialized message + pub fn publish(&self, data: &[u8]) -> Result<()> { + // TODO should this be an unsafe function? I'm not sure what happens if the data is malformed .. + + // upgrade to actual ref. if still alive + let publisher = self + .handle + .upgrade() + .ok_or(Error::RCL_RET_PUBLISHER_INVALID)?; + + // Safety: Not retained beyond this function + let msg_buf = rcl_serialized_message_t { + buffer: data.as_ptr() as *mut u8, + buffer_length: data.len(), + buffer_capacity: data.len(), + + // Since its read only, this should never be used .. + allocator: unsafe { rcutils_get_default_allocator() } + }; + + let result = + unsafe { rcl_publish_serialized_message( + publisher.as_ref(), + &msg_buf as *const rcl_serialized_message_t, + std::ptr::null_mut() + ) }; + + if result == RCL_RET_OK as i32 { + Ok(()) + } else { + log::error!("could not publish {}", result); + Err(Error::from_rcl_error(result)) + } + } +} + impl Publisher where T: WrappedTypesupport, diff --git a/r2r/tests/tokio_test_raw.rs b/r2r/tests/tokio_test_raw.rs index e619449a1..c0fed1e99 100644 --- a/r2r/tests/tokio_test_raw.rs +++ b/r2r/tests/tokio_test_raw.rs @@ -1,6 +1,7 @@ use futures::stream::StreamExt; use r2r::QosProfile; use tokio::task; +use r2r::WrappedTypesupport; #[tokio::test(flavor = "multi_thread")] async fn tokio_subscribe_raw_testing() -> Result<(), Box> { @@ -63,3 +64,75 @@ async fn tokio_subscribe_raw_testing() -> Result<(), Box> Ok(()) } + + + +#[tokio::test(flavor = "multi_thread")] +async fn tokio_publish_raw_testing() -> Result<(), Box> { + let ctx = r2r::Context::create()?; + let mut node = r2r::Node::create(ctx, "testnode2", "")?; + + let mut sub_int = node.subscribe::("/int", QosProfile::default())?; + + let mut sub_array = + node.subscribe::("/int_array", QosProfile::default())?; + + let pub_int = node.create_publisher_raw( + "/int", + "std_msgs/msg/Int32", + QosProfile::default() + )?; + + // Use an array as well since its a variable sized type + let pub_array = node.create_publisher_raw( + "/int_array", + "std_msgs/msg/Int32MultiArray", + QosProfile::default(), + )?; + + task::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + (0..10).for_each(|i| { + pub_int + .publish(&r2r::std_msgs::msg::Int32 { data: i }.to_serialized_bytes().unwrap()) + .unwrap(); + + pub_array + .publish(&r2r::std_msgs::msg::Int32MultiArray { + layout: r2r::std_msgs::msg::MultiArrayLayout::default(), + data: vec![i], + }.to_serialized_bytes().unwrap()) + .unwrap(); + }); + }); + + let sub_int_handle = task::spawn(async move { + while let Some(msg) = sub_int.next().await { + let len = msg.to_serialized_bytes().unwrap().len(); + + println!("Got int msg of len {len}"); + assert_eq!(len, 8); + } + }); + + let sub_array_handle = task::spawn(async move { + while let Some(msg) = sub_array.next().await { + let len = msg.to_serialized_bytes().unwrap().len(); + + println!("Got array msg of len {len}"); + assert_eq!(len, 20); + } + }); + + let handle = std::thread::spawn(move || { + for _ in 1..=30 { + node.spin_once(std::time::Duration::from_millis(100)); + } + }); + + sub_int_handle.await?; + sub_array_handle.await?; + handle.join().unwrap(); + + Ok(()) +} From 4774d7f5c4d1134ffd7ae1edc19149484d0232ab Mon Sep 17 00:00:00 2001 From: Michael Hoy Date: Thu, 14 Dec 2023 09:03:14 +0800 Subject: [PATCH 2/3] Add raw publisher examples --- r2r/examples/tokio_raw_publisher.rs | 27 +++++++++++++++++++++++++++ r2r/tests/tokio_test_raw.rs | 10 ++++++---- 2 files changed, 33 insertions(+), 4 deletions(-) create mode 100644 r2r/examples/tokio_raw_publisher.rs diff --git a/r2r/examples/tokio_raw_publisher.rs b/r2r/examples/tokio_raw_publisher.rs new file mode 100644 index 000000000..a11d5b08d --- /dev/null +++ b/r2r/examples/tokio_raw_publisher.rs @@ -0,0 +1,27 @@ +use r2r::QosProfile; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let ctx = r2r::Context::create()?; + let mut node = r2r::Node::create(ctx, "testnode", "")?; + let duration = std::time::Duration::from_millis(2500); + + let mut timer = node.create_wall_timer(duration)?; + let publisher = + node.create_publisher_raw("/topic", "std_msgs/msg/String", QosProfile::default())?; + + let handle = tokio::task::spawn_blocking(move || loop { + node.spin_once(std::time::Duration::from_millis(100)); + }); + + for _ in 1..10 { + timer.tick().await?; + let msg = r2r::std_msgs::msg::String { + data: "hello from r2r".to_string(), + }; + publisher.publish(&msg.to_serialized_bytes()?)?; + } + + handle.await?; + Ok(()) +} diff --git a/r2r/tests/tokio_test_raw.rs b/r2r/tests/tokio_test_raw.rs index c0fed1e99..9f2b5a123 100644 --- a/r2r/tests/tokio_test_raw.rs +++ b/r2r/tests/tokio_test_raw.rs @@ -98,10 +98,12 @@ async fn tokio_publish_raw_testing() -> Result<(), Box> { .unwrap(); pub_array - .publish(&r2r::std_msgs::msg::Int32MultiArray { - layout: r2r::std_msgs::msg::MultiArrayLayout::default(), - data: vec![i], - }.to_serialized_bytes().unwrap()) + .publish( + &r2r::std_msgs::msg::Int32MultiArray { + layout: r2r::std_msgs::msg::MultiArrayLayout::default(), + data: vec![i], + }.to_serialized_bytes().unwrap() + ) .unwrap(); }); }); From d8731caad99c2baab4f0ecaa324d34672025b2f5 Mon Sep 17 00:00:00 2001 From: Michael Hoy Date: Sat, 30 Dec 2023 00:36:38 +0800 Subject: [PATCH 3/3] Add more tests. Convert Context into singleton to avoid memory corruption when creatin multiple Context instances --- r2r/src/context.rs | 97 +++++++----- r2r/src/error.rs | 2 +- r2r/src/publishers.rs | 3 - r2r/tests/threads.rs | 102 ++++++------ r2r/tests/tokio_test_raw.rs | 298 ++++++++++++++++++++---------------- r2r/tests/tokio_testing.rs | 138 ++++++++++------- 6 files changed, 369 insertions(+), 271 deletions(-) diff --git a/r2r/src/context.rs b/r2r/src/context.rs index c73c3ab67..20c4cb6b2 100644 --- a/r2r/src/context.rs +++ b/r2r/src/context.rs @@ -2,6 +2,7 @@ use std::ffi::CStr; use std::ffi::CString; use std::fmt::Debug; use std::ops::{Deref, DerefMut}; +use std::sync::OnceLock; use std::sync::{Arc, Mutex}; use crate::error::*; @@ -29,50 +30,64 @@ macro_rules! check_rcl_ret { unsafe impl Send for Context {} +// Safety: Context is just a Arc> wrapper around ContextHandle +// so it should be safe to access from different threads +unsafe impl Sync for Context {} + +// Memory corruption (double free and others) was observed creating multiple +// `Context` objects in a single thread +// +// To reproduce, run the tests from `tokio_testing` or `tokio_test_raw` +// without this OnceLock + +static CONTEXT: OnceLock> = OnceLock::new(); + impl Context { /// Create a ROS context. pub fn create() -> Result { - let mut ctx: Box = unsafe { Box::new(rcl_get_zero_initialized_context()) }; - // argc/v - let args = std::env::args() - .map(|arg| CString::new(arg).unwrap()) - .collect::>(); - let mut c_args = args - .iter() - .map(|arg| arg.as_ptr()) - .collect::>(); - c_args.push(std::ptr::null()); - - let is_valid = unsafe { - let allocator = rcutils_get_default_allocator(); - let mut init_options = rcl_get_zero_initialized_init_options(); - check_rcl_ret!(rcl_init_options_init(&mut init_options, allocator)); - check_rcl_ret!(rcl_init( - (c_args.len() - 1) as ::std::os::raw::c_int, - c_args.as_ptr(), - &init_options, - ctx.as_mut(), - )); - check_rcl_ret!(rcl_init_options_fini(&mut init_options as *mut _)); - rcl_context_is_valid(ctx.as_mut()) - }; - - let logging_ok = unsafe { - let _guard = log_guard(); - let ret = rcl_logging_configure( - &ctx.as_ref().global_arguments, - &rcutils_get_default_allocator(), - ); - ret == RCL_RET_OK as i32 - }; - - if is_valid && logging_ok { - Ok(Context { - context_handle: Arc::new(Mutex::new(ContextHandle(ctx))), - }) - } else { - Err(Error::RCL_RET_ERROR) // TODO - } + CONTEXT.get_or_init(|| { + let mut ctx: Box = unsafe { Box::new(rcl_get_zero_initialized_context()) }; + // argc/v + let args = std::env::args() + .map(|arg| CString::new(arg).unwrap()) + .collect::>(); + let mut c_args = args + .iter() + .map(|arg| arg.as_ptr()) + .collect::>(); + c_args.push(std::ptr::null()); + + let is_valid = unsafe { + let allocator = rcutils_get_default_allocator(); + let mut init_options = rcl_get_zero_initialized_init_options(); + check_rcl_ret!(rcl_init_options_init(&mut init_options, allocator)); + check_rcl_ret!(rcl_init( + (c_args.len() - 1) as ::std::os::raw::c_int, + c_args.as_ptr(), + &init_options, + ctx.as_mut(), + )); + check_rcl_ret!(rcl_init_options_fini(&mut init_options as *mut _)); + rcl_context_is_valid(ctx.as_mut()) + }; + + let logging_ok = unsafe { + let _guard = log_guard(); + let ret = rcl_logging_configure( + &ctx.as_ref().global_arguments, + &rcutils_get_default_allocator(), + ); + ret == RCL_RET_OK as i32 + }; + + if is_valid && logging_ok { + Ok(Context { + context_handle: Arc::new(Mutex::new(ContextHandle(ctx))), + }) + } else { + Err(Error::RCL_RET_ERROR) // TODO + } + }).clone() } /// Check if the ROS context is valid. diff --git a/r2r/src/error.rs b/r2r/src/error.rs index 649b8fdf2..e0e29bd28 100644 --- a/r2r/src/error.rs +++ b/r2r/src/error.rs @@ -11,7 +11,7 @@ pub type Result = std::result::Result; /// These values are mostly copied straight from the RCL headers, but /// some are specific to r2r, such as `GoalCancelRejected` which does /// not have an analogue in the rcl. -#[derive(Error, Debug)] +#[derive(Error, Clone, Debug)] pub enum Error { #[error("RCL_RET_OK")] RCL_RET_OK, diff --git a/r2r/src/publishers.rs b/r2r/src/publishers.rs index 25c9dc9c9..f984b1dc1 100644 --- a/r2r/src/publishers.rs +++ b/r2r/src/publishers.rs @@ -128,7 +128,6 @@ pub struct PublisherUntyped { type_: String, } - pub fn make_publisher(handle: Weak) -> Publisher where T: WrappedTypesupport, @@ -146,7 +145,6 @@ pub fn make_publisher_untyped(handle: Weak, type_: String) -> Publis } } - pub fn create_publisher_helper( node: &mut rcl_node_t, topic: &str, typesupport: *const rosidl_message_type_support_t, qos_profile: QosProfile, @@ -268,7 +266,6 @@ impl PublisherUntyped { } - impl Publisher where T: WrappedTypesupport, diff --git a/r2r/tests/threads.rs b/r2r/tests/threads.rs index f99debdfe..0b7c2d1df 100644 --- a/r2r/tests/threads.rs +++ b/r2r/tests/threads.rs @@ -3,62 +3,78 @@ use std::time::Duration; use r2r::QosProfile; +const N_NODE_PER_CONTEXT: usize = 5; +const N_CONCURRENT_ROS_CONTEXT: usize = 2; +const N_TEARDOWN_CYCLES: usize = 2; + #[test] // Let's create and drop a lot of node and publishers for a while to see that we can cope. fn doesnt_crash() -> Result<(), Box> { - // a global shared context. - let ctx = r2r::Context::create()?; - for c in 0..10 { - let mut ths = Vec::new(); - // I have lowered this from 30 to 10 because cyclonedds can only handle a hard-coded number of - // publishers in threads. See - // https://github.com/eclipse-cyclonedds/cyclonedds/blob/cd2136d9321212bd52fdc613f07bbebfddd90dec/src/core/ddsc/src/dds_init.c#L115 - for i in 0..10 { - // create concurrent nodes that max out the cpu - let ctx = ctx.clone(); - ths.push(thread::spawn(move || { - let mut node = r2r::Node::create(ctx, &format!("testnode{}", i), "").unwrap(); + let threads = (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| std::thread::spawn(move || { + + for _i_cycle in 0..N_TEARDOWN_CYCLES { + // a global shared context. + let ctx = r2r::Context::create().unwrap(); - // each with 10 publishers - for _j in 0..10 { - let p = node - .create_publisher::( - &format!("/r2r{}", i), - QosProfile::default(), - ) - .unwrap(); - let to_send = r2r::std_msgs::msg::String { - data: format!("[node{}]: {}", i, c), - }; + for c in 0..10 { + let mut ths = Vec::new(); + // I have lowered this from 30 to (10 / N_CONCURRENT_ROS_CONTEXT) because cyclonedds can only handle a hard-coded number of + // publishers in threads. See + // https://github.com/eclipse-cyclonedds/cyclonedds/blob/cd2136d9321212bd52fdc613f07bbebfddd90dec/src/core/ddsc/src/dds_init.c#L115 + for i_node in 0..N_NODE_PER_CONTEXT { + // create concurrent nodes that max out the cpu + let ctx = ctx.clone(); + ths.push(thread::spawn(move || { + let mut node = r2r::Node::create(ctx, &format!("testnode_{}_{}", i_context, i_node), "").unwrap(); - // move publisher to its own thread and publish as fast as we can - thread::spawn(move || loop { - let res = p.publish(&to_send); - thread::sleep(Duration::from_millis(1)); - match res { - Ok(_) => (), - Err(_) => { - // println!("publisher died, quitting thread."); - break; - } + // each with 10 publishers + for _j in 0..10 { + let p = node + .create_publisher::( + &format!("/r2r{}", i_node), + QosProfile::default(), + ) + .unwrap(); + let to_send = r2r::std_msgs::msg::String { + data: format!("[node{}]: {}", i_node, c), + }; + + // move publisher to its own thread and publish as fast as we can + thread::spawn(move || loop { + let res = p.publish(&to_send); + thread::sleep(Duration::from_millis(1)); + match res { + Ok(_) => (), + Err(_) => { + // println!("publisher died, quitting thread."); + break; + } + } + }); } - }); + + // spin to simulate some load + for _j in 0..100 { + node.spin_once(Duration::from_millis(10)); + } + + // println!("all done {}-{}", c, i); + })); } - // spin to simulate some load - for _j in 0..100 { - node.spin_once(Duration::from_millis(10)); + for t in ths { + t.join().unwrap(); } + // println!("all threads done {}", c); - // println!("all done {}-{}", c, i); - })); + } } + + })); - for t in ths { - t.join().unwrap(); - } - // println!("all threads done {}", c); + for thread in threads.into_iter() { + thread.join().unwrap(); } Ok(()) diff --git a/r2r/tests/tokio_test_raw.rs b/r2r/tests/tokio_test_raw.rs index 3dec4961a..1819143d7 100644 --- a/r2r/tests/tokio_test_raw.rs +++ b/r2r/tests/tokio_test_raw.rs @@ -3,144 +3,180 @@ use r2r::QosProfile; use tokio::task; use r2r::WrappedTypesupport; -#[tokio::test(flavor = "multi_thread")] + +const N_CONCURRENT_ROS_CONTEXT: usize = 3; +const N_TEARDOWN_CYCLES: usize = 2; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tokio_subscribe_raw_testing() -> Result<(), Box> { - let ctx = r2r::Context::create()?; - let mut node = r2r::Node::create(ctx, "testnode2", "")?; - - let mut sub_int = node.subscribe_raw("/int", "std_msgs/msg/Int32", QosProfile::default())?; - - let mut sub_array = - node.subscribe_raw("/int_array", "std_msgs/msg/Int32MultiArray", QosProfile::default())?; - - let pub_int = - node.create_publisher::("/int", QosProfile::default())?; - - // Use an array as well since its a variable sized type - let pub_array = node.create_publisher::( - "/int_array", - QosProfile::default(), - )?; - - task::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - (0..10).for_each(|i| { - pub_int - .publish(&r2r::std_msgs::msg::Int32 { data: i }) - .unwrap(); - - pub_array - .publish(&r2r::std_msgs::msg::Int32MultiArray { - layout: r2r::std_msgs::msg::MultiArrayLayout::default(), - data: vec![i], - }) - .unwrap(); - }); - }); - - let sub_int_handle = task::spawn(async move { - while let Some(msg) = sub_int.next().await { - println!("Got int msg of len {}", msg.len()); - assert_eq!(msg.len(), 8); - } - }); - - let sub_array_handle = task::spawn(async move { - while let Some(msg) = sub_array.next().await { - println!("Got array msg of len {}", msg.len()); - assert_eq!(msg.len(), 20); - } - }); - - let handle = std::thread::spawn(move || { - for _ in 1..=30 { - node.spin_once(std::time::Duration::from_millis(100)); - } - }); - - sub_int_handle.await?; - sub_array_handle.await?; - handle.join().unwrap(); + let mut threads = futures::stream::FuturesUnordered::from_iter( + (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move { + // Iterate to check for memory corruption on node setup/teardown + for i_cycle in 0..N_TEARDOWN_CYCLES { + println!("tokio_subscribe_raw_testing iteration {i_cycle}"); + + let ctx = r2r::Context::create().unwrap(); + let mut node = r2r::Node::create(ctx, &format!("testnode2_{i_context}"), "").unwrap(); + + let mut sub_int = node.subscribe_raw("/int", "std_msgs/msg/Int32", QosProfile::default()).unwrap(); + + let mut sub_array = + node.subscribe_raw("/int_array", "std_msgs/msg/Int32MultiArray", QosProfile::default()).unwrap(); + + let pub_int = + node.create_publisher::("/int", QosProfile::default()).unwrap(); + + // Use an array as well since its a variable sized type + let pub_array = node.create_publisher::( + "/int_array", + QosProfile::default(), + ).unwrap(); + + task::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + (0..10).for_each(|i| { + pub_int + .publish(&r2r::std_msgs::msg::Int32 { data: i }) + .unwrap(); + + pub_array + .publish(&r2r::std_msgs::msg::Int32MultiArray { + layout: r2r::std_msgs::msg::MultiArrayLayout::default(), + data: vec![i], + }) + .unwrap(); + }); + }); + + let sub_int_handle = task::spawn(async move { + while let Some(msg) = sub_int.next().await { + println!("Got int msg of len {}", msg.len()); + assert_eq!(msg.len(), 8); + } + }); + + let sub_array_handle = task::spawn(async move { + while let Some(msg) = sub_array.next().await { + println!("Got array msg of len {}", msg.len()); + assert_eq!(msg.len(), 20); + } + }); + + let handle = std::thread::spawn(move || { + for _ in 1..=30 { + node.spin_once(std::time::Duration::from_millis(100)); + } + }); + + sub_int_handle.await.unwrap(); + sub_array_handle.await.unwrap(); + handle.join().unwrap(); + + println!("Going to drop tokio_subscribe_raw_testing iteration {i_cycle}"); + } + + }))); + + while let Some(thread) = threads.next().await { + thread.unwrap(); + } Ok(()) } - -#[tokio::test(flavor = "multi_thread")] +// Limit the number of threads to force threads to be reused +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tokio_publish_raw_testing() -> Result<(), Box> { - let ctx = r2r::Context::create()?; - let mut node = r2r::Node::create(ctx, "testnode2", "")?; - - let mut sub_int = node.subscribe::("/int", QosProfile::default())?; - - let mut sub_array = - node.subscribe::("/int_array", QosProfile::default())?; - - let pub_int = node.create_publisher_untyped( - "/int", - "std_msgs/msg/Int32", - QosProfile::default() - )?; - - // Use an array as well since its a variable sized type - let pub_array = node.create_publisher_untyped( - "/int_array", - "std_msgs/msg/Int32MultiArray", - QosProfile::default(), - )?; - - task::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - (0..10).for_each(|i| { - pub_int - .publish_raw(&r2r::std_msgs::msg::Int32 { data: i }.to_serialized_bytes().unwrap()) - .unwrap(); - - pub_array - .publish_raw( - &r2r::std_msgs::msg::Int32MultiArray { - layout: r2r::std_msgs::msg::MultiArrayLayout::default(), - data: vec![i], - }.to_serialized_bytes().unwrap() - ) - .unwrap(); - }); - }); - - let sub_int_handle = task::spawn(async move { - while let Some(msg) = sub_int.next().await { - // Try to check for any possible corruption - msg.to_serialized_bytes().unwrap(); - - println!("Got int msg with value {}", msg.data); - assert!(msg.data >= 0); - assert!(msg.data < 10); - - } - }); - - let sub_array_handle = task::spawn(async move { - while let Some(msg) = sub_array.next().await { - // Try to check for any possible corruption - msg.to_serialized_bytes().unwrap(); - - println!("Got array msg with value {:?}", msg.data); - assert_eq!(msg.data.len(), 1); - assert!(msg.data[0] >= 0); - assert!(msg.data[0] < 10); - } - }); - - let handle = std::thread::spawn(move || { - for _ in 1..=30 { - node.spin_once(std::time::Duration::from_millis(100)); - } - }); - - sub_int_handle.await?; - sub_array_handle.await?; - handle.join().unwrap(); + + let mut threads = futures::stream::FuturesUnordered::from_iter( + (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move { + // Iterate to check for memory corruption on node setup/teardown + for i_cycle in 0..N_TEARDOWN_CYCLES { + + println!("tokio_publish_raw_testing iteration {i_cycle}"); + + let ctx = r2r::Context::create().unwrap(); + let mut node = r2r::Node::create(ctx, &format!("testnode3_{i_context}"), "").unwrap(); + + let mut sub_int = node.subscribe::("/int", QosProfile::default()).unwrap(); + + let mut sub_array = + node.subscribe::("/int_array", QosProfile::default()).unwrap(); + + let pub_int = node.create_publisher_untyped( + "/int", + "std_msgs/msg/Int32", + QosProfile::default() + ).unwrap(); + + // Use an array as well since its a variable sized type + let pub_array = node.create_publisher_untyped( + "/int_array", + "std_msgs/msg/Int32MultiArray", + QosProfile::default(), + ).unwrap(); + + task::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + (0..10).for_each(|i| { + pub_int + .publish_raw(&r2r::std_msgs::msg::Int32 { data: i }.to_serialized_bytes().unwrap()) + .unwrap(); + + pub_array + .publish_raw( + &r2r::std_msgs::msg::Int32MultiArray { + layout: r2r::std_msgs::msg::MultiArrayLayout::default(), + data: vec![i], + }.to_serialized_bytes().unwrap() + ) + .unwrap(); + }); + }); + + let sub_int_handle = task::spawn(async move { + while let Some(msg) = sub_int.next().await { + // Try to check for any possible corruption + msg.to_serialized_bytes().unwrap(); + + println!("Got int msg with value {}", msg.data); + assert!(msg.data >= 0); + assert!(msg.data < 10); + + } + }); + + let sub_array_handle = task::spawn(async move { + while let Some(msg) = sub_array.next().await { + // Try to check for any possible corruption + msg.to_serialized_bytes().unwrap(); + + println!("Got array msg with value {:?}", msg.data); + assert_eq!(msg.data.len(), 1); + assert!(msg.data[0] >= 0); + assert!(msg.data[0] < 10); + } + }); + + let handle = std::thread::spawn(move || { + for _ in 1..=30 { + node.spin_once(std::time::Duration::from_millis(100)); + } + }); + + sub_int_handle.await.unwrap(); + sub_array_handle.await.unwrap(); + handle.join().unwrap(); + + println!("Going to drop tokio_publish_raw_testing iteration {i_cycle}"); + + } + }))); + + while let Some(thread) = threads.next().await { + thread.unwrap(); + } Ok(()) } diff --git a/r2r/tests/tokio_testing.rs b/r2r/tests/tokio_testing.rs index 0736cb58f..aa4077f93 100644 --- a/r2r/tests/tokio_testing.rs +++ b/r2r/tests/tokio_testing.rs @@ -4,60 +4,94 @@ use r2r::QosProfile; use std::sync::{Arc, Mutex}; use tokio::task; -#[tokio::test(flavor = "multi_thread")] +const N_CONCURRENT_ROS_CONTEXT: usize = 3; +const N_TEARDOWN_CYCLES: usize = 2; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tokio_testing() -> Result<(), Box> { - let ctx = r2r::Context::create()?; - let mut node = r2r::Node::create(ctx, "testnode", "")?; - let mut s_the_no = - node.subscribe::("/the_no", QosProfile::default())?; - let mut s_new_no = - node.subscribe::("/new_no", QosProfile::default())?; - let p_the_no = - node.create_publisher::("/the_no", QosProfile::default())?; - let p_new_no = - node.create_publisher::("/new_no", QosProfile::default())?; - let state = Arc::new(Mutex::new(0)); - - task::spawn(async move { - (0..10).for_each(|i| { - p_the_no - .publish(&r2r::std_msgs::msg::Int32 { data: i }) - .unwrap(); - }); - }); - - task::spawn(async move { - while let Some(msg) = s_the_no.next().await { - p_new_no - .publish(&r2r::std_msgs::msg::Int32 { - data: msg.data + 10, - }) - .unwrap(); - } - }); - - let s = state.clone(); - task::spawn(async move { - while let Some(msg) = s_new_no.next().await { - let i = msg.data; - if i == 19 { - *s.lock().unwrap() = 19; - } - } - }); - - let handle = std::thread::spawn(move || { - for _ in 1..=30 { - node.spin_once(std::time::Duration::from_millis(100)); - let x = state.lock().unwrap(); - if *x == 19 { - break; + + let mut threads = futures::stream::FuturesUnordered::from_iter( + (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move { + // Iterate to check for memory corruption on node setup/teardown + for i_cycle in 0..N_TEARDOWN_CYCLES { + + println!("tokio_testing iteration {i_cycle}"); + + let ctx = r2r::Context::create().unwrap(); + // let ctx = std::thread::spawn(|| r2r::Context::create().unwrap()).join().unwrap(); + + let mut node = r2r::Node::create(ctx, &format!("testnode_{i_context}"), "").unwrap(); + let mut s_the_no = + node.subscribe::(&format!("/the_no_{i_context}"), QosProfile::default()).unwrap(); + let mut s_new_no = + node.subscribe::(&format!("/new_no_{i_context}"), QosProfile::default()).unwrap(); + let p_the_no = + node.create_publisher::(&format!("/the_no_{i_context}"), QosProfile::default()).unwrap(); + let p_new_no = + node.create_publisher::(&format!("/new_no_{i_context}"), QosProfile::default()).unwrap(); + let state = Arc::new(Mutex::new(0)); + + task::spawn(async move { + (0..10).for_each(|i| { + p_the_no + .publish(&r2r::std_msgs::msg::Int32 { data: i }) + .unwrap(); + + println!("send {i}"); + + }); + }); + + task::spawn(async move { + while let Some(msg) = s_the_no.next().await { + p_new_no + .publish(&r2r::std_msgs::msg::Int32 { + data: msg.data + 10, + }) + .unwrap(); + + println!("got {}, send {}", msg.data, msg.data + 10); + } + }); + + let s = state.clone(); + task::spawn(async move { + while let Some(msg) = s_new_no.next().await { + + println!("got {}", msg.data); + + let i = msg.data; + + *s.lock().unwrap() = i; + } + }); + + // std::thread::spawn doesn't work here anymore? + let handle = task::spawn_blocking(move || { + for _ in 1..30 { + node.spin_once(std::time::Duration::from_millis(100)); + let x = state.lock().unwrap(); + + println!("rec {}", x); + + if *x == 19 { + break; + } + } + + *state.lock().unwrap() + }); + let x = handle.await.unwrap(); + assert_eq!(x, 19); + + println!("tokio_testing finish iteration {i_cycle}"); + } - } + }))); + + while let Some(thread) = threads.next().await { + thread.unwrap(); + } - *state.lock().unwrap() - }); - let x = handle.join().unwrap(); - assert_eq!(x, 19); Ok(()) }