diff --git a/src/hal/serial.rs b/src/hal/serial.rs index f899229..cfd0563 100644 --- a/src/hal/serial.rs +++ b/src/hal/serial.rs @@ -12,22 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. //! Handlers for UART connections to/from nodes -use std::error::Error; -use std::fmt::Display; use std::io::{Read, Write}; use std::sync::Arc; +use thiserror::Error; +use tokio::sync::mpsc::error::SendError; +use super::NodeId; use anyhow::Result; use bytes::{Bytes, BytesMut}; use circular_buffer::CircularBuffer; use futures::{SinkExt, StreamExt}; -use tokio::sync::mpsc::{channel, Sender}; -use tokio::sync::Mutex; +use tokio::sync::mpsc::{self, channel, Sender}; +use tokio::sync::{broadcast, watch, Mutex}; use tokio_serial::{DataBits, Parity, SerialPortBuilderExt, StopBits}; use tokio_util::codec::{BytesCodec, Decoder}; -use super::NodeId; - const OUTPUT_BUF_SIZE: usize = 16 * 1024; type RingBuffer = CircularBuffer; @@ -44,7 +43,16 @@ impl SerialConnections { let handlers: Vec> = paths .iter() .enumerate() - .map(|(i, path)| Mutex::new(Handler::new(i + 1, path))) + .map(|(i, path)| { + Mutex::new(Handler::new( + i + 1, + path, + 115200, + DataBits::Eight, + Parity::None, + StopBits::One, + )) + }) .collect(); Ok(SerialConnections { handlers }) @@ -52,77 +60,92 @@ impl SerialConnections { pub async fn run(&self) -> Result<(), SerialError> { for h in &self.handlers { - h.lock().await.start_reader()?; + h.lock().await.run_handler()?; } Ok(()) } - pub async fn read(&self, node: NodeId) -> Result { + pub async fn start_session( + &self, + node: NodeId, + ) -> Result<(BytesMut, SerialChannel), SerialError> { let idx = node as usize; - self.handlers[idx].lock().await.read().await + let locked = self.handlers[idx].lock().await; + let buffer = locked.read_whole_buffer().await?; + Ok((buffer, locked.open_channel()?)) } +} - pub async fn write>(&self, node: NodeId, data: B) -> Result<(), SerialError> { - let idx = node as usize; - self.handlers[idx].lock().await.write(data.into()).await - } +pub struct SerialChannel { + inner: (broadcast::Receiver, mpsc::Sender), } #[derive(Debug)] struct Handler { node: usize, + baud_rate: u32, + data_bits: DataBits, + parity: Parity, + stop_bits: StopBits, path: &'static str, ring_buffer: Arc>>, - worker_context: Option>, + worker_context: Option<(broadcast::Sender, mpsc::Sender)>, } impl Handler { - fn new(node: usize, path: &'static str) -> Self { + fn new( + node: usize, + path: &'static str, + baud_rate: u32, + data_bits: DataBits, + parity: Parity, + stop_bits: StopBits, + ) -> Self { + let (sender, _) = watch::channel(BytesMut::new()); Handler { node, path, + baud_rate, + parity, + stop_bits, ring_buffer: Arc::new(Mutex::new(RingBuffer::boxed())), worker_context: None, } } - async fn write>(&self, data: B) -> Result<(), SerialError> { - let Some(sender) = &self.worker_context else { + pub fn open_channel(&self) -> Result { + let Some((read_sender, write_sender)) = self.worker_context else { return Err(SerialError::NotStarted); }; - sender - .send(data.into()) - .await - .map_err(|e| SerialError::InternalError(e.to_string())) + Ok(SerialChannel { + inner: (read_sender.subscribe(), write_sender.clone()), + }) } - async fn read(&self) -> Result { + /// This function returns all the cached data. + /// Time complexity: O(N) + async fn read_whole_buffer(&self) -> Result { if self.worker_context.is_none() { return Err(SerialError::NotStarted); }; let mut rb = self.ring_buffer.lock().await; - let mut buf = vec![0; rb.len()]; - - rb.read(&mut buf) - .map_err(|e| SerialError::InternalError(format!("failed to read: {}", e)))?; - - Ok(buf.into()) + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(rb.make_contiguous()); + Ok(bytes) } - fn start_reader(&mut self) -> Result<(), SerialError> { + fn run_handler(&mut self) -> Result<(), SerialError> { if self.worker_context.take().is_some() { return Err(SerialError::AlreadyRunning); }; - let baud_rate = 115200; - let mut port = tokio_serial::new(self.path, baud_rate) - .data_bits(DataBits::Eight) - .parity(Parity::None) - .stop_bits(StopBits::One) - .open_native_async() - .map_err(|e| SerialError::InternalError(e.to_string()))?; + let mut port = tokio_serial::new(self.path, self.baud_rate) + .data_bits(self.data_bits) + .parity(self.parity) + .stop_bits(self.stop_bits) + .open_native_async()?; // Disable exclusivity of the port to allow other applications to open it. // Not a reason to abort if we can't. @@ -130,8 +153,9 @@ impl Handler { log::warn!("Unable to set exclusivity of port {}: {}", self.path, e); } - let (sender, mut receiver) = channel::(64); - self.worker_context = Some(sender); + let (read_sender, _) = broadcast::channel::(8); + let (write_sender, mut write_receiver) = mpsc::channel::(8); + self.worker_context = Some((read_sender.clone(), write_sender)); let node = self.node; let buffer = self.ring_buffer.clone(); @@ -139,7 +163,7 @@ impl Handler { let (mut sink, mut stream) = BytesCodec::new().framed(port).split(); loop { tokio::select! { - res = receiver.recv() => { + res = write_receiver.recv() => { let Some(data) = res else { log::error!("error sending data to uart"); break; @@ -161,10 +185,15 @@ impl Handler { }; // Implementation is actually infallible in the currently used v0.1.3 - let Ok(_) = buffer.lock().await.write(&bytes) else { + if buffer.lock().await.write(&bytes).is_err() { log::error!("Failed to write to buffer of node {}", node); break; }; + + if let Err(e) = read_sender.send(bytes) { + log::error!("broadcast error: {:#}", e); + break; + } }, } } @@ -175,21 +204,16 @@ impl Handler { } } -#[derive(Debug)] +#[derive(Error, Debug)] pub enum SerialError { + #[error("serial worker not started")] NotStarted, + #[error("already running")] AlreadyRunning, - InternalError(String), -} - -impl Error for SerialError {} - -impl Display for SerialError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - SerialError::NotStarted => write!(f, "serial worker not started"), - SerialError::AlreadyRunning => write!(f, "already running"), - SerialError::InternalError(e) => e.fmt(f), - } - } + #[error(transparent)] + SendError(#[from] SendError), + #[error(transparent)] + SerialError(#[from] tokio_serial::Error), + #[error(transparent)] + IoError(#[from] std::io::Error), }