From cb87b9c01c3fee430a90fff352523e1fdd0511ac Mon Sep 17 00:00:00 2001 From: Michael Hoy Date: Sun, 7 Jan 2024 18:46:59 +0800 Subject: [PATCH] Add raw publisher, convert Context to singleton (#76) * Add raw publisher * Add raw publisher examples * Add more tests. Convert Context into singleton to avoid memory corruption when creatin multiple Context instances --- r2r/examples/tokio_raw_publisher.rs | 28 ++++ r2r/src/context.rs | 97 +++++++----- r2r/src/error.rs | 2 +- r2r/src/nodes.rs | 4 +- r2r/src/publishers.rs | 37 +++++ r2r/tests/threads.rs | 102 ++++++------ r2r/tests/tokio_test_raw.rs | 231 +++++++++++++++++++++------- r2r/tests/tokio_testing.rs | 138 ++++++++++------- 8 files changed, 444 insertions(+), 195 deletions(-) create mode 100644 r2r/examples/tokio_raw_publisher.rs diff --git a/r2r/examples/tokio_raw_publisher.rs b/r2r/examples/tokio_raw_publisher.rs new file mode 100644 index 000000000..8cccdf5cf --- /dev/null +++ b/r2r/examples/tokio_raw_publisher.rs @@ -0,0 +1,28 @@ +use r2r::QosProfile; +use r2r::WrappedTypesupport; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let ctx = r2r::Context::create()?; + let mut node = r2r::Node::create(ctx, "testnode", "")?; + let duration = std::time::Duration::from_millis(2500); + + let mut timer = node.create_wall_timer(duration)?; + let publisher = + node.create_publisher_untyped("/topic", "std_msgs/msg/String", QosProfile::default())?; + + let handle = tokio::task::spawn_blocking(move || loop { + node.spin_once(std::time::Duration::from_millis(100)); + }); + + for _ in 1..10 { + timer.tick().await?; + let msg = r2r::std_msgs::msg::String { + data: "hello from r2r".to_string(), + }; + publisher.publish_raw(&msg.to_serialized_bytes()?)?; + } + + handle.await?; + Ok(()) +} diff --git a/r2r/src/context.rs b/r2r/src/context.rs index c73c3ab67..20c4cb6b2 100644 --- a/r2r/src/context.rs +++ b/r2r/src/context.rs @@ -2,6 +2,7 @@ use std::ffi::CStr; use std::ffi::CString; use std::fmt::Debug; use std::ops::{Deref, DerefMut}; +use std::sync::OnceLock; use std::sync::{Arc, Mutex}; use crate::error::*; @@ -29,50 +30,64 @@ macro_rules! check_rcl_ret { unsafe impl Send for Context {} +// Safety: Context is just a Arc> wrapper around ContextHandle +// so it should be safe to access from different threads +unsafe impl Sync for Context {} + +// Memory corruption (double free and others) was observed creating multiple +// `Context` objects in a single thread +// +// To reproduce, run the tests from `tokio_testing` or `tokio_test_raw` +// without this OnceLock + +static CONTEXT: OnceLock> = OnceLock::new(); + impl Context { /// Create a ROS context. pub fn create() -> Result { - let mut ctx: Box = unsafe { Box::new(rcl_get_zero_initialized_context()) }; - // argc/v - let args = std::env::args() - .map(|arg| CString::new(arg).unwrap()) - .collect::>(); - let mut c_args = args - .iter() - .map(|arg| arg.as_ptr()) - .collect::>(); - c_args.push(std::ptr::null()); - - let is_valid = unsafe { - let allocator = rcutils_get_default_allocator(); - let mut init_options = rcl_get_zero_initialized_init_options(); - check_rcl_ret!(rcl_init_options_init(&mut init_options, allocator)); - check_rcl_ret!(rcl_init( - (c_args.len() - 1) as ::std::os::raw::c_int, - c_args.as_ptr(), - &init_options, - ctx.as_mut(), - )); - check_rcl_ret!(rcl_init_options_fini(&mut init_options as *mut _)); - rcl_context_is_valid(ctx.as_mut()) - }; - - let logging_ok = unsafe { - let _guard = log_guard(); - let ret = rcl_logging_configure( - &ctx.as_ref().global_arguments, - &rcutils_get_default_allocator(), - ); - ret == RCL_RET_OK as i32 - }; - - if is_valid && logging_ok { - Ok(Context { - context_handle: Arc::new(Mutex::new(ContextHandle(ctx))), - }) - } else { - Err(Error::RCL_RET_ERROR) // TODO - } + CONTEXT.get_or_init(|| { + let mut ctx: Box = unsafe { Box::new(rcl_get_zero_initialized_context()) }; + // argc/v + let args = std::env::args() + .map(|arg| CString::new(arg).unwrap()) + .collect::>(); + let mut c_args = args + .iter() + .map(|arg| arg.as_ptr()) + .collect::>(); + c_args.push(std::ptr::null()); + + let is_valid = unsafe { + let allocator = rcutils_get_default_allocator(); + let mut init_options = rcl_get_zero_initialized_init_options(); + check_rcl_ret!(rcl_init_options_init(&mut init_options, allocator)); + check_rcl_ret!(rcl_init( + (c_args.len() - 1) as ::std::os::raw::c_int, + c_args.as_ptr(), + &init_options, + ctx.as_mut(), + )); + check_rcl_ret!(rcl_init_options_fini(&mut init_options as *mut _)); + rcl_context_is_valid(ctx.as_mut()) + }; + + let logging_ok = unsafe { + let _guard = log_guard(); + let ret = rcl_logging_configure( + &ctx.as_ref().global_arguments, + &rcutils_get_default_allocator(), + ); + ret == RCL_RET_OK as i32 + }; + + if is_valid && logging_ok { + Ok(Context { + context_handle: Arc::new(Mutex::new(ContextHandle(ctx))), + }) + } else { + Err(Error::RCL_RET_ERROR) // TODO + } + }).clone() } /// Check if the ROS context is valid. diff --git a/r2r/src/error.rs b/r2r/src/error.rs index 649b8fdf2..e0e29bd28 100644 --- a/r2r/src/error.rs +++ b/r2r/src/error.rs @@ -11,7 +11,7 @@ pub type Result = std::result::Result; /// These values are mostly copied straight from the RCL headers, but /// some are specific to r2r, such as `GoalCancelRejected` which does /// not have an analogue in the rcl. -#[derive(Error, Debug)] +#[derive(Error, Clone, Debug)] pub enum Error { #[error("RCL_RET_OK")] RCL_RET_OK, diff --git a/r2r/src/nodes.rs b/r2r/src/nodes.rs index 49b08c1b5..02910ae77 100644 --- a/r2r/src/nodes.rs +++ b/r2r/src/nodes.rs @@ -791,7 +791,9 @@ impl Node { Ok(p) } - /// Create a ROS publisher with a type given at runtime. + /// Create a ROS publisher with a type given at runtime, where the data may either be + /// supplied as JSON (using the `publish` method) or a pre-serialized ROS message + /// (i.e. &[u8], using the `publish_raw` method). pub fn create_publisher_untyped( &mut self, topic: &str, topic_type: &str, qos_profile: QosProfile, ) -> Result { diff --git a/r2r/src/publishers.rs b/r2r/src/publishers.rs index 75b668f74..f984b1dc1 100644 --- a/r2r/src/publishers.rs +++ b/r2r/src/publishers.rs @@ -202,6 +202,43 @@ impl PublisherUntyped { } } + /// Publish an pre-serialized ROS message represented by a `&[u8]`. + /// + /// It is up to the user to make sure data is a valid ROS serialized message. + pub fn publish_raw(&self, data: &[u8]) -> Result<()> { + // TODO should this be an unsafe function? I'm not sure what happens if the data is malformed .. + + // upgrade to actual ref. if still alive + let publisher = self + .handle + .upgrade() + .ok_or(Error::RCL_RET_PUBLISHER_INVALID)?; + + // Safety: Not retained beyond this function + let msg_buf = rcl_serialized_message_t { + buffer: data.as_ptr() as *mut u8, + buffer_length: data.len(), + buffer_capacity: data.len(), + + // Since its read only, this should never be used .. + allocator: unsafe { rcutils_get_default_allocator() } + }; + + let result = + unsafe { rcl_publish_serialized_message( + &publisher.handle, + &msg_buf as *const rcl_serialized_message_t, + std::ptr::null_mut() + ) }; + + if result == RCL_RET_OK as i32 { + Ok(()) + } else { + log::error!("could not publish {}", result); + Err(Error::from_rcl_error(result)) + } + } + /// Gets the number of external subscribers (i.e. it doesn't /// count subscribers from the same process). pub fn get_inter_process_subscription_count(&self) -> Result { diff --git a/r2r/tests/threads.rs b/r2r/tests/threads.rs index f99debdfe..0b7c2d1df 100644 --- a/r2r/tests/threads.rs +++ b/r2r/tests/threads.rs @@ -3,62 +3,78 @@ use std::time::Duration; use r2r::QosProfile; +const N_NODE_PER_CONTEXT: usize = 5; +const N_CONCURRENT_ROS_CONTEXT: usize = 2; +const N_TEARDOWN_CYCLES: usize = 2; + #[test] // Let's create and drop a lot of node and publishers for a while to see that we can cope. fn doesnt_crash() -> Result<(), Box> { - // a global shared context. - let ctx = r2r::Context::create()?; - for c in 0..10 { - let mut ths = Vec::new(); - // I have lowered this from 30 to 10 because cyclonedds can only handle a hard-coded number of - // publishers in threads. See - // https://github.com/eclipse-cyclonedds/cyclonedds/blob/cd2136d9321212bd52fdc613f07bbebfddd90dec/src/core/ddsc/src/dds_init.c#L115 - for i in 0..10 { - // create concurrent nodes that max out the cpu - let ctx = ctx.clone(); - ths.push(thread::spawn(move || { - let mut node = r2r::Node::create(ctx, &format!("testnode{}", i), "").unwrap(); + let threads = (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| std::thread::spawn(move || { + + for _i_cycle in 0..N_TEARDOWN_CYCLES { + // a global shared context. + let ctx = r2r::Context::create().unwrap(); - // each with 10 publishers - for _j in 0..10 { - let p = node - .create_publisher::( - &format!("/r2r{}", i), - QosProfile::default(), - ) - .unwrap(); - let to_send = r2r::std_msgs::msg::String { - data: format!("[node{}]: {}", i, c), - }; + for c in 0..10 { + let mut ths = Vec::new(); + // I have lowered this from 30 to (10 / N_CONCURRENT_ROS_CONTEXT) because cyclonedds can only handle a hard-coded number of + // publishers in threads. See + // https://github.com/eclipse-cyclonedds/cyclonedds/blob/cd2136d9321212bd52fdc613f07bbebfddd90dec/src/core/ddsc/src/dds_init.c#L115 + for i_node in 0..N_NODE_PER_CONTEXT { + // create concurrent nodes that max out the cpu + let ctx = ctx.clone(); + ths.push(thread::spawn(move || { + let mut node = r2r::Node::create(ctx, &format!("testnode_{}_{}", i_context, i_node), "").unwrap(); - // move publisher to its own thread and publish as fast as we can - thread::spawn(move || loop { - let res = p.publish(&to_send); - thread::sleep(Duration::from_millis(1)); - match res { - Ok(_) => (), - Err(_) => { - // println!("publisher died, quitting thread."); - break; - } + // each with 10 publishers + for _j in 0..10 { + let p = node + .create_publisher::( + &format!("/r2r{}", i_node), + QosProfile::default(), + ) + .unwrap(); + let to_send = r2r::std_msgs::msg::String { + data: format!("[node{}]: {}", i_node, c), + }; + + // move publisher to its own thread and publish as fast as we can + thread::spawn(move || loop { + let res = p.publish(&to_send); + thread::sleep(Duration::from_millis(1)); + match res { + Ok(_) => (), + Err(_) => { + // println!("publisher died, quitting thread."); + break; + } + } + }); } - }); + + // spin to simulate some load + for _j in 0..100 { + node.spin_once(Duration::from_millis(10)); + } + + // println!("all done {}-{}", c, i); + })); } - // spin to simulate some load - for _j in 0..100 { - node.spin_once(Duration::from_millis(10)); + for t in ths { + t.join().unwrap(); } + // println!("all threads done {}", c); - // println!("all done {}-{}", c, i); - })); + } } + + })); - for t in ths { - t.join().unwrap(); - } - // println!("all threads done {}", c); + for thread in threads.into_iter() { + thread.join().unwrap(); } Ok(()) diff --git a/r2r/tests/tokio_test_raw.rs b/r2r/tests/tokio_test_raw.rs index e619449a1..1819143d7 100644 --- a/r2r/tests/tokio_test_raw.rs +++ b/r2r/tests/tokio_test_raw.rs @@ -1,65 +1,182 @@ use futures::stream::StreamExt; use r2r::QosProfile; use tokio::task; +use r2r::WrappedTypesupport; -#[tokio::test(flavor = "multi_thread")] + +const N_CONCURRENT_ROS_CONTEXT: usize = 3; +const N_TEARDOWN_CYCLES: usize = 2; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tokio_subscribe_raw_testing() -> Result<(), Box> { - let ctx = r2r::Context::create()?; - let mut node = r2r::Node::create(ctx, "testnode2", "")?; - - let mut sub_int = node.subscribe_raw("/int", "std_msgs/msg/Int32", QosProfile::default())?; - - let mut sub_array = - node.subscribe_raw("/int_array", "std_msgs/msg/Int32MultiArray", QosProfile::default())?; - - let pub_int = - node.create_publisher::("/int", QosProfile::default())?; - - // Use an array as well since its a variable sized type - let pub_array = node.create_publisher::( - "/int_array", - QosProfile::default(), - )?; - - task::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - (0..10).for_each(|i| { - pub_int - .publish(&r2r::std_msgs::msg::Int32 { data: i }) - .unwrap(); - - pub_array - .publish(&r2r::std_msgs::msg::Int32MultiArray { - layout: r2r::std_msgs::msg::MultiArrayLayout::default(), - data: vec![i], - }) - .unwrap(); - }); - }); - - let sub_int_handle = task::spawn(async move { - while let Some(msg) = sub_int.next().await { - println!("Got int msg of len {}", msg.len()); - assert_eq!(msg.len(), 8); - } - }); - - let sub_array_handle = task::spawn(async move { - while let Some(msg) = sub_array.next().await { - println!("Got array msg of len {}", msg.len()); - assert_eq!(msg.len(), 20); - } - }); - - let handle = std::thread::spawn(move || { - for _ in 1..=30 { - node.spin_once(std::time::Duration::from_millis(100)); - } - }); - - sub_int_handle.await?; - sub_array_handle.await?; - handle.join().unwrap(); + let mut threads = futures::stream::FuturesUnordered::from_iter( + (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move { + // Iterate to check for memory corruption on node setup/teardown + for i_cycle in 0..N_TEARDOWN_CYCLES { + println!("tokio_subscribe_raw_testing iteration {i_cycle}"); + + let ctx = r2r::Context::create().unwrap(); + let mut node = r2r::Node::create(ctx, &format!("testnode2_{i_context}"), "").unwrap(); + + let mut sub_int = node.subscribe_raw("/int", "std_msgs/msg/Int32", QosProfile::default()).unwrap(); + + let mut sub_array = + node.subscribe_raw("/int_array", "std_msgs/msg/Int32MultiArray", QosProfile::default()).unwrap(); + + let pub_int = + node.create_publisher::("/int", QosProfile::default()).unwrap(); + + // Use an array as well since its a variable sized type + let pub_array = node.create_publisher::( + "/int_array", + QosProfile::default(), + ).unwrap(); + + task::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + (0..10).for_each(|i| { + pub_int + .publish(&r2r::std_msgs::msg::Int32 { data: i }) + .unwrap(); + + pub_array + .publish(&r2r::std_msgs::msg::Int32MultiArray { + layout: r2r::std_msgs::msg::MultiArrayLayout::default(), + data: vec![i], + }) + .unwrap(); + }); + }); + + let sub_int_handle = task::spawn(async move { + while let Some(msg) = sub_int.next().await { + println!("Got int msg of len {}", msg.len()); + assert_eq!(msg.len(), 8); + } + }); + + let sub_array_handle = task::spawn(async move { + while let Some(msg) = sub_array.next().await { + println!("Got array msg of len {}", msg.len()); + assert_eq!(msg.len(), 20); + } + }); + + let handle = std::thread::spawn(move || { + for _ in 1..=30 { + node.spin_once(std::time::Duration::from_millis(100)); + } + }); + + sub_int_handle.await.unwrap(); + sub_array_handle.await.unwrap(); + handle.join().unwrap(); + + println!("Going to drop tokio_subscribe_raw_testing iteration {i_cycle}"); + } + + }))); + + while let Some(thread) = threads.next().await { + thread.unwrap(); + } + + Ok(()) +} + + +// Limit the number of threads to force threads to be reused +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn tokio_publish_raw_testing() -> Result<(), Box> { + + let mut threads = futures::stream::FuturesUnordered::from_iter( + (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move { + // Iterate to check for memory corruption on node setup/teardown + for i_cycle in 0..N_TEARDOWN_CYCLES { + + println!("tokio_publish_raw_testing iteration {i_cycle}"); + + let ctx = r2r::Context::create().unwrap(); + let mut node = r2r::Node::create(ctx, &format!("testnode3_{i_context}"), "").unwrap(); + + let mut sub_int = node.subscribe::("/int", QosProfile::default()).unwrap(); + + let mut sub_array = + node.subscribe::("/int_array", QosProfile::default()).unwrap(); + + let pub_int = node.create_publisher_untyped( + "/int", + "std_msgs/msg/Int32", + QosProfile::default() + ).unwrap(); + + // Use an array as well since its a variable sized type + let pub_array = node.create_publisher_untyped( + "/int_array", + "std_msgs/msg/Int32MultiArray", + QosProfile::default(), + ).unwrap(); + + task::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + (0..10).for_each(|i| { + pub_int + .publish_raw(&r2r::std_msgs::msg::Int32 { data: i }.to_serialized_bytes().unwrap()) + .unwrap(); + + pub_array + .publish_raw( + &r2r::std_msgs::msg::Int32MultiArray { + layout: r2r::std_msgs::msg::MultiArrayLayout::default(), + data: vec![i], + }.to_serialized_bytes().unwrap() + ) + .unwrap(); + }); + }); + + let sub_int_handle = task::spawn(async move { + while let Some(msg) = sub_int.next().await { + // Try to check for any possible corruption + msg.to_serialized_bytes().unwrap(); + + println!("Got int msg with value {}", msg.data); + assert!(msg.data >= 0); + assert!(msg.data < 10); + + } + }); + + let sub_array_handle = task::spawn(async move { + while let Some(msg) = sub_array.next().await { + // Try to check for any possible corruption + msg.to_serialized_bytes().unwrap(); + + println!("Got array msg with value {:?}", msg.data); + assert_eq!(msg.data.len(), 1); + assert!(msg.data[0] >= 0); + assert!(msg.data[0] < 10); + } + }); + + let handle = std::thread::spawn(move || { + for _ in 1..=30 { + node.spin_once(std::time::Duration::from_millis(100)); + } + }); + + sub_int_handle.await.unwrap(); + sub_array_handle.await.unwrap(); + handle.join().unwrap(); + + println!("Going to drop tokio_publish_raw_testing iteration {i_cycle}"); + + } + }))); + + while let Some(thread) = threads.next().await { + thread.unwrap(); + } Ok(()) } diff --git a/r2r/tests/tokio_testing.rs b/r2r/tests/tokio_testing.rs index 0736cb58f..aa4077f93 100644 --- a/r2r/tests/tokio_testing.rs +++ b/r2r/tests/tokio_testing.rs @@ -4,60 +4,94 @@ use r2r::QosProfile; use std::sync::{Arc, Mutex}; use tokio::task; -#[tokio::test(flavor = "multi_thread")] +const N_CONCURRENT_ROS_CONTEXT: usize = 3; +const N_TEARDOWN_CYCLES: usize = 2; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tokio_testing() -> Result<(), Box> { - let ctx = r2r::Context::create()?; - let mut node = r2r::Node::create(ctx, "testnode", "")?; - let mut s_the_no = - node.subscribe::("/the_no", QosProfile::default())?; - let mut s_new_no = - node.subscribe::("/new_no", QosProfile::default())?; - let p_the_no = - node.create_publisher::("/the_no", QosProfile::default())?; - let p_new_no = - node.create_publisher::("/new_no", QosProfile::default())?; - let state = Arc::new(Mutex::new(0)); - - task::spawn(async move { - (0..10).for_each(|i| { - p_the_no - .publish(&r2r::std_msgs::msg::Int32 { data: i }) - .unwrap(); - }); - }); - - task::spawn(async move { - while let Some(msg) = s_the_no.next().await { - p_new_no - .publish(&r2r::std_msgs::msg::Int32 { - data: msg.data + 10, - }) - .unwrap(); - } - }); - - let s = state.clone(); - task::spawn(async move { - while let Some(msg) = s_new_no.next().await { - let i = msg.data; - if i == 19 { - *s.lock().unwrap() = 19; - } - } - }); - - let handle = std::thread::spawn(move || { - for _ in 1..=30 { - node.spin_once(std::time::Duration::from_millis(100)); - let x = state.lock().unwrap(); - if *x == 19 { - break; + + let mut threads = futures::stream::FuturesUnordered::from_iter( + (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move { + // Iterate to check for memory corruption on node setup/teardown + for i_cycle in 0..N_TEARDOWN_CYCLES { + + println!("tokio_testing iteration {i_cycle}"); + + let ctx = r2r::Context::create().unwrap(); + // let ctx = std::thread::spawn(|| r2r::Context::create().unwrap()).join().unwrap(); + + let mut node = r2r::Node::create(ctx, &format!("testnode_{i_context}"), "").unwrap(); + let mut s_the_no = + node.subscribe::(&format!("/the_no_{i_context}"), QosProfile::default()).unwrap(); + let mut s_new_no = + node.subscribe::(&format!("/new_no_{i_context}"), QosProfile::default()).unwrap(); + let p_the_no = + node.create_publisher::(&format!("/the_no_{i_context}"), QosProfile::default()).unwrap(); + let p_new_no = + node.create_publisher::(&format!("/new_no_{i_context}"), QosProfile::default()).unwrap(); + let state = Arc::new(Mutex::new(0)); + + task::spawn(async move { + (0..10).for_each(|i| { + p_the_no + .publish(&r2r::std_msgs::msg::Int32 { data: i }) + .unwrap(); + + println!("send {i}"); + + }); + }); + + task::spawn(async move { + while let Some(msg) = s_the_no.next().await { + p_new_no + .publish(&r2r::std_msgs::msg::Int32 { + data: msg.data + 10, + }) + .unwrap(); + + println!("got {}, send {}", msg.data, msg.data + 10); + } + }); + + let s = state.clone(); + task::spawn(async move { + while let Some(msg) = s_new_no.next().await { + + println!("got {}", msg.data); + + let i = msg.data; + + *s.lock().unwrap() = i; + } + }); + + // std::thread::spawn doesn't work here anymore? + let handle = task::spawn_blocking(move || { + for _ in 1..30 { + node.spin_once(std::time::Duration::from_millis(100)); + let x = state.lock().unwrap(); + + println!("rec {}", x); + + if *x == 19 { + break; + } + } + + *state.lock().unwrap() + }); + let x = handle.await.unwrap(); + assert_eq!(x, 19); + + println!("tokio_testing finish iteration {i_cycle}"); + } - } + }))); + + while let Some(thread) = threads.next().await { + thread.unwrap(); + } - *state.lock().unwrap() - }); - let x = handle.join().unwrap(); - assert_eq!(x, 19); Ok(()) }