Skip to content

Commit

Permalink
electrical_protocol: Add packet driver that does not use ROS (WIP for #…
Browse files Browse the repository at this point in the history
  • Loading branch information
cbrxyz committed Oct 29, 2024
1 parent a62159b commit 0fd1561
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 0 deletions.
1 change: 1 addition & 0 deletions mil_common/drivers/electrical_protocol/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ catkin_package()
if(CATKIN_ENABLE_TESTING)
find_package(rostest REQUIRED)
add_rostest(test/simulated.test)
add_rostest(test/noros.test)
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
#! /usr/bin/env python3
##################################3
# electrical_protocol driver without ROS
#
# To run:
# - Leave running in another process/terminal session
# - Inherit from this class to write your own packet sender/receiver
##################################3
from __future__ import annotations

import abc
import contextlib
import logging
import threading
from typing import Any, Generic, TypeVar, Union, cast, get_args, get_origin

import serial

from .packet import SYNC_CHAR_1, Packet

SendPackets = TypeVar("SendPackets", bound=Packet)
RecvPackets = TypeVar("RecvPackets", bound=Packet)


logger = logging.getLogger(__name__)


class BufferThread(threading.Thread):
def __init__(self, event, callable):
super().__init__()
self.stopped = event
self.hz = 20.0
self.callable = callable

def run(self):
while not self.stopped.wait(1 / self.hz):
self.callable()

def set_hz(self, hz: float):
self.hz = hz


class ROSSerialDevice(Generic[SendPackets, RecvPackets]):
"""
Represents a generic serial device, which is expected to be the main component
of an individual ROS node.
Attributes:
port (Optional[str]): The port used for the serial connection, if provided.
baudrate (Optional[int]): The baudrate to use with the device, if provided.
device (Optional[serial.Serial]): The serial class used to communicate with
the device.
rate (float): The reading rate of the device, in Hertz. Set to `20` by default.
"""

device: serial.Serial | None
_recv_T: Any
_send_T: Any

def is_connected(self) -> bool:
return self.device is not None

def is_open(self) -> bool:
return bool(self.device) and self.device.is_open

def __init__(
self,
port: str | None,
baudrate: int | None,
buffer_process_hz: float = 20.0,
) -> None:
"""
Arguments:
port (Optional[str]): The serial port to connect to. If ``None``, connection
will not be established on initialization; rather, the user can use
:meth:`~.connect` to connect later.
baudrate (Optional[int]): The baudrate to connect with. If ``None`` and
a port is specified, then 115200 is assumed.
"""
self.port = port
self.baudrate = baudrate
if port:
self.device = serial.Serial(port, baudrate or 115200, timeout=0.1)
if not self.device.is_open:
self.device.open()
self.device.reset_input_buffer()
self.device.reset_output_buffer()
else:
self.device = None
self.lock = threading.Lock()
self.rate = buffer_process_hz
self.buff_event = threading.Event()
self.buff_thread = BufferThread(self.buff_event, self._process_buffer)
self.buff_thread.daemon = True
self.buff_thread.start()

def __init_subclass__(cls) -> None:
# this is a super hack lol :P
# cred: https://stackoverflow.com/a/71720366
cls._send_T = get_args(cls.__orig_bases__[0])[0] # type: ignore
cls._recv_T = get_args(cls.__orig_bases__[0])[1] # type: ignore

def connect(self, port: str, baudrate: int) -> None:
"""
Connects to the port with the given baudrate. If the device is already
connected, the input and output buffers will be flushed.
Arguments:
port (str): The serial port to connect to.
baudrate (int): The baudrate to connect with.
"""
self.port = port
self.baudrate = baudrate
self.device = serial.Serial(port, baudrate, timeout=0.1)
if not self.device.is_open:
self.device.open()
self.device.reset_input_buffer()
self.device.reset_output_buffer()

def close(self) -> None:
"""
Closes the serial device.
"""
logger.info("Shutting down thread...")
self.buff_event.set()
logger.info("Closing serial device...")
if not self.device:
raise RuntimeError("Device is not connected.")
else:
# TODO: Find a better way to deal with these os errors
with contextlib.suppress(OSError):
if self.device.in_waiting:
logger.warn(
"Shutting down device, but packets were left in buffer...",
)
self.device.close()

def write(self, data: bytes) -> None:
"""
Writes a series of raw bytes to the device. This method should rarely be
used; using :meth:`~.send_packet` is preferred because of the guarantees
it provides through the packet class.
Arguments:
data (bytes): The data to send.
"""
if not self.device:
raise RuntimeError("Device is not connected.")
self.device.write(data)

def send_packet(self, packet: SendPackets) -> None:
"""
Sends a given packet to the device.
Arguments:
packet (:class:`~.Packet`): The packet to send.
"""
with self.lock:
self.write(bytes(packet))

def _read_from_stream(self) -> bytes:
# Read until SOF is encourntered in case buffer contains the end of a previous packet
if not self.device:
raise RuntimeError("Device is not connected.")
sof = None
for _ in range(10):
sof = self.device.read(1)
if not len(sof):
continue
sof_int = int.from_bytes(sof, byteorder="big")
if sof_int == SYNC_CHAR_1:
break
if not isinstance(sof, bytes):
raise TimeoutError("No SOF received in one second.")
sof_int = int.from_bytes(sof, byteorder="big")
if sof_int != SYNC_CHAR_1:
logger.error("Where da start char at?")
data = sof
# Read sync char 2, msg ID, subclass ID
data += self.device.read(3)
length = self.device.read(2) # read payload length
data += length
data += self.device.read(
int.from_bytes(length, byteorder="little") + 2,
) # read data and checksum
return data

def _correct_type(self, provided: Any) -> bool:
# either:
# 1. RecvPackets is a Packet --> check isinstance on the type var
# 2. RecvPackets is a Union of Packets --> check isinstance on all
if get_origin(self._recv_T) is Union:
return isinstance(provided, get_args(self._recv_T))
else:
return isinstance(provided, self._recv_T)

def adjust_read_rate(self, rate: float) -> None:
"""
Sets the reading rate to a specified amount.
Arguments:
rate (float): The reading speed to use, in hz.
"""
self.rate = min(rate, 1_000)
self.buff_thread.set_hz(rate)

def scale_read_rate(self, scale: float) -> None:
"""
Scales the reading rate of the device handle by some factor.
Arguments:
scale (float): The amount to scale the reading rate by.
"""
self.adjust_read_rate(self.rate * scale)

def _read_packet(self) -> bool:
if not self.device:
raise RuntimeError("Device is not connected.")
try:
with self.lock:
if not self.is_open() or self.device.in_waiting == 0:
return False
if self.device.in_waiting > 200:
logger.warn(
"Packets are coming in much quicker than expected, upping rate...",
)
self.scale_read_rate(1 + self.device.in_waiting / 1000)
packed_packet = self._read_from_stream()
assert isinstance(packed_packet, bytes)
packet = Packet.from_bytes(packed_packet)
except serial.SerialException as e:
logger.error(f"Error reading packet: {e}")
return False
except OSError:
logger.error("Cannot read from serial device.")
return False
if not self._correct_type(packet):
logger.error(
f"Received unexpected packet: {packet} (expected: {self._recv_T})",
)
return False
packet = cast(RecvPackets, packet)
self.on_packet_received(packet)
return True

def _process_buffer(self) -> None:
if not self.is_open():
return
try:
self._read_packet()
except Exception as e:
logger.error(f"Error reading recent packet: {e}")
import traceback

traceback.print_exc()

@abc.abstractmethod
def on_packet_received(self, packet: RecvPackets) -> None:
"""
Abstract method to be implemented by subclasses for handling packets
sent by the physical electrical board.
Arguments:
packet (:class:`~.Packet`): The packet that is received.
"""
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python3

from __future__ import annotations

from dataclasses import dataclass
from threading import Event
from typing import Union

import rospy
from electrical_protocol import Packet
from electrical_protocol.driver_noros import ROSSerialDevice
from std_msgs.msg import Float32, String
from std_srvs.srv import Empty, EmptyRequest, EmptyResponse


@dataclass
class RequestAddPacket(Packet, class_id=0x37, subclass_id=0x00, payload_format="<ff"):
number_one: float
number_two: float


@dataclass
class RequestSubPacket(Packet, class_id=0x37, subclass_id=0x01, payload_format="<ff"):
start: float
minus: float


@dataclass
class AnswerPacket(Packet, class_id=0x37, subclass_id=0x02, payload_format="<f"):
result: float


class CalculatorDevice(
ROSSerialDevice[Union[RequestAddPacket, RequestSubPacket], AnswerPacket],
):
def __init__(self):
self.port_topic = rospy.Subscriber("~port", String, self.port_callback)
self.start_service = rospy.Service("~trigger", Empty, self.trigger)
self.answer_topic = rospy.Publisher("~answer", Float32, queue_size=10)
self.next_packet = Event()
self.i = 0
super().__init__(None, 115200)

def port_callback(self, msg: String):
self.connect(msg.data, 115200)

def trigger(self, _: EmptyRequest):
self.num_one, self.num_two = self.i, 1000 - self.i
self.i += 1
self.send_packet(
RequestAddPacket(number_one=self.num_one, number_two=self.num_two),
)
return EmptyResponse()

def on_packet_received(self, packet) -> None:
self.answer_topic.publish(Float32(data=packet.result))
self.next_packet.set()


if __name__ == "__main__":
rospy.init_node("calculator_device")
device = CalculatorDevice()
rospy.on_shutdown(device.close)
rospy.spin()
6 changes: 6 additions & 0 deletions mil_common/drivers/electrical_protocol/test/noros.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<launch>
<param name="/is_simulation" value="True" />
<node pkg="electrical_protocol" type="calculator_device_noros.py" name="calculator_device" output="screen" />
<test pkg="electrical_protocol" test-name="test_simulated_basic" type="test_simulated_basic.py" />
</launch>

0 comments on commit 0fd1561

Please sign in to comment.