Skip to content

Commit

Permalink
use read-write locks (#26)
Browse files Browse the repository at this point in the history
* use read-write lock instead of lock more

* always sleep at least some amount

* modify sleep logic
  • Loading branch information
codekansas authored Oct 12, 2024
1 parent b61109e commit 513e8a9
Showing 1 changed file with 59 additions and 63 deletions.
122 changes: 59 additions & 63 deletions actuator/rust/robstride/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use serialport::SerialPort;
use std::collections::{HashMap, HashSet};
use std::io::{Read, Write};
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;

Expand Down Expand Up @@ -884,17 +884,17 @@ impl Motors {

pub struct MotorsSupervisor {
motors: Arc<Mutex<Motors>>,
target_params: Arc<Mutex<HashMap<u8, MotorControlParams>>>,
running: Arc<Mutex<bool>>,
latest_feedback: Arc<Mutex<HashMap<u8, MotorFeedback>>>,
target_params: Arc<RwLock<HashMap<u8, MotorControlParams>>>,
running: Arc<RwLock<bool>>,
latest_feedback: Arc<RwLock<HashMap<u8, MotorFeedback>>>,
motors_to_zero: Arc<Mutex<HashSet<u8>>>,
paused: Arc<Mutex<bool>>,
paused: Arc<RwLock<bool>>,
restart: Arc<Mutex<bool>>,
total_commands: Arc<Mutex<u64>>,
failed_commands: Arc<Mutex<u64>>,
min_update_rate: Arc<Mutex<f64>>,
target_update_rate: Arc<Mutex<f64>>,
actual_update_rate: Arc<Mutex<f64>>,
total_commands: Arc<RwLock<u64>>,
failed_commands: Arc<RwLock<u64>>,
min_update_rate: Arc<RwLock<f64>>,
target_update_rate: Arc<RwLock<f64>>,
actual_update_rate: Arc<RwLock<f64>>,
}

impl MotorsSupervisor {
Expand Down Expand Up @@ -934,26 +934,19 @@ impl MotorsSupervisor {
.map(|(&id, _)| id)
.collect::<HashSet<u8>>();

let motors = Arc::new(Mutex::new(motors));
let motors_to_zero = Arc::new(Mutex::new(zero_on_init_motors));
let target_params = Arc::new(Mutex::new(target_params));
let running = Arc::new(Mutex::new(true));
let paused = Arc::new(Mutex::new(false));
let restart = Arc::new(Mutex::new(false));

let controller = MotorsSupervisor {
motors,
target_params,
running,
latest_feedback: Arc::new(Mutex::new(HashMap::new())),
motors_to_zero,
paused,
restart,
total_commands: Arc::new(Mutex::new(0)),
failed_commands: Arc::new(Mutex::new(0)),
min_update_rate: Arc::new(Mutex::new(min_update_rate)),
target_update_rate: Arc::new(Mutex::new(target_update_rate)),
actual_update_rate: Arc::new(Mutex::new(0.0)),
motors: Arc::new(Mutex::new(motors)),
target_params: Arc::new(RwLock::new(target_params)),
running: Arc::new(RwLock::new(true)),
latest_feedback: Arc::new(RwLock::new(HashMap::new())),
motors_to_zero: Arc::new(Mutex::new(zero_on_init_motors)),
paused: Arc::new(RwLock::new(false)),
restart: Arc::new(Mutex::new(false)),
total_commands: Arc::new(RwLock::new(0)),
failed_commands: Arc::new(RwLock::new(0)),
min_update_rate: Arc::new(RwLock::new(min_update_rate)),
target_update_rate: Arc::new(RwLock::new(target_update_rate)),
actual_update_rate: Arc::new(RwLock::new(0.0)),
};

controller.start_control_thread();
Expand Down Expand Up @@ -982,22 +975,22 @@ impl MotorsSupervisor {
let _ = motors.send_start();

// Set CAN timeout based on minimum update rate
let can_timeout = (1000.0 / *min_update_rate.lock().unwrap()) as u32;
let can_timeout = (1000.0 / *min_update_rate.read().unwrap()) as u32;
let _ = motors.send_can_timeout(can_timeout);

let mut last_update_time = std::time::Instant::now();

loop {
{
// If not running, break the loop.
if !*running.lock().unwrap() {
if !*running.read().unwrap() {
break;
}
}

{
// If paused, just wait a short time without sending any commands.
if *paused.lock().unwrap() {
if *paused.read().unwrap() {
std::thread::sleep(Duration::from_millis(10));
continue;
}
Expand All @@ -1018,7 +1011,7 @@ impl MotorsSupervisor {
{
// Read latest feedback from motors.
let latest_feedback_from_motors = motors.get_latest_feedback();
let mut latest_feedback = latest_feedback.lock().unwrap();
let mut latest_feedback = latest_feedback.write().unwrap();
*latest_feedback = latest_feedback_from_motors.clone();
}

Expand All @@ -1028,7 +1021,7 @@ impl MotorsSupervisor {
let motor_ids = motor_ids_to_zero.iter().cloned().collect::<Vec<u8>>();
if !motor_ids.is_empty() {
if let Err(_) = motors.send_set_zero(Some(&motor_ids)) {
*failed_commands.lock().unwrap() += 1;
*failed_commands.write().unwrap() += 1;
}
motor_ids_to_zero.clear();
}
Expand All @@ -1038,32 +1031,35 @@ impl MotorsSupervisor {
.map(|id| (*id, MotorControlParams::default())),
);
if let Err(_) = motors.send_motor_controls(&torque_commands) {
*failed_commands.lock().unwrap() += 1;
*failed_commands.write().unwrap() += 1;
}
*total_commands.lock().unwrap() += 1;
*total_commands.write().unwrap() += 1;
}

// Send PD commands to motors.
{
let target_params = target_params.lock().unwrap();
let target_params = target_params.read().unwrap();
if let Err(_) = motors.send_motor_controls(&target_params) {
*failed_commands.lock().unwrap() += 1;
*failed_commands.write().unwrap() += 1;
}
*total_commands.lock().unwrap() += 1;
*total_commands.write().unwrap() += 1;
}

// Calculate actual update rate
let elapsed = loop_start_time.duration_since(last_update_time);
last_update_time = loop_start_time;
let current_rate = 1.0 / elapsed.as_secs_f64();
*actual_update_rate.lock().unwrap() = current_rate;
*actual_update_rate.write().unwrap() = current_rate;

// Sleep to maintain target update rate
let target_duration =
Duration::from_secs_f64(1.0 / *target_update_rate.lock().unwrap());
Duration::from_secs_f64(1.0 / *target_update_rate.read().unwrap());
let elapsed = loop_start_time.elapsed();
if elapsed < target_duration {
let min_sleep_duration = Duration::from_micros(1);
if target_duration > elapsed + min_sleep_duration {
std::thread::sleep(target_duration - elapsed);
} else {
std::thread::sleep(min_sleep_duration);
}
}

Expand All @@ -1085,32 +1081,32 @@ impl MotorsSupervisor {

// New methods to access the command counters
pub fn get_total_commands(&self) -> u64 {
*self.total_commands.lock().unwrap()
*self.total_commands.read().unwrap()
}

pub fn get_failed_commands(&self) -> u64 {
*self.failed_commands.lock().unwrap()
*self.failed_commands.read().unwrap()
}

pub fn reset_command_counters(&self) {
*self.total_commands.lock().unwrap() = 0;
*self.failed_commands.lock().unwrap() = 0;
*self.total_commands.write().unwrap() = 0;
*self.failed_commands.write().unwrap() = 0;
}

pub fn set_params(&self, motor_id: u8, params: MotorControlParams) {
let mut target_params = self.target_params.lock().unwrap();
let mut target_params = self.target_params.write().unwrap();
target_params.insert(motor_id, params);
}

pub fn set_position(&self, motor_id: u8, position: f32) {
let mut target_params = self.target_params.lock().unwrap();
let mut target_params = self.target_params.write().unwrap();
if let Some(params) = target_params.get_mut(&motor_id) {
params.position = position;
}
}

pub fn get_position(&self, motor_id: u8) -> Result<f32, std::io::Error> {
let target_params = self.target_params.lock().unwrap();
let target_params = self.target_params.read().unwrap();
target_params
.get(&motor_id)
.map(|params| params.position)
Expand All @@ -1123,14 +1119,14 @@ impl MotorsSupervisor {
}

pub fn set_velocity(&self, motor_id: u8, velocity: f32) {
let mut target_params = self.target_params.lock().unwrap();
let mut target_params = self.target_params.write().unwrap();
if let Some(params) = target_params.get_mut(&motor_id) {
params.velocity = velocity;
}
}

pub fn get_velocity(&self, motor_id: u8) -> Result<f32, std::io::Error> {
let target_params = self.target_params.lock().unwrap();
let target_params = self.target_params.read().unwrap();
target_params
.get(&motor_id)
.map(|params| params.velocity)
Expand All @@ -1143,14 +1139,14 @@ impl MotorsSupervisor {
}

pub fn set_kp(&self, motor_id: u8, kp: f32) {
let mut target_params = self.target_params.lock().unwrap();
let mut target_params = self.target_params.write().unwrap();
if let Some(params) = target_params.get_mut(&motor_id) {
params.kp = kp.clamp(0.0, kp); // Clamp kp to be non-negative.
}
}

pub fn get_kp(&self, motor_id: u8) -> Result<f32, std::io::Error> {
let target_params = self.target_params.lock().unwrap();
let target_params = self.target_params.read().unwrap();
target_params
.get(&motor_id)
.map(|params| params.kp)
Expand All @@ -1163,14 +1159,14 @@ impl MotorsSupervisor {
}

pub fn set_kd(&self, motor_id: u8, kd: f32) {
let mut target_params = self.target_params.lock().unwrap();
let mut target_params = self.target_params.write().unwrap();
if let Some(params) = target_params.get_mut(&motor_id) {
params.kd = kd.clamp(0.0, kd); // Clamp kd to be non-negative.
}
}

pub fn get_kd(&self, motor_id: u8) -> Result<f32, std::io::Error> {
let target_params = self.target_params.lock().unwrap();
let target_params = self.target_params.read().unwrap();
target_params
.get(&motor_id)
.map(|params| params.kd)
Expand All @@ -1183,14 +1179,14 @@ impl MotorsSupervisor {
}

pub fn set_torque(&self, motor_id: u8, torque: f32) {
let mut target_params = self.target_params.lock().unwrap();
let mut target_params = self.target_params.write().unwrap();
if let Some(params) = target_params.get_mut(&motor_id) {
params.torque = torque;
}
}

pub fn get_torque(&self, motor_id: u8) -> Result<f32, std::io::Error> {
let target_params = self.target_params.lock().unwrap();
let target_params = self.target_params.read().unwrap();
target_params
.get(&motor_id)
.map(|params| params.torque)
Expand All @@ -1213,12 +1209,12 @@ impl MotorsSupervisor {
}

pub fn get_latest_feedback(&self) -> HashMap<u8, MotorFeedback> {
let latest_feedback = self.latest_feedback.lock().unwrap();
let latest_feedback = self.latest_feedback.read().unwrap();
latest_feedback.clone()
}

pub fn toggle_pause(&self) {
let mut paused = self.paused.lock().unwrap();
let mut paused = self.paused.write().unwrap();
*paused = !*paused;
}

Expand All @@ -1229,27 +1225,27 @@ impl MotorsSupervisor {

pub fn stop(&self) {
{
let mut running = self.running.lock().unwrap();
let mut running = self.running.write().unwrap();
*running = false;
}
std::thread::sleep(Duration::from_millis(200));
}

pub fn set_min_update_rate(&self, rate: f64) {
let mut min_rate = self.min_update_rate.lock().unwrap();
let mut min_rate = self.min_update_rate.write().unwrap();
*min_rate = rate;
let can_timeout = (1000.0 / rate) as u32;
let mut motors = self.motors.lock().unwrap();
let _ = motors.send_can_timeout(can_timeout);
}

pub fn set_target_update_rate(&self, rate: f64) {
let mut target_rate = self.target_update_rate.lock().unwrap();
let mut target_rate = self.target_update_rate.write().unwrap();
*target_rate = rate;
}

pub fn get_actual_update_rate(&self) -> f64 {
*self.actual_update_rate.lock().unwrap()
*self.actual_update_rate.read().unwrap()
}
}

Expand Down

0 comments on commit 513e8a9

Please sign in to comment.