From dec18cd20c012e637fb0e1c5a8a1cb4b53464905 Mon Sep 17 00:00:00 2001 From: Dominic Reber <71256590+domire8@users.noreply.github.com> Date: Sun, 8 Dec 2024 16:47:22 +0100 Subject: [PATCH] feat(components): get clproto message type from attribute (#175) --- CHANGELOG.md | 1 + .../modulo_components/component.py | 4 +- .../modulo_components/component_interface.py | 19 ++++---- .../modulo_components/lifecycle_component.py | 4 +- .../modulo_components/test/python/conftest.py | 6 +-- .../translators/message_writers.py | 44 +++++++++++++++++++ 6 files changed, 62 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dffdade6..7c57d5ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Release Versions: - feat(controllers): add TF listener in BaseControllerInterface (#169) - feat(controllers): add TF broadcaster in BaseControllerInterface (#170) - test(controllers): add TF listener and broadcaster tests (#172) + - feat(components): get clproto message type from attribute (#175) ## 5.0.2 diff --git a/source/modulo_components/modulo_components/component.py b/source/modulo_components/modulo_components/component.py index 7d482218..59e331e2 100644 --- a/source/modulo_components/modulo_components/component.py +++ b/source/modulo_components/modulo_components/component.py @@ -1,5 +1,5 @@ from threading import Thread -from typing import TypeVar +from typing import Optional, TypeVar import clproto from modulo_components.component_interface import ComponentInterface @@ -74,7 +74,7 @@ def on_execute_callback(self) -> bool: return True def add_output(self, signal_name: str, data: str, message_type: MsgT, - clproto_message_type=clproto.MessageType.UNKNOWN_MESSAGE, default_topic="", fixed_topic=False, + clproto_message_type: Optional[clproto.MessageType] = None, default_topic="", fixed_topic=False, publish_on_step=True): """ Add and configure an output signal of the component. diff --git a/source/modulo_components/modulo_components/component_interface.py b/source/modulo_components/modulo_components/component_interface.py index 62377e60..ef53c127 100644 --- a/source/modulo_components/modulo_components/component_interface.py +++ b/source/modulo_components/modulo_components/component_interface.py @@ -422,7 +422,7 @@ def remove_output(self, signal_name): self.get_logger().debug(f"Removing signal '{signal_name}'.") def __create_output(self, signal_name: str, data: str, message_type: MsgT, - clproto_message_type: clproto.MessageType, default_topic: str, fixed_topic: bool, + clproto_message_type: Union[clproto.MessageType, None], default_topic: str, fixed_topic: bool, publish_on_step: bool) -> str: """ Helper function to parse the signal name and add an output without Publisher to the dict of outputs. @@ -438,23 +438,26 @@ def __create_output(self, signal_name: str, data: str, message_type: MsgT, :return: The parsed signal name """ try: - if message_type == EncodedState and clproto_message_type == clproto.MessageType.UNKNOWN_MESSAGE: - raise AddSignalError(f"Provide a valid clproto message type for outputs of type EncodedState.") - self.declare_output(signal_name, default_topic, fixed_topic) - parsed_signal_name = parse_topic_name(signal_name) if message_type == Bool or message_type == Float64 or \ message_type == Float64MultiArray or message_type == Int32 or message_type == String: translator = modulo_writers.write_std_message elif message_type == EncodedState: - translator = partial(modulo_writers.write_clproto_message, - clproto_message_type=clproto_message_type) + cl_msg_type = clproto_message_type if clproto_message_type else modulo_writers.get_clproto_msg_type( + self.__getattribute__(data)) + if cl_msg_type == clproto.MessageType.UNKNOWN_MESSAGE: + raise AddSignalError(f"Provide a valid clproto message type for output '{ + signal_name}' of type EncodedState.") + translator = partial(modulo_writers.write_clproto_message, clproto_message_type=cl_msg_type) elif hasattr(message_type, 'get_fields_and_field_types'): def write_ros_msg(message, data): for field in message.get_fields_and_field_types().keys(): setattr(message, field, getattr(data, field)) translator = write_ros_msg else: - raise AddSignalError("The provided message type is not supported to create a component output.") + raise AddSignalError( + f"The provided message type is not supported to create component output '{signal_name}'.") + self.declare_output(signal_name, default_topic, fixed_topic) + parsed_signal_name = parse_topic_name(signal_name) self.__outputs[parsed_signal_name] = {"attribute": data, "message_type": message_type, "translator": translator} self.__periodic_outputs[parsed_signal_name] = publish_on_step diff --git a/source/modulo_components/modulo_components/lifecycle_component.py b/source/modulo_components/modulo_components/lifecycle_component.py index ed77b6e5..c65e5dd3 100644 --- a/source/modulo_components/modulo_components/lifecycle_component.py +++ b/source/modulo_components/modulo_components/lifecycle_component.py @@ -1,4 +1,4 @@ -from typing import TypeVar +from typing import Optional, TypeVar import clproto from lifecycle_msgs.msg import State @@ -361,7 +361,7 @@ def __configure_outputs(self) -> bool: return success def add_output(self, signal_name: str, data: str, message_type: MsgT, - clproto_message_type=clproto.MessageType.UNKNOWN_MESSAGE, default_topic="", fixed_topic=False, + clproto_message_type: Optional[clproto.MessageType] = None, default_topic="", fixed_topic=False, publish_on_step=True): """ Add an output signal of the component. diff --git a/source/modulo_components/test/python/conftest.py b/source/modulo_components/test/python/conftest.py index af69a29b..a1141daa 100644 --- a/source/modulo_components/test/python/conftest.py +++ b/source/modulo_components/test/python/conftest.py @@ -49,8 +49,7 @@ def publish(self): component = component_type("minimal_cartesian_output") component._output = random_pose - component.add_output("cartesian_pose", "_output", EncodedState, clproto.MessageType.CARTESIAN_STATE_MESSAGE, - topic, publish_on_step=publish_on_step) + component.add_output("cartesian_pose", "_output", EncodedState, default_topic=topic, publish_on_step=publish_on_step) component.publish = publish.__get__(component) return component @@ -65,8 +64,7 @@ def publish(self): component = component_type("minimal_joint_output") component._output = random_joint - component.add_output("joint_state", "_output", EncodedState, clproto.MessageType.JOINT_STATE_MESSAGE, - topic, publish_on_step=publish_on_step) + component.add_output("joint_state", "_output", EncodedState, default_topic=topic, publish_on_step=publish_on_step) component.publish = publish.__get__(component) return component diff --git a/source/modulo_core/modulo_core/translators/message_writers.py b/source/modulo_core/modulo_core/translators/message_writers.py index 80163b8c..5cbe4325 100644 --- a/source/modulo_core/modulo_core/translators/message_writers.py +++ b/source/modulo_core/modulo_core/translators/message_writers.py @@ -14,6 +14,50 @@ StateT = TypeVar('StateT') +def get_clproto_msg_type(state: StateT) -> clproto.MessageType: + if not isinstance(state, sr.State) or not hasattr(state, 'get_type') or not callable(state.get_type): + return clproto.MessageType.UNKNOWN_MESSAGE + + state_type = state.get_type() + if state_type == sr.StateType.STATE: + return clproto.MessageType.STATE_MESSAGE + elif state_type == sr.StateType.SPATIAL_STATE: + return clproto.MessageType.SPATIAL_STATE_MESSAGE + elif state_type == sr.StateType.CARTESIAN_STATE: + return clproto.MessageType.CARTESIAN_STATE_MESSAGE + elif state_type == sr.StateType.CARTESIAN_POSE: + return clproto.MessageType.CARTESIAN_POSE_MESSAGE + elif state_type == sr.StateType.CARTESIAN_TWIST: + return clproto.MessageType.CARTESIAN_TWIST_MESSAGE + elif state_type == sr.StateType.CARTESIAN_ACCELERATION: + return clproto.MessageType.CARTESIAN_ACCELERATION_MESSAGE + elif state_type == sr.StateType.CARTESIAN_WRENCH: + return clproto.MessageType.CARTESIAN_WRENCH_MESSAGE + elif state_type == sr.StateType.JACOBIAN: + return clproto.MessageType.JACOBIAN_MESSAGE + elif state_type == sr.StateType.JOINT_STATE: + return clproto.MessageType.JOINT_STATE_MESSAGE + elif state_type == sr.StateType.JOINT_POSITIONS: + return clproto.MessageType.JOINT_POSITIONS_MESSAGE + elif state_type == sr.StateType.JOINT_VELOCITIES: + return clproto.MessageType.JOINT_VELOCITIES_MESSAGE + elif state_type == sr.StateType.JOINT_ACCELERATIONS: + return clproto.MessageType.JOINT_ACCELERATIONS_MESSAGE + elif state_type == sr.StateType.JOINT_TORQUES: + return clproto.MessageType.JOINT_TORQUES_MESSAGE + elif state_type == sr.StateType.GEOMETRY_SHAPE: + return clproto.MessageType.SHAPE_MESSAGE + elif state_type == sr.StateType.GEOMETRY_ELLIPSOID: + return clproto.MessageType.ELLIPSOID_MESSAGE + elif state_type == sr.StateType.PARAMETER: + return clproto.MessageType.PARAMETER_MESSAGE + elif state_type == sr.StateType.DIGITAL_IO_STATE: + return clproto.MessageType.DIGITAL_IO_STATE_MESSAGE + elif state_type == sr.StateType.ANALOG_IO_STATE: + return clproto.MessageType.ANALOG_IO_STATE_MESSAGE + return clproto.MessageType.UNKNOWN_MESSAGE + + def write_xyz(message: Union[geometry.Point, geometry.Vector3], vector: np.array): """ Helper function to write a vector to a Point or Vector3 message.