-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
electrical_protocol: Add packet driver that does not use ROS (WIP for #…
- Loading branch information
Showing
4 changed files
with
337 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
266 changes: 266 additions & 0 deletions
266
mil_common/drivers/electrical_protocol/electrical_protocol/driver_noros.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
#! /usr/bin/env python3 | ||
##################################3 | ||
# electrical_protocol driver without ROS | ||
# | ||
# To run: | ||
# - Leave running in another process/terminal session | ||
# - Inherit from this class to write your own packet sender/receiver | ||
##################################3 | ||
from __future__ import annotations | ||
|
||
import abc | ||
import contextlib | ||
import logging | ||
import threading | ||
from typing import Any, Generic, TypeVar, Union, cast, get_args, get_origin | ||
|
||
import serial | ||
|
||
from .packet import SYNC_CHAR_1, Packet | ||
|
||
SendPackets = TypeVar("SendPackets", bound=Packet) | ||
RecvPackets = TypeVar("RecvPackets", bound=Packet) | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class BufferThread(threading.Thread): | ||
def __init__(self, event, callable): | ||
super().__init__() | ||
self.stopped = event | ||
self.hz = 20.0 | ||
self.callable = callable | ||
|
||
def run(self): | ||
while not self.stopped.wait(1 / self.hz): | ||
self.callable() | ||
|
||
def set_hz(self, hz: float): | ||
self.hz = hz | ||
|
||
|
||
class ROSSerialDevice(Generic[SendPackets, RecvPackets]): | ||
""" | ||
Represents a generic serial device, which is expected to be the main component | ||
of an individual ROS node. | ||
Attributes: | ||
port (Optional[str]): The port used for the serial connection, if provided. | ||
baudrate (Optional[int]): The baudrate to use with the device, if provided. | ||
device (Optional[serial.Serial]): The serial class used to communicate with | ||
the device. | ||
rate (float): The reading rate of the device, in Hertz. Set to `20` by default. | ||
""" | ||
|
||
device: serial.Serial | None | ||
_recv_T: Any | ||
_send_T: Any | ||
|
||
def is_connected(self) -> bool: | ||
return self.device is not None | ||
|
||
def is_open(self) -> bool: | ||
return bool(self.device) and self.device.is_open | ||
|
||
def __init__( | ||
self, | ||
port: str | None, | ||
baudrate: int | None, | ||
buffer_process_hz: float = 20.0, | ||
) -> None: | ||
""" | ||
Arguments: | ||
port (Optional[str]): The serial port to connect to. If ``None``, connection | ||
will not be established on initialization; rather, the user can use | ||
:meth:`~.connect` to connect later. | ||
baudrate (Optional[int]): The baudrate to connect with. If ``None`` and | ||
a port is specified, then 115200 is assumed. | ||
""" | ||
self.port = port | ||
self.baudrate = baudrate | ||
if port: | ||
self.device = serial.Serial(port, baudrate or 115200, timeout=0.1) | ||
if not self.device.is_open: | ||
self.device.open() | ||
self.device.reset_input_buffer() | ||
self.device.reset_output_buffer() | ||
else: | ||
self.device = None | ||
self.lock = threading.Lock() | ||
self.rate = buffer_process_hz | ||
self.buff_event = threading.Event() | ||
self.buff_thread = BufferThread(self.buff_event, self._process_buffer) | ||
self.buff_thread.daemon = True | ||
self.buff_thread.start() | ||
|
||
def __init_subclass__(cls) -> None: | ||
# this is a super hack lol :P | ||
# cred: https://stackoverflow.com/a/71720366 | ||
cls._send_T = get_args(cls.__orig_bases__[0])[0] # type: ignore | ||
cls._recv_T = get_args(cls.__orig_bases__[0])[1] # type: ignore | ||
|
||
def connect(self, port: str, baudrate: int) -> None: | ||
""" | ||
Connects to the port with the given baudrate. If the device is already | ||
connected, the input and output buffers will be flushed. | ||
Arguments: | ||
port (str): The serial port to connect to. | ||
baudrate (int): The baudrate to connect with. | ||
""" | ||
self.port = port | ||
self.baudrate = baudrate | ||
self.device = serial.Serial(port, baudrate, timeout=0.1) | ||
if not self.device.is_open: | ||
self.device.open() | ||
self.device.reset_input_buffer() | ||
self.device.reset_output_buffer() | ||
|
||
def close(self) -> None: | ||
""" | ||
Closes the serial device. | ||
""" | ||
logger.info("Shutting down thread...") | ||
self.buff_event.set() | ||
logger.info("Closing serial device...") | ||
if not self.device: | ||
raise RuntimeError("Device is not connected.") | ||
else: | ||
# TODO: Find a better way to deal with these os errors | ||
with contextlib.suppress(OSError): | ||
if self.device.in_waiting: | ||
logger.warn( | ||
"Shutting down device, but packets were left in buffer...", | ||
) | ||
self.device.close() | ||
|
||
def write(self, data: bytes) -> None: | ||
""" | ||
Writes a series of raw bytes to the device. This method should rarely be | ||
used; using :meth:`~.send_packet` is preferred because of the guarantees | ||
it provides through the packet class. | ||
Arguments: | ||
data (bytes): The data to send. | ||
""" | ||
if not self.device: | ||
raise RuntimeError("Device is not connected.") | ||
self.device.write(data) | ||
|
||
def send_packet(self, packet: SendPackets) -> None: | ||
""" | ||
Sends a given packet to the device. | ||
Arguments: | ||
packet (:class:`~.Packet`): The packet to send. | ||
""" | ||
with self.lock: | ||
self.write(bytes(packet)) | ||
|
||
def _read_from_stream(self) -> bytes: | ||
# Read until SOF is encourntered in case buffer contains the end of a previous packet | ||
if not self.device: | ||
raise RuntimeError("Device is not connected.") | ||
sof = None | ||
for _ in range(10): | ||
sof = self.device.read(1) | ||
if not len(sof): | ||
continue | ||
sof_int = int.from_bytes(sof, byteorder="big") | ||
if sof_int == SYNC_CHAR_1: | ||
break | ||
if not isinstance(sof, bytes): | ||
raise TimeoutError("No SOF received in one second.") | ||
sof_int = int.from_bytes(sof, byteorder="big") | ||
if sof_int != SYNC_CHAR_1: | ||
logger.error("Where da start char at?") | ||
data = sof | ||
# Read sync char 2, msg ID, subclass ID | ||
data += self.device.read(3) | ||
length = self.device.read(2) # read payload length | ||
data += length | ||
data += self.device.read( | ||
int.from_bytes(length, byteorder="little") + 2, | ||
) # read data and checksum | ||
return data | ||
|
||
def _correct_type(self, provided: Any) -> bool: | ||
# either: | ||
# 1. RecvPackets is a Packet --> check isinstance on the type var | ||
# 2. RecvPackets is a Union of Packets --> check isinstance on all | ||
if get_origin(self._recv_T) is Union: | ||
return isinstance(provided, get_args(self._recv_T)) | ||
else: | ||
return isinstance(provided, self._recv_T) | ||
|
||
def adjust_read_rate(self, rate: float) -> None: | ||
""" | ||
Sets the reading rate to a specified amount. | ||
Arguments: | ||
rate (float): The reading speed to use, in hz. | ||
""" | ||
self.rate = min(rate, 1_000) | ||
self.buff_thread.set_hz(rate) | ||
|
||
def scale_read_rate(self, scale: float) -> None: | ||
""" | ||
Scales the reading rate of the device handle by some factor. | ||
Arguments: | ||
scale (float): The amount to scale the reading rate by. | ||
""" | ||
self.adjust_read_rate(self.rate * scale) | ||
|
||
def _read_packet(self) -> bool: | ||
if not self.device: | ||
raise RuntimeError("Device is not connected.") | ||
try: | ||
with self.lock: | ||
if not self.is_open() or self.device.in_waiting == 0: | ||
return False | ||
if self.device.in_waiting > 200: | ||
logger.warn( | ||
"Packets are coming in much quicker than expected, upping rate...", | ||
) | ||
self.scale_read_rate(1 + self.device.in_waiting / 1000) | ||
packed_packet = self._read_from_stream() | ||
assert isinstance(packed_packet, bytes) | ||
packet = Packet.from_bytes(packed_packet) | ||
except serial.SerialException as e: | ||
logger.error(f"Error reading packet: {e}") | ||
return False | ||
except OSError: | ||
logger.error("Cannot read from serial device.") | ||
return False | ||
if not self._correct_type(packet): | ||
logger.error( | ||
f"Received unexpected packet: {packet} (expected: {self._recv_T})", | ||
) | ||
return False | ||
packet = cast(RecvPackets, packet) | ||
self.on_packet_received(packet) | ||
return True | ||
|
||
def _process_buffer(self) -> None: | ||
if not self.is_open(): | ||
return | ||
try: | ||
self._read_packet() | ||
except Exception as e: | ||
logger.error(f"Error reading recent packet: {e}") | ||
import traceback | ||
|
||
traceback.print_exc() | ||
|
||
@abc.abstractmethod | ||
def on_packet_received(self, packet: RecvPackets) -> None: | ||
""" | ||
Abstract method to be implemented by subclasses for handling packets | ||
sent by the physical electrical board. | ||
Arguments: | ||
packet (:class:`~.Packet`): The packet that is received. | ||
""" | ||
pass |
64 changes: 64 additions & 0 deletions
64
mil_common/drivers/electrical_protocol/test/calculator_device_noros.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from threading import Event | ||
from typing import Union | ||
|
||
import rospy | ||
from electrical_protocol import Packet | ||
from electrical_protocol.driver_noros import ROSSerialDevice | ||
from std_msgs.msg import Float32, String | ||
from std_srvs.srv import Empty, EmptyRequest, EmptyResponse | ||
|
||
|
||
@dataclass | ||
class RequestAddPacket(Packet, class_id=0x37, subclass_id=0x00, payload_format="<ff"): | ||
number_one: float | ||
number_two: float | ||
|
||
|
||
@dataclass | ||
class RequestSubPacket(Packet, class_id=0x37, subclass_id=0x01, payload_format="<ff"): | ||
start: float | ||
minus: float | ||
|
||
|
||
@dataclass | ||
class AnswerPacket(Packet, class_id=0x37, subclass_id=0x02, payload_format="<f"): | ||
result: float | ||
|
||
|
||
class CalculatorDevice( | ||
ROSSerialDevice[Union[RequestAddPacket, RequestSubPacket], AnswerPacket], | ||
): | ||
def __init__(self): | ||
self.port_topic = rospy.Subscriber("~port", String, self.port_callback) | ||
self.start_service = rospy.Service("~trigger", Empty, self.trigger) | ||
self.answer_topic = rospy.Publisher("~answer", Float32, queue_size=10) | ||
self.next_packet = Event() | ||
self.i = 0 | ||
super().__init__(None, 115200) | ||
|
||
def port_callback(self, msg: String): | ||
self.connect(msg.data, 115200) | ||
|
||
def trigger(self, _: EmptyRequest): | ||
self.num_one, self.num_two = self.i, 1000 - self.i | ||
self.i += 1 | ||
self.send_packet( | ||
RequestAddPacket(number_one=self.num_one, number_two=self.num_two), | ||
) | ||
return EmptyResponse() | ||
|
||
def on_packet_received(self, packet) -> None: | ||
self.answer_topic.publish(Float32(data=packet.result)) | ||
self.next_packet.set() | ||
|
||
|
||
if __name__ == "__main__": | ||
rospy.init_node("calculator_device") | ||
device = CalculatorDevice() | ||
rospy.on_shutdown(device.close) | ||
rospy.spin() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<launch> | ||
<param name="/is_simulation" value="True" /> | ||
<node pkg="electrical_protocol" type="calculator_device_noros.py" name="calculator_device" output="screen" /> | ||
<test pkg="electrical_protocol" test-name="test_simulated_basic" type="test_simulated_basic.py" /> | ||
</launch> |