diff --git a/src/pycram/local_transformer.py b/src/pycram/local_transformer.py index 9b5520ccb..af2c56839 100644 --- a/src/pycram/local_transformer.py +++ b/src/pycram/local_transformer.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import logging @@ -10,7 +11,11 @@ from tf import TransformerROS from .datastructures.pose import Pose, Transform -from typing_extensions import List, Optional, Union, Iterable +from typing_extensions import List, Optional, Union, Iterable, TYPE_CHECKING + +if TYPE_CHECKING: + from .world_concepts.world_object import Object + from .datastructures.world import World class LocalTransformer(TransformerROS): @@ -49,7 +54,7 @@ def __init__(self): self._initialized = True def transform_to_object_frame(self, pose: Pose, - world_object: 'world_concepts.world_object.Object', link_name: str = None) -> Union[ + world_object: Object, link_name: str = None) -> Union[ Pose, None]: """ Transforms the given pose to the coordinate frame of the given World object. If no link name is given the @@ -66,13 +71,12 @@ def transform_to_object_frame(self, pose: Pose, target_frame = world_object.tf_frame return self.transform_pose(pose, target_frame) - def update_transforms_for_objects(self, object_names: List[str]) -> None: + def update_transforms_for_objects(self, objects: List[Object]) -> None: """ Updates the transforms for objects affected by the transformation. The objects are identified by their names. - :param object_names: List of object names for which the transforms should be updated + :param objects: List of objects for which the transforms should be updated """ - objects = list(map(self.world.get_object_by_name, object_names)) [obj.update_link_transforms() for obj in objects] def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]: @@ -83,7 +87,7 @@ def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]: :param target_frame: Name of the TF frame into which the Pose should be transformed :return: A transformed pose in the target frame """ - objects = list(map(self.get_object_name_for_frame, [pose.frame, target_frame])) + objects = list(map(self.get_object_from_frame, [pose.frame, target_frame])) self.update_transforms_for_objects([obj for obj in objects if obj is not None]) copy_pose = pose.copy() @@ -101,30 +105,38 @@ def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]: return Pose(*copy_pose.to_list(), frame=new_pose.header.frame_id) - def get_object_name_for_frame(self, frame: str) -> Optional[str]: + def get_object_from_frame(self, frame: str) -> Optional[Object]: """ Get the name of the object that is associated with the given frame. :param frame: The frame for which the object name should be returned :return: The name of the object associated with the frame """ - world = self.prospection_world if self.prospection_prefix in frame else self.world if frame == "map": return None - obj_name = [obj.name for obj in world.objects if frame == obj.tf_frame] - return obj_name[0] if len(obj_name) > 0 else self.get_object_name_for_link_frame(frame) + world = self.get_world_from_frame(frame) + found_objects = [obj for obj in world.objects if frame == obj.tf_frame] + return found_objects[0] if len(found_objects) > 0 else self.get_object_from_link_frame(frame) - def get_object_name_for_link_frame(self, link_frame: str) -> Optional[str]: + def get_object_from_link_frame(self, link_frame: str) -> Optional[Object]: """ Get the name of the object that is associated with the given link frame. :param link_frame: The frame of the link for which the object name should be returned :return: The name of the object associated with the link frame """ - world = self.prospection_world if self.prospection_prefix in link_frame else self.world - object_name = [obj.name for obj in world.objects for link in obj.links.values() - if link_frame in (link.name, link.tf_frame)] - return object_name[0] if len(object_name) > 0 else None + world = self.get_world_from_frame(link_frame) + found_objects = [obj for obj in world.objects for link in obj.links.values() + if link_frame in (link.name, link.tf_frame)] + return found_objects[0] if len(found_objects) > 0 else None + + def get_world_from_frame(self, frame: str) -> World: + """ + Get the world that is associated with the given frame name. + + :param frame: The frame name. + """ + return self.prospection_world if self.prospection_prefix in frame else self.world def lookup_transform_from_source_to_target_frame(self, source_frame: str, target_frame: str, time: Optional[Time] = None) -> Transform: @@ -137,7 +149,7 @@ def lookup_transform_from_source_to_target_frame(self, source_frame: str, target :param time: Time at which the transform should be looked up :return: The transform from source_frame to target_frame """ - objects = list(map(self.get_object_name_for_frame, [source_frame, target_frame])) + objects = list(map(self.get_object_from_frame, [source_frame, target_frame])) self.update_transforms_for_objects([obj for obj in objects if obj is not None]) tf_time = time if time else self.getLatestCommonTime(source_frame, target_frame)