Skip to content

Commit

Permalink
[LocalTransformerBugFix] corrected used world in update transforms.
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdelrhmanBassiouny committed Nov 24, 2024
1 parent f826231 commit 8081f94
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions src/pycram/local_transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import sys
import logging

Expand All @@ -10,7 +11,11 @@
from tf import TransformerROS

from .datastructures.pose import Pose, Transform
from typing_extensions import List, Optional, Union, Iterable
from typing_extensions import List, Optional, Union, Iterable, TYPE_CHECKING

if TYPE_CHECKING:
from .world_concepts.world_object import Object
from .datastructures.world import World


class LocalTransformer(TransformerROS):
Expand Down Expand Up @@ -49,7 +54,7 @@ def __init__(self):
self._initialized = True

def transform_to_object_frame(self, pose: Pose,
world_object: 'world_concepts.world_object.Object', link_name: str = None) -> Union[
world_object: Object, link_name: str = None) -> Union[
Pose, None]:
"""
Transforms the given pose to the coordinate frame of the given World object. If no link name is given the
Expand All @@ -66,13 +71,12 @@ def transform_to_object_frame(self, pose: Pose,
target_frame = world_object.tf_frame
return self.transform_pose(pose, target_frame)

def update_transforms_for_objects(self, object_names: List[str]) -> None:
def update_transforms_for_objects(self, objects: List[Object]) -> None:
"""
Updates the transforms for objects affected by the transformation. The objects are identified by their names.
:param object_names: List of object names for which the transforms should be updated
:param objects: List of objects for which the transforms should be updated
"""
objects = list(map(self.world.get_object_by_name, object_names))
[obj.update_link_transforms() for obj in objects]

def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]:
Expand All @@ -83,7 +87,7 @@ def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]:
:param target_frame: Name of the TF frame into which the Pose should be transformed
:return: A transformed pose in the target frame
"""
objects = list(map(self.get_object_name_for_frame, [pose.frame, target_frame]))
objects = list(map(self.get_object_from_frame, [pose.frame, target_frame]))
self.update_transforms_for_objects([obj for obj in objects if obj is not None])

copy_pose = pose.copy()
Expand All @@ -101,30 +105,38 @@ def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]:

return Pose(*copy_pose.to_list(), frame=new_pose.header.frame_id)

def get_object_name_for_frame(self, frame: str) -> Optional[str]:
def get_object_from_frame(self, frame: str) -> Optional[Object]:
"""
Get the name of the object that is associated with the given frame.
:param frame: The frame for which the object name should be returned
:return: The name of the object associated with the frame
"""
world = self.prospection_world if self.prospection_prefix in frame else self.world
if frame == "map":
return None
obj_name = [obj.name for obj in world.objects if frame == obj.tf_frame]
return obj_name[0] if len(obj_name) > 0 else self.get_object_name_for_link_frame(frame)
world = self.get_world_from_frame(frame)
found_objects = [obj for obj in world.objects if frame == obj.tf_frame]
return found_objects[0] if len(found_objects) > 0 else self.get_object_from_link_frame(frame)

def get_object_name_for_link_frame(self, link_frame: str) -> Optional[str]:
def get_object_from_link_frame(self, link_frame: str) -> Optional[Object]:
"""
Get the name of the object that is associated with the given link frame.
:param link_frame: The frame of the link for which the object name should be returned
:return: The name of the object associated with the link frame
"""
world = self.prospection_world if self.prospection_prefix in link_frame else self.world
object_name = [obj.name for obj in world.objects for link in obj.links.values()
if link_frame in (link.name, link.tf_frame)]
return object_name[0] if len(object_name) > 0 else None
world = self.get_world_from_frame(link_frame)
found_objects = [obj for obj in world.objects for link in obj.links.values()
if link_frame in (link.name, link.tf_frame)]
return found_objects[0] if len(found_objects) > 0 else None

def get_world_from_frame(self, frame: str) -> World:
"""
Get the world that is associated with the given frame name.
:param frame: The frame name.
"""
return self.prospection_world if self.prospection_prefix in frame else self.world

def lookup_transform_from_source_to_target_frame(self, source_frame: str, target_frame: str,
time: Optional[Time] = None) -> Transform:
Expand All @@ -137,7 +149,7 @@ def lookup_transform_from_source_to_target_frame(self, source_frame: str, target
:param time: Time at which the transform should be looked up
:return: The transform from source_frame to target_frame
"""
objects = list(map(self.get_object_name_for_frame, [source_frame, target_frame]))
objects = list(map(self.get_object_from_frame, [source_frame, target_frame]))
self.update_transforms_for_objects([obj for obj in objects if obj is not None])

tf_time = time if time else self.getLatestCommonTime(source_frame, target_frame)
Expand Down

0 comments on commit 8081f94

Please sign in to comment.