Skip to content

Commit

Permalink
bind
Browse files Browse the repository at this point in the history
  • Loading branch information
WT-MM committed Dec 16, 2024
1 parent 1569dcf commit b0f3ff6
Showing 1 changed file with 40 additions and 41 deletions.
81 changes: 40 additions & 41 deletions actuator/bindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use pyo3::prelude::*;
use pyo3_stub_gen::define_stub_info_gatherer;
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};use robstride::{CH341Transport, ControlConfig, SocketCanTransport, Supervisor, TransportType};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
use robstride::{
ActuatorType, CH341Transport, ControlConfig, SocketCanTransport,
Supervisor, TransportType, ActuatorConfiguration
ActuatorType, ActuatorConfiguration, Supervisor, TransportType,
CH341Transport, ControlConfig, SocketCanTransport,
};
use std::sync::Arc;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -145,59 +145,52 @@ impl PyRobstrideActuator {
#[new]
fn new(
ports: Vec<String>,
actuator_timeout: f64,
polling_interval: f64,
py_actuators_config: Vec<(u8, PyRobstrideActuatorConfig)>,
polling_interval: f64,
) -> PyResult<Self> {
let rt = Runtime::new().unwrap();

let actuators_config: Vec<(u8, robstride::ActuatorConfiguration)> =
py_actuators_config.into_iter()
.map(|(id, config)| (id, config.into()))
.collect();
let actuators_config: Vec<(u8, ActuatorConfiguration)> = py_actuators_config
.into_iter()
.map(|(id, config)| (id, config.into()))
.collect();

let rt = Runtime::new().map_err(|e| ErrReportWrapper(e.into()))?;

let supervisor = rt.block_on(async {
let mut supervisor = Supervisor::new(Duration::from_secs_f64(actuator_timeout))?;
let mut supervisor = Supervisor::new()
.map_err(|e| ErrReportWrapper(e))?;

// Add transports
for port in &ports {
if port.starts_with("/dev/tty") {
let serial = CH341Transport::new(port.clone()).await?;
let serial = CH341Transport::new(port.clone()).await
.map_err(|e| ErrReportWrapper(e))?;
supervisor
.add_transport(port.clone(), TransportType::CH341(serial))
.await?;
.await
.map_err(|e| ErrReportWrapper(e))?;
} else if port.starts_with("can") {
let can = SocketCanTransport::new(port.clone()).await?;
let can = SocketCanTransport::new(port.clone()).await
.map_err(|e| ErrReportWrapper(e))?;
supervisor
.add_transport(port.clone(), TransportType::SocketCAN(can))
.await?;
.await
.map_err(|e| ErrReportWrapper(e))?;
} else {
return Err(eyre::eyre!("Invalid port: {}", port));
return Err(ErrReportWrapper(eyre::eyre!("Invalid port: {}", port)));
}
}

// Start supervisor task
let mut supervisor_runner = supervisor.clone_controller();
tokio::spawn(async move {
if let Err(e) = supervisor_runner
.run(Duration::from_secs_f64(polling_interval))
.await
{
tracing::error!("Supervisor task failed: {}", e);
}
});

// Scan for motors
for port in &ports {
let discovered_ids = supervisor.scan_bus(0xFD, port, &actuators_config).await?;
let discovered_ids = supervisor.scan_bus(0xFD, port, &actuators_config).await
.map_err(|e| ErrReportWrapper(e))?;
for (motor_id, _) in &actuators_config {
if !discovered_ids.contains(motor_id) {
tracing::warn!("Configured motor not found - ID: {}", motor_id);
}
}
}

Ok::<_, eyre::Error>(supervisor)
Ok(supervisor)
})?;

Ok(PyRobstrideActuator {
Expand All @@ -219,7 +212,8 @@ impl PyRobstrideActuator {
cmd.velocity.map(|v| v.to_radians() as f32).unwrap_or(0.0),
cmd.torque.map(|t| t as f32).unwrap_or(0.0),
)
.await;
.await
.map_err(|e| ErrReportWrapper(e))?;
results.push(result.is_ok());
}
Ok(results)
Expand All @@ -238,27 +232,32 @@ impl PyRobstrideActuator {
max_current: Some(10.0),
};

let result = supervisor.configure(config.actuator_id as u8, control_config).await;
let result = supervisor.configure(config.actuator_id as u8, control_config).await
.map_err(|e| ErrReportWrapper(e))?;

if let Some(torque_enabled) = config.torque_enabled {
if torque_enabled {
supervisor.enable(config.actuator_id as u8).await?;
supervisor.enable(config.actuator_id as u8).await
.map_err(|e| ErrReportWrapper(e))?;
} else {
supervisor.disable(config.actuator_id as u8, true).await?;
supervisor.disable(config.actuator_id as u8, true).await
.map_err(|e| ErrReportWrapper(e))?;
}
}

if let Some(true) = config.zero_position {
supervisor.zero(config.actuator_id as u8).await?;
supervisor.zero(config.actuator_id as u8).await
.map_err(|e| ErrReportWrapper(e))?;
}

if let Some(new_id) = config.new_actuator_id {
supervisor
.change_id(config.actuator_id as u8, new_id as u8)
.await?;
.await
.map_err(|e| ErrReportWrapper(e))?;
}

Ok(result.is_ok())
Ok(true)
})
}

Expand Down Expand Up @@ -308,8 +307,8 @@ impl From<eyre::Report> for PyErr {
}

#[pymodule]
fn robstride_bindings(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(get_version))?;
fn robstride_bindings(m: &Bound<PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_version, m)?)?;
m.add_class::<PyRobstrideActuator>()?;
m.add_class::<PyRobstrideActuatorCommand>()?;
m.add_class::<PyRobstrideConfigureRequest>()?;
Expand All @@ -318,4 +317,4 @@ fn robstride_bindings(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Ok(())
}

define_stub_info_gatherer!(robstride_bindings);
define_stub_info_gatherer!(stub_info);

0 comments on commit b0f3ff6

Please sign in to comment.