diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/config/multiverse_conf.py b/config/multiverse_conf.py new file mode 100644 index 000000000..6463cf151 --- /dev/null +++ b/config/multiverse_conf.py @@ -0,0 +1,72 @@ +import datetime + +from typing_extensions import Type + +from .world_conf import WorldConfig +from pycram.description import ObjectDescription +from pycram.helper import find_multiverse_resources_path +from pycram.object_descriptors.mjcf import ObjectDescription as MJCF + + +class MultiverseConfig(WorldConfig): + # Multiverse Configuration + resources_path = find_multiverse_resources_path() + """ + The path to the Multiverse resources directory. + """ + + # Multiverse Socket Configuration + HOST: str = "tcp://127.0.0.1" + SERVER_HOST: str = HOST + SERVER_PORT: str = 7000 + BASE_CLIENT_PORT: int = 9000 + + # Multiverse Client Configuration + READER_MAX_WAIT_TIME_FOR_DATA: datetime.timedelta = datetime.timedelta(milliseconds=1000) + """ + The maximum wait time for the data in seconds. + """ + + # Multiverse Simulation Configuration + simulation_time_step: datetime.timedelta = datetime.timedelta(milliseconds=10) + simulation_frequency: int = int(1 / simulation_time_step.total_seconds()) + """ + The time step of the simulation in seconds and the frequency of the simulation in Hz. + """ + + simulation_wait_time_factor: float = 1.0 + """ + The factor to multiply the simulation wait time with, this is used to adjust the simulation wait time to account for + the time taken by the simulation to process the request, this depends on the computational power of the machine + running the simulation. + """ + + use_static_mode: bool = True + """ + If True, the simulation will always be in paused state unless the simulate() function is called, this behaves + similar to bullet_world which uses the bullet physics engine. + """ + + use_controller: bool = False + use_controller = use_controller and not use_static_mode + """ + Only used when use_static_mode is False. This turns on the controller for the robot joints. + """ + + default_description_type: Type[ObjectDescription] = MJCF + """ + The default description type for the objects. + """ + + use_physics_simulator_state: bool = True + """ + Whether to use the physics simulator state when restoring or saving the world state. + """ + + clear_cache_at_start = False + + let_pycram_move_attached_objects = False + let_pycram_handle_spawning = False + + position_tolerance = 2e-2 + prismatic_joint_position_tolerance = 2e-2 diff --git a/config/world_conf.py b/config/world_conf.py new file mode 100644 index 000000000..76db4b330 --- /dev/null +++ b/config/world_conf.py @@ -0,0 +1,93 @@ +import math +import os + +from typing_extensions import Tuple, Type +from pycram.description import ObjectDescription +from pycram.object_descriptors.urdf import ObjectDescription as URDF + + +class WorldConfig: + + """ + A class to store the configuration of the world, this can be inherited to create a new configuration class for a + specific world (e.g. multiverse has MultiverseConfig which inherits from this class). + """ + + resources_path = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'resources') + resources_path = os.path.abspath(resources_path) + """ + Global reference for the resources path, this is used to search for the description files of the robot and + the objects. + """ + + cache_dir_name: str = 'cached' + """ + The name of the cache directory. + """ + + cache_dir: str = os.path.join(resources_path, cache_dir_name) + """ + Global reference for the cache directory, this is used to cache the description files of the robot and the objects. + """ + + clear_cache_at_start: bool = True + """ + Whether to clear the cache directory at the start. + """ + + prospection_world_prefix: str = "prospection_" + """ + The prefix for the prospection world name. + """ + + simulation_frequency: int = 240 + """ + The simulation frequency (Hz), used for calculating the equivalent real time in the simulation. + """ + + update_poses_from_sim_on_get: bool = True + """ + Whether to update the poses from the simulator when getting the object poses. + """ + + default_description_type: Type[ObjectDescription] = URDF + """ + The default description type for the objects. + """ + + use_physics_simulator_state: bool = False + """ + Whether to use the physics simulator state when restoring or saving the world state. + Currently with PyBullet, this causes a bug where ray_test does not work correctly after restoring the state using the + simulator, so it is recommended to set this to False in PyBullet. + """ + + let_pycram_move_attached_objects: bool = True + let_pycram_handle_spawning: bool = True + let_pycram_handle_world_sync: bool = True + """ + Whether to let PyCRAM handle the movement of attached objects, the spawning of objects, + and the world synchronization. + """ + + position_tolerance: float = 1e-2 + orientation_tolerance: float = 10 * math.pi / 180 + prismatic_joint_position_tolerance: float = 1e-2 + revolute_joint_position_tolerance: float = 5 * math.pi / 180 + """ + The acceptable error for the position and orientation of an object/link, and the joint positions. + """ + + use_percentage_of_goal: bool = False + acceptable_percentage_of_goal: float = 0.5 + """ + Whether to use a percentage of the goal as the acceptable error. + """ + + raise_goal_validator_error: bool = False + """ + Whether to raise an error if the goals are not achieved. + """ + @classmethod + def get_pose_tolerance(cls) -> Tuple[float, float]: + return cls.position_tolerance, cls.orientation_tolerance diff --git a/demos/pycram_bullet_world_demo/demo.py b/demos/pycram_bullet_world_demo/demo.py index a60df7771..f06fc3d38 100644 --- a/demos/pycram_bullet_world_demo/demo.py +++ b/demos/pycram_bullet_world_demo/demo.py @@ -8,10 +8,12 @@ from pycram.object_descriptors.urdf import ObjectDescription from pycram.world_concepts.world_object import Object from pycram.datastructures.dataclasses import Color +from pycram.ros.viz_marker_publisher import VizMarkerPublisher extension = ObjectDescription.get_file_extension() world = BulletWorld(WorldMode.GUI) + robot = Object("pr2", ObjectType.ROBOT, f"pr2{extension}", pose=Pose([1, 2, 0])) apartment = Object("apartment", ObjectType.ENVIRONMENT, f"apartment{extension}") @@ -94,3 +96,5 @@ def move_and_detect(obj_type): PlaceAction(spoon_desig, [spoon_target_pose], [pickup_arm]).resolve().perform() ParkArmsAction([Arms.BOTH]).resolve().perform() + +world.exit() diff --git a/demos/pycram_multiverse_demo/demo.py b/demos/pycram_multiverse_demo/demo.py new file mode 100644 index 000000000..cdc625dac --- /dev/null +++ b/demos/pycram_multiverse_demo/demo.py @@ -0,0 +1,101 @@ +from pycram.datastructures.dataclasses import Color +from pycram.datastructures.enums import ObjectType, Arms, Grasp +from pycram.datastructures.pose import Pose +from pycram.designators.action_designator import ParkArmsAction, MoveTorsoAction, TransportAction, NavigateAction, \ + LookAtAction, DetectAction, OpenAction, PickUpAction, CloseAction, PlaceAction +from pycram.designators.location_designator import CostmapLocation, AccessingLocation +from pycram.designators.object_designator import BelieveObject, ObjectPart +from pycram.object_descriptors.urdf import ObjectDescription +from pycram.process_module import simulated_robot, with_simulated_robot +from pycram.world_concepts.world_object import Object +from pycram.worlds.multiverse import Multiverse + + +@with_simulated_robot +def move_and_detect(obj_type: ObjectType, pick_pose: Pose): + NavigateAction(target_locations=[Pose([1.7, 2, 0])]).resolve().perform() + + LookAtAction(targets=[pick_pose]).resolve().perform() + + object_desig = DetectAction(BelieveObject(types=[obj_type])).resolve().perform() + + return object_desig + + +world = Multiverse(simulation_name='pycram_test') +extension = ObjectDescription.get_file_extension() +robot = Object('pr2', ObjectType.ROBOT, f'pr2{extension}', pose=Pose([1.3, 2, 0.01])) +apartment = Object("apartment", ObjectType.ENVIRONMENT, f"apartment{extension}") + +milk = Object("milk", ObjectType.MILK, f"milk.stl", pose=Pose([2.4, 2, 1.02]), + color=Color(1, 0, 0, 1)) + +spoon = Object("spoon", ObjectType.SPOON, "spoon.stl", pose=Pose([2.5, 2.2, 0.85]), + color=Color(0, 0, 1, 1)) +apartment.attach(spoon, 'cabinet10_drawer1') + +robot_desig = BelieveObject(names=[robot.name]) +apartment_desig = BelieveObject(names=[apartment.name]) + +with simulated_robot: + + # Transport the milk + ParkArmsAction([Arms.BOTH]).resolve().perform() + + MoveTorsoAction([0.25]).resolve().perform() + + NavigateAction(target_locations=[Pose([1.7, 2, 0])]).resolve().perform() + + LookAtAction(targets=[Pose([2.6, 2.15, 1])]).resolve().perform() + + milk_desig = DetectAction(BelieveObject(types=[milk.obj_type])).resolve().perform() + + TransportAction(milk_desig, [Arms.LEFT], [Pose([2.4, 3, 1.02])]).resolve().perform() + + # Find and navigate to the drawer containing the spoon + handle_desig = ObjectPart(names=["cabinet10_drawer1_handle"], part_of=apartment_desig.resolve()) + drawer_open_loc = AccessingLocation(handle_desig=handle_desig.resolve(), + robot_desig=robot_desig.resolve()).resolve() + + NavigateAction([drawer_open_loc.pose]).resolve().perform() + + OpenAction(object_designator_description=handle_desig, arms=[drawer_open_loc.arms[0]]).resolve().perform() + spoon.detach(apartment) + + # Detect and pickup the spoon + LookAtAction([apartment.get_link_pose("cabinet10_drawer1_handle")]).resolve().perform() + + spoon_desig = DetectAction(BelieveObject(types=[ObjectType.SPOON])).resolve().perform() + + pickup_arm = Arms.LEFT if drawer_open_loc.arms[0] == Arms.RIGHT else Arms.RIGHT + PickUpAction(spoon_desig, [pickup_arm], [Grasp.TOP]).resolve().perform() + + ParkArmsAction([Arms.LEFT if pickup_arm == Arms.LEFT else Arms.RIGHT]).resolve().perform() + + CloseAction(object_designator_description=handle_desig, arms=[drawer_open_loc.arms[0]]).resolve().perform() + + ParkArmsAction([Arms.BOTH]).resolve().perform() + + MoveTorsoAction([0.15]).resolve().perform() + + # Find a pose to place the spoon, move and then place it + spoon_target_pose = Pose([2.35, 2.6, 0.95], [0, 0, 0, 1]) + placing_loc = CostmapLocation(target=spoon_target_pose, reachable_for=robot_desig.resolve()).resolve() + + NavigateAction([placing_loc.pose]).resolve().perform() + + PlaceAction(spoon_desig, [spoon_target_pose], [pickup_arm]).resolve().perform() + + ParkArmsAction([Arms.BOTH]).resolve().perform() + +world.exit() + + +def debug_place_spoon(): + robot.set_pose(Pose([1.66, 2.56, 0.01], [0.0, 0.0, -0.04044101807153309, 0.9991819274072855])) + spoon.set_pose(Pose([1.9601757566599975, 2.06718019258626, 1.0427727691273496], + [0.11157527804553227, -0.7076243776942466, 0.12773644958149588, 0.685931554070963])) + robot.attach(spoon, 'r_gripper_tool_frame') + pickup_arm = Arms.RIGHT + spoon_desig = BelieveObject(names=[spoon.name]) + return pickup_arm, spoon_desig diff --git a/examples/cram_plan_tutorial.md b/examples/cram_plan_tutorial.md index 7606b9d3d..7ca106a34 100644 --- a/examples/cram_plan_tutorial.md +++ b/examples/cram_plan_tutorial.md @@ -28,7 +28,7 @@ from pycram.designators.location_designator import * from pycram.process_module import simulated_robot from pycram.designators.object_designator import * import anytree -import pycram.plan_failures +import pycram.failures ``` Next we will create a bullet world with a PR2 in a kitchen containing milk and cereal. diff --git a/examples/improving_actions.md b/examples/improving_actions.md index dd07b67cb..4cb46a28e 100644 --- a/examples/improving_actions.md +++ b/examples/improving_actions.md @@ -44,7 +44,7 @@ from random_events.product_algebra import Event, SimpleEvent import pycram.orm.base from pycram.designators.action_designator import MoveTorsoActionPerformable -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure from pycram.designators.object_designator import ObjectDesignatorDescription from pycram.worlds.bullet_world import BulletWorld from pycram.world_concepts.world_object import Object diff --git a/examples/language.md b/examples/language.md index 3b5a03fb5..197c879bb 100644 --- a/examples/language.md +++ b/examples/language.md @@ -254,7 +254,7 @@ We will see how exceptions are handled at a simple example. from pycram.designators.action_designator import * from pycram.process_module import simulated_robot from pycram.language import Code -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure def code_test(): diff --git a/examples/minimal_task_tree.md b/examples/minimal_task_tree.md index ae4e16492..ff93d7680 100644 --- a/examples/minimal_task_tree.md +++ b/examples/minimal_task_tree.md @@ -31,7 +31,7 @@ from pycram.designators.object_designator import * from pycram.datastructures.pose import Pose from pycram.datastructures.enums import ObjectType, WorldMode import anytree -import pycram.plan_failures +import pycram.failures ``` Next we will create a bullet world with a PR2 in a kitchen containing milk and cereal. diff --git a/examples/orm_querying_examples.md b/examples/orm_querying_examples.md index 24d1625ab..b24c7f740 100644 --- a/examples/orm_querying_examples.md +++ b/examples/orm_querying_examples.md @@ -30,9 +30,10 @@ import tqdm import pycram.orm.base from pycram.worlds.bullet_world import BulletWorld from pycram.world_concepts.world_object import Object as BulletWorldObject -from pycram.designators.action_designator import MoveTorsoAction, PickUpAction, NavigateAction, ParkArmsAction, ParkArmsActionPerformable, MoveTorsoActionPerformable +from pycram.designators.action_designator import MoveTorsoAction, PickUpAction, NavigateAction, ParkArmsAction, + ParkArmsActionPerformable, MoveTorsoActionPerformable from pycram.designators.object_designator import ObjectDesignatorDescription -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure from pycram.process_module import ProcessModule from pycram.datastructures.enums import Arms, ObjectType, Grasp, WorldMode from pycram.process_module import simulated_robot diff --git a/requirements.txt b/requirements.txt index 9586bf3aa..0276c54ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,9 @@ pynput~=1.7.7 playsound~=1.3.0 pydub~=0.25.1 gTTS~=2.5.3 +dm_control +trimesh +deprecated probabilistic_model>=5.1.0 random_events>=3.0.7 sympy diff --git a/resources/IAI_kitchen.urdf b/resources/objects/IAI_kitchen.urdf similarity index 100% rename from resources/IAI_kitchen.urdf rename to resources/objects/IAI_kitchen.urdf diff --git a/resources/Static_CokeBottle.stl b/resources/objects/Static_CokeBottle.stl similarity index 100% rename from resources/Static_CokeBottle.stl rename to resources/objects/Static_CokeBottle.stl diff --git a/resources/Static_MilkPitcher.stl b/resources/objects/Static_MilkPitcher.stl similarity index 100% rename from resources/Static_MilkPitcher.stl rename to resources/objects/Static_MilkPitcher.stl diff --git a/resources/apartment.urdf b/resources/objects/apartment.urdf similarity index 100% rename from resources/apartment.urdf rename to resources/objects/apartment.urdf diff --git a/resources/apartment_bowl.stl b/resources/objects/apartment_bowl.stl similarity index 100% rename from resources/apartment_bowl.stl rename to resources/objects/apartment_bowl.stl diff --git a/resources/apartment_without_walls.urdf b/resources/objects/apartment_without_walls.urdf similarity index 100% rename from resources/apartment_without_walls.urdf rename to resources/objects/apartment_without_walls.urdf diff --git a/resources/bowl.stl b/resources/objects/bowl.stl similarity index 100% rename from resources/bowl.stl rename to resources/objects/bowl.stl diff --git a/resources/box.urdf b/resources/objects/box.urdf similarity index 100% rename from resources/box.urdf rename to resources/objects/box.urdf diff --git a/resources/breakfast_cereal.stl b/resources/objects/breakfast_cereal.stl similarity index 100% rename from resources/breakfast_cereal.stl rename to resources/objects/breakfast_cereal.stl diff --git a/resources/broken_kitchen.urdf b/resources/objects/broken_kitchen.urdf similarity index 100% rename from resources/broken_kitchen.urdf rename to resources/objects/broken_kitchen.urdf diff --git a/resources/jeroen_cup.stl b/resources/objects/jeroen_cup.stl similarity index 100% rename from resources/jeroen_cup.stl rename to resources/objects/jeroen_cup.stl diff --git a/resources/kitchen.urdf b/resources/objects/kitchen.urdf similarity index 100% rename from resources/kitchen.urdf rename to resources/objects/kitchen.urdf diff --git a/resources/milk.stl b/resources/objects/milk.stl similarity index 100% rename from resources/milk.stl rename to resources/objects/milk.stl diff --git a/resources/plane.obj b/resources/objects/plane.obj similarity index 100% rename from resources/plane.obj rename to resources/objects/plane.obj diff --git a/resources/plane.urdf b/resources/objects/plane.urdf similarity index 100% rename from resources/plane.urdf rename to resources/objects/plane.urdf diff --git a/resources/spoon.stl b/resources/objects/spoon.stl similarity index 100% rename from resources/spoon.stl rename to resources/objects/spoon.stl diff --git a/src/pycram/__init__.py b/src/pycram/__init__.py index 970badb95..9fde2cdac 100644 --- a/src/pycram/__init__.py +++ b/src/pycram/__init__.py @@ -1,4 +1,5 @@ -import pycram.process_modules +from . import process_modules +from . import robot_descriptions # from .specialized_designators import * from .datastructures.world import World diff --git a/src/pycram/cache_manager.py b/src/pycram/cache_manager.py index 1c7b86cea..3e6a9889d 100644 --- a/src/pycram/cache_manager.py +++ b/src/pycram/cache_manager.py @@ -1,8 +1,9 @@ import glob import os import pathlib +import shutil -from typing_extensions import List, TYPE_CHECKING +from typing_extensions import List, TYPE_CHECKING, Optional if TYPE_CHECKING: from .description import ObjectDescription @@ -14,31 +15,52 @@ class CacheManager: The CacheManager is responsible for caching object description files and managing the cache directory. """ - mesh_extensions: List[str] = [".obj", ".stl"] + cache_cleared: bool = False """ - The file extensions of mesh files. + Indicate whether the cache directory has been cleared at least once since beginning or not. """ - def __init__(self, cache_dir: str, data_directory: List[str]): + def __init__(self, cache_dir: str, data_directory: List[str], clear_cache: bool = True): """ - Initializes the CacheManager. + Initialize the CacheManager. :param cache_dir: The directory where the cached files are stored. :param data_directory: The directory where all resource files are stored. + :param clear_cache: If True, the cache directory will be cleared. """ self.cache_dir = cache_dir - self.data_directory = data_directory + self.data_directories = data_directory + if clear_cache: + self.clear_cache() + + def clear_cache(self): + """ + Clear the cache directory. + """ + self.delete_cache_dir() + self.create_cache_dir_if_not_exists() + self.cache_cleared = True + + def delete_cache_dir(self): + """ + Delete the cache directory. + """ + if pathlib.Path(self.cache_dir).exists(): + shutil.rmtree(self.cache_dir) def update_cache_dir_with_object(self, path: str, ignore_cached_files: bool, - object_description: 'ObjectDescription', object_name: str) -> str: + object_description: 'ObjectDescription', object_name: str, + scale_mesh: Optional[float] = None) -> str: """ - Checks if the file is already in the cache directory, if not it will be preprocessed and saved in the cache. + Check if the file is already in the cache directory, if not preprocess and save in the cache. :param path: The path of the file to preprocess and save in the cache directory. :param ignore_cached_files: If True, the file will be preprocessed and saved in the cache directory even if it is already cached. :param object_description: The object description of the file. :param object_name: The name of the object. + :param scale_mesh: The scale of the mesh. + :return: The path of the cached file. """ path_object = pathlib.Path(path) extension = path_object.suffix @@ -46,49 +68,24 @@ def update_cache_dir_with_object(self, path: str, ignore_cached_files: bool, self.create_cache_dir_if_not_exists() # save correct path in case the file is already in the cache directory - cache_path = self.cache_dir + object_description.get_file_name(path_object, extension, object_name) + cache_path = os.path.join(self.cache_dir, object_description.get_file_name(path_object, extension, object_name)) if not self.is_cached(path, object_description) or ignore_cached_files: # if file is not yet cached preprocess the description file and save it in the cache directory. path = self.look_for_file_in_data_dir(path_object) - self.generate_description_and_write_to_cache(path, object_name, extension, cache_path, object_description) + object_description.generate_description_from_file(path, object_name, extension, cache_path, scale_mesh) return cache_path - def generate_description_and_write_to_cache(self, path: str, name: str, extension: str, cache_path: str, - object_description: 'ObjectDescription') -> None: - """ - Generates the description from the file at the given path and writes it to the cache directory. - - :param path: The path of the file to preprocess. - :param name: The name of the object. - :param extension: The file extension of the file to preprocess. - :param cache_path: The path of the file in the cache directory. - :param object_description: The object description of the file. - """ - description_string = object_description.generate_description_from_file(path, name, extension) - self.write_to_cache(description_string, cache_path) - - @staticmethod - def write_to_cache(description_string: str, cache_path: str) -> None: - """ - Writes the description string to the cache directory. - - :param description_string: The description string to write to the cache directory. - :param cache_path: The path of the file in the cache directory. - """ - with open(cache_path, "w") as file: - file.write(description_string) - def look_for_file_in_data_dir(self, path_object: pathlib.Path) -> str: """ - Looks for a file in the data directory of the World. If the file is not found in the data directory, this method - raises a FileNotFoundError. + Look for a file in the data directory of the World. If the file is not found in the data directory, raise a + FileNotFoundError. :param path_object: The pathlib object of the file to look for. """ name = path_object.name - for data_dir in self.data_directory: + for data_dir in self.data_directories: data_path = pathlib.Path(data_dir).joinpath("**") for file in glob.glob(str(data_path), recursive=True): file_path = pathlib.Path(file) @@ -97,18 +94,18 @@ def look_for_file_in_data_dir(self, path_object: pathlib.Path) -> str: return str(file_path) raise FileNotFoundError( - f"File {name} could not be found in the resource directory {self.data_directory}") + f"File {name} could not be found in the resource directory {self.data_directories}") def create_cache_dir_if_not_exists(self): """ - Creates the cache directory if it does not exist. + Create the cache directory if it does not exist. """ if not pathlib.Path(self.cache_dir).exists(): os.mkdir(self.cache_dir) def is_cached(self, path: str, object_description: 'ObjectDescription') -> bool: """ - Checks if the file in the given path is already cached or if + Check if the file in the given path is already cached or if there is already a cached file with the given name, this is the case if a .stl, .obj file or a description from the parameter server is used. @@ -116,26 +113,26 @@ def is_cached(self, path: str, object_description: 'ObjectDescription') -> bool: :param object_description: The object description of the file. :return: True if there already exists a cached file, False in any other case. """ - return True if self.check_with_extension(path) else self.check_without_extension(path, object_description) + return self.check_with_extension(path) or self.check_without_extension(path, object_description) def check_with_extension(self, path: str) -> bool: """ - Checks if the file in the given ath exists in the cache directory including file extension. + Check if the file in the given ath exists in the cache directory including file extension. :param path: The path of the file to check. """ file_name = pathlib.Path(path).name - full_path = pathlib.Path(self.cache_dir + file_name) + full_path = pathlib.Path(os.path.join(self.cache_dir, file_name)) return full_path.exists() def check_without_extension(self, path: str, object_description: 'ObjectDescription') -> bool: """ - Checks if the file in the given path exists in the cache directory without file extension, - the extension is added after the file name manually in this case. + Check if the file in the given path exists in the cache directory the given file extension. + Instead, replace the given extension with the extension of the used ObjectDescription and check for that one. :param path: The path of the file to check. :param object_description: The object description of the file. """ file_stem = pathlib.Path(path).stem - full_path = pathlib.Path(self.cache_dir + file_stem + object_description.get_file_extension()) + full_path = pathlib.Path(os.path.join(self.cache_dir, file_stem + object_description.get_file_extension())) return full_path.exists() diff --git a/src/pycram/config b/src/pycram/config new file mode 120000 index 000000000..899f69898 --- /dev/null +++ b/src/pycram/config @@ -0,0 +1 @@ +../../config \ No newline at end of file diff --git a/src/pycram/costmaps.py b/src/pycram/costmaps.py index 8b898617d..ad34677e6 100644 --- a/src/pycram/costmaps.py +++ b/src/pycram/costmaps.py @@ -29,8 +29,7 @@ from .datastructures.pose import Pose, Transform from .datastructures.world import World from .datastructures.dataclasses import AxisAlignedBoundingBox, BoxVisualShape, Color - -import pycram_bullet as p +from tf.transformations import quaternion_from_matrix @dataclass @@ -121,42 +120,21 @@ def visualize(self) -> None: # Creation of the visual shapes, for documentation of the visual shapes # please look here: https://docs.google.com/document/d/10sXEhzFRSnvFcl3XxNGhnD4N2SedqwdAvK3dsihxVUA/edit#heading=h.q1gn7v6o58bf for box in boxes: - visual = p.createVisualShape(p.GEOM_BOX, - halfExtents=[(box[1] * self.resolution) / 2, (box[2] * self.resolution) / 2, - 0.001], - rgbaColor=[1, 0, 0, 0.6], - visualFramePosition=[(box[0][0] + box[1] / 2) * self.resolution, - (box[0][1] + box[2] / 2) * self.resolution, 0.]) + box = BoxVisualShape(Color(1, 0, 0, 0.6), + [(box[0][0] + box[1] / 2) * self.resolution, + (box[0][1] + box[2] / 2) * self.resolution, 0.], + [(box[1] * self.resolution) / 2, (box[2] * self.resolution) / 2, 0.001]) + visual = self.world.create_visual_shape(box) cells.append(visual) # Set to 127 for since this is the maximal amount of links in a multibody for cell_parts in self._chunks(cells, 127): - # Dummy paramater since these are needed to spawn visual shapes as a - # multibody. - link_poses = [[0, 0, 0] for c in cell_parts] - link_orientations = [[0, 0, 0, 1] for c in cell_parts] - link_masses = [1.0 for c in cell_parts] - link_parent = [0 for c in cell_parts] - link_joints = [p.JOINT_FIXED for c in cell_parts] - link_collision = [-1 for c in cell_parts] - link_joint_axis = [[1, 0, 0] for c in cell_parts] - # The position at which the multibody will be spawned. Offset such that - # the origin referes to the centre of the costmap. - # origin_pose = self.origin.position_as_list() - # base_pose = [origin_pose[0] - self.height / 2 * self.resolution, - # origin_pose[1] - self.width / 2 * self.resolution, origin_pose[2]] - offset = [[-self.height / 2 * self.resolution, -self.width / 2 * self.resolution, 0.05], [0, 0, 0, 1]] - new_pose = p.multiplyTransforms(self.origin.position_as_list(), self.origin.orientation_as_list(), - offset[0], offset[1]) - - map_obj = p.createMultiBody(baseVisualShapeIndex=-1, linkVisualShapeIndices=cell_parts, - basePosition=new_pose[0], baseOrientation=new_pose[1], linkPositions=link_poses, - # [0, 0, 1, 0] - linkMasses=link_masses, linkOrientations=link_orientations, - linkInertialFramePositions=link_poses, - linkInertialFrameOrientations=link_orientations, linkParentIndices=link_parent, - linkJointTypes=link_joints, linkJointAxis=link_joint_axis, - linkCollisionShapeIndices=link_collision) + origin_transform = (Transform(self.origin.position_as_list(), self.origin.orientation_as_list()) + .get_homogeneous_matrix()) + offset_transform = (Transform(offset[0], offset[1]).get_homogeneous_matrix()) + new_pose_transform = np.dot(origin_transform, offset_transform) + new_pose = Pose(new_pose_transform[:3, 3].tolist(), quaternion_from_matrix(new_pose_transform)) + map_obj = self.world.create_multi_body_from_visual_shapes(cell_parts, new_pose) self.vis_ids.append(map_obj) def _chunks(self, lst: List, n: int) -> List: @@ -175,7 +153,7 @@ def close_visualization(self) -> None: Removes the visualization from the World. """ for v_id in self.vis_ids: - self.world.remove_object_by_id(v_id) + self.world.remove_visual_object(v_id) self.vis_ids = [] def _find_consectuive_line(self, start: Tuple[int, int], map: np.ndarray) -> int: @@ -471,7 +449,6 @@ def _create_from_world(self, size: int, resolution: float) -> np.ndarray: i = 0 j = 0 for n in self._chunks(np.array(rays), 16380): - # with UseProspectionWorld(): r_t = World.current_world.ray_test_batch(n[:, 0], n[:, 1], num_threads=0) while r_t is None: r_t = World.current_world.ray_test_batch(n[:, 0], n[:, 1], num_threads=0) @@ -797,11 +774,10 @@ def generate_map(self) -> None: def get_aabb_for_link(self) -> AxisAlignedBoundingBox: """ - Returns the axis aligned bounding box (AABB) of the link provided when creating this costmap. To try and let the - AABB as close to the actual object as possible, the Object will be rotated such that the link will be in the - identity orientation. - :return: Two points in world coordinate space, which span a rectangle + :return: The axis aligned bounding box (AABB) of the link provided when creating this costmap. To try and let + the AABB as close to the actual object as possible, the Object will be rotated such that the link will be in the + identity orientation. """ prospection_object = World.current_world.get_prospection_object_for_object(self.object) with UseProspectionWorld(): diff --git a/src/pycram/datastructures/dataclasses.py b/src/pycram/datastructures/dataclasses.py index d83bcd00f..040189ce3 100644 --- a/src/pycram/datastructures/dataclasses.py +++ b/src/pycram/datastructures/dataclasses.py @@ -1,10 +1,14 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from copy import deepcopy, copy from dataclasses import dataclass + from typing_extensions import List, Optional, Tuple, Callable, Dict, Any, Union, TYPE_CHECKING -from .enums import JointType, Shape + +from .enums import JointType, Shape, VirtualMobileBaseJointName from .pose import Pose, Point -from abc import ABC, abstractmethod +from ..validation.error_checkers import calculate_joint_position_error, is_error_acceptable if TYPE_CHECKING: from ..description import Link @@ -14,7 +18,7 @@ def get_point_as_list(point: Point) -> List[float]: """ - Returns the point as a list. + Return the point as a list. :param point: The point. :return: The point as a list @@ -37,7 +41,7 @@ class Color: @classmethod def from_list(cls, color: List[float]): """ - Sets the rgba_color from a list of RGBA values. + Set the rgba_color from a list of RGBA values. :param color: The list of RGBA values """ @@ -51,7 +55,7 @@ def from_list(cls, color: List[float]): @classmethod def from_rgb(cls, rgb: List[float]): """ - Sets the rgba_color from a list of RGB values. + Set the rgba_color from a list of RGB values. :param rgb: The list of RGB values """ @@ -60,7 +64,7 @@ def from_rgb(cls, rgb: List[float]): @classmethod def from_rgba(cls, rgba: List[float]): """ - Sets the rgba_color from a list of RGBA values. + Set the rgba_color from a list of RGBA values. :param rgba: The list of RGBA values """ @@ -68,7 +72,7 @@ def from_rgba(cls, rgba: List[float]): def get_rgba(self) -> List[float]: """ - Returns the rgba_color as a list of RGBA values. + Return the rgba_color as a list of RGBA values. :return: The rgba_color as a list of RGBA values """ @@ -76,7 +80,7 @@ def get_rgba(self) -> List[float]: def get_rgb(self) -> List[float]: """ - Returns the rgba_color as a list of RGB values. + Return the rgba_color as a list of RGB values. :return: The rgba_color as a list of RGB values """ @@ -98,7 +102,7 @@ class AxisAlignedBoundingBox: @classmethod def from_min_max(cls, min_point: List[float], max_point: List[float]): """ - Sets the axis-aligned bounding box from a minimum and maximum point. + Set the axis-aligned bounding box from a minimum and maximum point. :param min_point: The minimum point :param max_point: The maximum point @@ -107,48 +111,36 @@ def from_min_max(cls, min_point: List[float], max_point: List[float]): def get_min_max_points(self) -> Tuple[Point, Point]: """ - Returns the axis-aligned bounding box as a tuple of minimum and maximum points. - :return: The axis-aligned bounding box as a tuple of minimum and maximum points """ return self.get_min_point(), self.get_max_point() def get_min_point(self) -> Point: """ - Returns the axis-aligned bounding box as a minimum point. - :return: The axis-aligned bounding box as a minimum point """ return Point(self.min_x, self.min_y, self.min_z) def get_max_point(self) -> Point: """ - Returns the axis-aligned bounding box as a maximum point. - :return: The axis-aligned bounding box as a maximum point """ return Point(self.max_x, self.max_y, self.max_z) def get_min_max(self) -> Tuple[List[float], List[float]]: """ - Returns the axis-aligned bounding box as a tuple of minimum and maximum points. - :return: The axis-aligned bounding box as a tuple of minimum and maximum points """ return self.get_min(), self.get_max() def get_min(self) -> List[float]: """ - Returns the minimum point of the axis-aligned bounding box. - :return: The minimum point of the axis-aligned bounding box """ return [self.min_x, self.min_y, self.min_z] def get_max(self) -> List[float]: """ - Returns the maximum point of the axis-aligned bounding box. - :return: The maximum point of the axis-aligned bounding box """ return [self.max_x, self.max_y, self.max_z] @@ -156,12 +148,19 @@ def get_max(self) -> List[float]: @dataclass class CollisionCallbacks: + """ + Dataclass for storing the collision callbacks which are callables that get called when there is a collision + or when a collision is no longer there. + """ on_collision_cb: Callable no_collision_cb: Optional[Callable] = None @dataclass class MultiBody: + """ + Dataclass for storing the information of a multibody which consists of a base and multiple links with joints. + """ base_visual_shape_index: int base_pose: Pose link_visual_shape_indices: List[int] @@ -176,13 +175,16 @@ class MultiBody: @dataclass class VisualShape(ABC): + """ + Abstract dataclass for storing the information of a visual shape. + """ rgba_color: Color visual_frame_position: List[float] @abstractmethod def shape_data(self) -> Dict[str, Any]: """ - Returns the shape data of the visual shape (e.g. half extents for a box, radius for a sphere). + :return: the shape data of the visual shape (e.g. half extents for a box, radius for a sphere) as a dictionary. """ pass @@ -190,13 +192,16 @@ def shape_data(self) -> Dict[str, Any]: @abstractmethod def visual_geometry_type(self) -> Shape: """ - Returns the visual geometry type of the visual shape (e.g. box, sphere). + :return: The visual geometry type of the visual shape (e.g. box, sphere) as a Shape object. """ pass @dataclass class BoxVisualShape(VisualShape): + """ + Dataclass for storing the information of a box visual shape + """ half_extents: List[float] def shape_data(self) -> Dict[str, List[float]]: @@ -213,6 +218,9 @@ def size(self) -> List[float]: @dataclass class SphereVisualShape(VisualShape): + """ + Dataclass for storing the information of a sphere visual shape + """ radius: float def shape_data(self) -> Dict[str, float]: @@ -225,6 +233,9 @@ def visual_geometry_type(self) -> Shape: @dataclass class CapsuleVisualShape(VisualShape): + """ + Dataclass for storing the information of a capsule visual shape + """ radius: float length: float @@ -238,6 +249,9 @@ def visual_geometry_type(self) -> Shape: @dataclass class CylinderVisualShape(CapsuleVisualShape): + """ + Dataclass for storing the information of a cylinder visual shape + """ @property def visual_geometry_type(self) -> Shape: @@ -246,6 +260,9 @@ def visual_geometry_type(self) -> Shape: @dataclass class MeshVisualShape(VisualShape): + """ + Dataclass for storing the information of a mesh visual shape + """ scale: List[float] file_name: str @@ -259,6 +276,9 @@ def visual_geometry_type(self) -> Shape: @dataclass class PlaneVisualShape(VisualShape): + """ + Dataclass for storing the information of a plane visual shape + """ normal: List[float] def shape_data(self) -> Dict[str, List[float]]: @@ -271,28 +291,409 @@ def visual_geometry_type(self) -> Shape: @dataclass class State(ABC): + """ + Abstract dataclass for storing the state of an entity (e.g. world, object, link, joint). + """ pass @dataclass class LinkState(State): + """ + Dataclass for storing the state of a link. + """ constraint_ids: Dict[Link, int] + def __eq__(self, other: 'LinkState'): + return self.all_constraints_exist(other) and self.all_constraints_are_equal(other) + + def all_constraints_exist(self, other: 'LinkState') -> bool: + """ + Check if all constraints exist in the other link state. + + :param other: The state of the other link. + :return: True if all constraints exist, False otherwise. + """ + return (all([cid_k in other.constraint_ids.keys() for cid_k in self.constraint_ids.keys()]) + and len(self.constraint_ids.keys()) == len(other.constraint_ids.keys())) + + def all_constraints_are_equal(self, other: 'LinkState') -> bool: + """ + Check if all constraints are equal to the ones in the other link state. + + :param other: The state of the other link. + :return: True if all constraints are equal, False otherwise. + """ + return all([cid == other_cid for cid, other_cid in zip(self.constraint_ids.values(), + other.constraint_ids.values())]) + + def __copy__(self): + return LinkState(constraint_ids=copy(self.constraint_ids)) + @dataclass class JointState(State): + """ + Dataclass for storing the state of a joint. + """ position: float + acceptable_error: float + + def __eq__(self, other: 'JointState'): + error = calculate_joint_position_error(self.position, other.position) + return is_error_acceptable(error, other.acceptable_error) + + def __copy__(self): + return JointState(position=self.position, acceptable_error=self.acceptable_error) @dataclass class ObjectState(State): + """ + Dataclass for storing the state of an object. + """ pose: Pose attachments: Dict[Object, Attachment] link_states: Dict[int, LinkState] joint_states: Dict[int, JointState] + acceptable_pose_error: Tuple[float, float] + + def __eq__(self, other: 'ObjectState'): + return (self.pose_is_almost_equal(other) + and self.all_attachments_exist(other) and self.all_attachments_are_equal(other) + and self.link_states == other.link_states + and self.joint_states == other.joint_states) + + def pose_is_almost_equal(self, other: 'ObjectState') -> bool: + """ + Check if the pose of the object is almost equal to the pose of another object. + + :param other: The state of the other object. + :return: True if the poses are almost equal, False otherwise. + """ + return self.pose.almost_equal(other.pose, other.acceptable_pose_error[0], other.acceptable_pose_error[1]) + + def all_attachments_exist(self, other: 'ObjectState') -> bool: + """ + Check if all attachments exist in the other object state. + + :param other: The state of the other object. + :return: True if all attachments exist, False otherwise. + """ + return (all([att_k in other.attachments.keys() for att_k in self.attachments.keys()]) + and len(self.attachments.keys()) == len(other.attachments.keys())) + + def all_attachments_are_equal(self, other: 'ObjectState') -> bool: + """ + Check if all attachments are equal to the ones in the other object state. + + :param other: The state of the other object. + :return: True if all attachments are equal, False otherwise + """ + return all([att == other_att for att, other_att in zip(self.attachments.values(), other.attachments.values())]) + + def __copy__(self): + return ObjectState(pose=self.pose.copy(), attachments=copy(self.attachments), + link_states=copy(self.link_states), + joint_states=copy(self.joint_states), + acceptable_pose_error=deepcopy(self.acceptable_pose_error)) @dataclass class WorldState(State): - simulator_state_id: int + """ + Dataclass for storing the state of the world. + """ + simulator_state_id: Optional[int] object_states: Dict[str, ObjectState] + + def __eq__(self, other: 'WorldState'): + return (self.simulator_state_is_equal(other) and self.all_objects_exist(other) + and self.all_objects_states_are_equal(other)) + + def simulator_state_is_equal(self, other: 'WorldState') -> bool: + """ + Check if the simulator state is equal to the simulator state of another world state. + + :param other: The state of the other world. + :return: True if the simulator states are equal, False otherwise. + """ + return self.simulator_state_id == other.simulator_state_id + + def all_objects_exist(self, other: 'WorldState') -> bool: + """ + Check if all objects exist in the other world state. + + :param other: The state of the other world. + :return: True if all objects exist, False otherwise. + """ + return (all([obj_name in other.object_states.keys() for obj_name in self.object_states.keys()]) + and len(self.object_states.keys()) == len(other.object_states.keys())) + + def all_objects_states_are_equal(self, other: 'WorldState') -> bool: + """ + Check if all object states are equal to the ones in the other world state. + + :param other: The state of the other world. + :return: True if all object states are equal, False otherwise. + """ + return all([obj_state == other_obj_state + for obj_state, other_obj_state in zip(self.object_states.values(), + other.object_states.values())]) + + def __copy__(self): + return WorldState(simulator_state_id=self.simulator_state_id, + object_states=deepcopy(self.object_states)) + + +@dataclass +class LateralFriction: + """ + Dataclass for storing the information of the lateral friction. + """ + lateral_friction: float + lateral_friction_direction: List[float] + + +@dataclass +class ContactPoint: + """ + Dataclass for storing the information of a contact point between two objects. + """ + link_a: Link + link_b: Link + position_on_object_a: Optional[List[float]] = None + position_on_object_b: Optional[List[float]] = None + normal_on_b: Optional[List[float]] = None # normal on object b pointing towards object a + distance: Optional[float] = None + normal_force: Optional[List[float]] = None # normal force applied during last step simulation + lateral_friction_1: Optional[LateralFriction] = None + lateral_friction_2: Optional[LateralFriction] = None + force_x_in_world_frame: Optional[float] = None + force_y_in_world_frame: Optional[float] = None + force_z_in_world_frame: Optional[float] = None + + def __str__(self): + return f"ContactPoint: {self.link_a.object.name} - {self.link_b.object.name}" + + def __repr__(self): + return self.__str__() + + +ClosestPoint = ContactPoint +""" +The closest point between two objects which has the same structure as ContactPoint. +""" + + +class ContactPointsList(list): + """ + A list of contact points. + """ + def get_links_that_got_removed(self, previous_points: 'ContactPointsList') -> List[Link]: + """ + Return the links that are not in the current points list but were in the initial points list. + + :param previous_points: The initial points list. + :return: A list of Link instances that represent the links that got removed. + """ + initial_links_in_contact = previous_points.get_links_in_contact() + current_links_in_contact = self.get_links_in_contact() + return [link for link in initial_links_in_contact if link not in current_links_in_contact] + + def get_links_in_contact(self) -> List[Link]: + """ + Get the links in contact. + + :return: A list of Link instances that represent the links in contact. + """ + return [point.link_b for point in self] + + def check_if_two_objects_are_in_contact(self, obj_a: Object, obj_b: Object) -> bool: + """ + Check if two objects are in contact. + + :param obj_a: An instance of the Object class that represents the first object. + :param obj_b: An instance of the Object class that represents the second object. + :return: True if the objects are in contact, False otherwise. + """ + return (any([point.link_b.object == obj_b and point.link_a.object == obj_a for point in self]) or + any([point.link_a.object == obj_b and point.link_b.object == obj_a for point in self])) + + def get_normals_of_object(self, obj: Object) -> List[List[float]]: + """ + Get the normals of the object. + + :param obj: An instance of the Object class that represents the object. + :return: A list of float vectors that represent the normals of the object. + """ + return self.get_points_of_object(obj).get_normals() + + def get_normals(self) -> List[List[float]]: + """ + Get the normals of the points. + + :return: A list of float vectors that represent the normals of the contact points. + """ + return [point.normal_on_b for point in self] + + def get_links_in_contact_of_object(self, obj: Object) -> List[Link]: + """ + Get the links in contact of the object. + + :param obj: An instance of the Object class that represents the object. + :return: A list of Link instances that represent the links in contact of the object. + """ + return [point.link_b for point in self if point.link_b.object == obj] + + def get_points_of_object(self, obj: Object) -> 'ContactPointsList': + """ + Get the points of the object. + + :param obj: An instance of the Object class that represents the object that the points are related to. + :return: A ContactPointsList instance that represents the contact points of the object. + """ + return ContactPointsList([point for point in self if point.link_b.object == obj]) + + def get_objects_that_got_removed(self, previous_points: 'ContactPointsList') -> List[Object]: + """ + Return the object that is not in the current points list but was in the initial points list. + + :param previous_points: The initial points list. + :return: A list of Object instances that represent the objects that got removed. + """ + initial_objects_in_contact = previous_points.get_objects_that_have_points() + current_objects_in_contact = self.get_objects_that_have_points() + return [obj for obj in initial_objects_in_contact if obj not in current_objects_in_contact] + + def get_new_objects(self, previous_points: 'ContactPointsList') -> List[Object]: + """ + Return the object that is not in the initial points list but is in the current points list. + + :param previous_points: The initial points list. + :return: A list of Object instances that represent the new objects. + """ + initial_objects_in_contact = previous_points.get_objects_that_have_points() + current_objects_in_contact = self.get_objects_that_have_points() + return [obj for obj in current_objects_in_contact if obj not in initial_objects_in_contact] + + def is_object_in_the_list(self, obj: Object) -> bool: + """ + Check if the object is one of the objects that have points in the list. + + :param obj: An instance of the Object class that represents the object. + :return: True if the object is in the list, False otherwise. + """ + return obj in self.get_objects_that_have_points() + + def get_names_of_objects_that_have_points(self) -> List[str]: + """ + Return the names of the objects that have points in the list. + + :return: A list of strings that represent the names of the objects that have points in the list. + """ + return [obj.name for obj in self.get_objects_that_have_points()] + + def get_objects_that_have_points(self) -> List[Object]: + """ + Return the objects that have points in the list. + + :return: A list of Object instances that represent the objects that have points in the list. + """ + return list({point.link_b.object for point in self}) + + def __str__(self): + return f"ContactPointsList: {', '.join(self.get_names_of_objects_that_have_points())}" + + def __repr__(self): + return self.__str__() + + +ClosestPointsList = ContactPointsList +""" +The list of closest points which has same structure as ContactPointsList. +""" + + +@dataclass +class TextAnnotation: + """ + Dataclass for storing text annotations that can be displayed in the simulation. + """ + text: str + position: List[float] + id: int + color: Color = Color(0, 0, 0, 1) + size: float = 0.1 + + +@dataclass +class VirtualMobileBaseJoints: + """ + Dataclass for storing the names, types and axes of the virtual mobile base joints of a mobile robot. + """ + translation_x: Optional[str] = VirtualMobileBaseJointName.LINEAR_X.value + translation_y: Optional[str] = VirtualMobileBaseJointName.LINEAR_Y.value + angular_z: Optional[str] = VirtualMobileBaseJointName.ANGULAR_Z.value + + @property + def names(self) -> List[str]: + """ + Return the names of the virtual mobile base joints. + """ + return [self.translation_x, self.translation_y, self.angular_z] + + def get_types(self) -> Dict[str, JointType]: + """ + Return the joint types of the virtual mobile base joints. + """ + return {self.translation_x: JointType.PRISMATIC, + self.translation_y: JointType.PRISMATIC, + self.angular_z: JointType.REVOLUTE} + + def get_axes(self) -> Dict[str, Point]: + """ + Return the axes (i.e. The axis on which the joint moves) of the virtual mobile base joints. + """ + return {self.translation_x: Point(1, 0, 0), + self.translation_y: Point(0, 1, 0), + self.angular_z: Point(0, 0, 1)} + + +@dataclass +class MultiverseMetaData: + """Meta data for the Multiverse Client, the simulation_name should be non-empty and unique for each simulation""" + world_name: str = "world" + simulation_name: str = "cram" + length_unit: str = "m" + angle_unit: str = "rad" + mass_unit: str = "kg" + time_unit: str = "s" + handedness: str = "rhs" + + +@dataclass +class RayResult: + """ + A dataclass to store the ray result. The ray result contains the body name that the ray intersects with and the + distance from the ray origin to the intersection point. + """ + body_name: str + distance: float + + def intersected(self) -> bool: + """ + Check if the ray intersects with a body. + return: Whether the ray intersects with a body. + """ + return self.distance >= 0 and self.body_name != "" + + +@dataclass +class MultiverseContactPoint: + """ + A dataclass to store the contact point returned from Multiverse. + """ + body_name: str + contact_force: List[float] + contact_torque: List[float] diff --git a/src/pycram/datastructures/enums.py b/src/pycram/datastructures/enums.py index 1365c94b0..92968ba61 100644 --- a/src/pycram/datastructures/enums.py +++ b/src/pycram/datastructures/enums.py @@ -2,12 +2,16 @@ from enum import Enum, auto +from ..failures import UnsupportedJointType + + class ExecutionType(Enum): """Enum for Execution Process Module types.""" REAL = auto() SIMULATED = auto() SEMI_REAL = auto() + class Arms(int, Enum): """Enum for Arms.""" LEFT = 0 @@ -158,3 +162,112 @@ class ImageEnum(Enum): SOFA = 17 INSPECT = 18 CHAIR = 37 + + +class VirtualMobileBaseJointName(Enum): + """ + Enum for the joint names of the virtual mobile base. + """ + LINEAR_X = "odom_vel_lin_x_joint" + LINEAR_Y = "odom_vel_lin_y_joint" + ANGULAR_Z = "odom_vel_ang_z_joint" + + +class MJCFGeomType(Enum): + """ + Enum for the different geom types in a MuJoCo XML file. + """ + BOX = "box" + CYLINDER = "cylinder" + CAPSULE = "capsule" + SPHERE = "sphere" + PLANE = "plane" + MESH = "mesh" + ELLIPSOID = "ellipsoid" + HFIELD = "hfield" + SDF = "sdf" + + +MJCFBodyType = MJCFGeomType +""" +Alias for MJCFGeomType. As the body type is the same as the geom type. +""" + + +class MJCFJointType(Enum): + """ + Enum for the different joint types in a MuJoCo XML file. + """ + FREE = "free" + BALL = "ball" + SLIDE = "slide" + HINGE = "hinge" + FIXED = "fixed" # Added for compatibility with PyCRAM, but not a real joint type in MuJoCo. + + +class MultiverseAPIName(Enum): + """ + Enum for the different APIs of the Multiverse. + """ + GET_CONTACT_BODIES = "get_contact_bodies" + GET_CONSTRAINT_EFFORT = "get_constraint_effort" + ATTACH = "attach" + DETACH = "detach" + GET_RAYS = "get_rays" + EXIST = "exist" + PAUSE = "pause" + UNPAUSE = "unpause" + SAVE = "save" + LOAD = "load" + + +class MultiverseProperty(Enum): + def __str__(self): + return self.value + + +class MultiverseBodyProperty(MultiverseProperty): + """ + Enum for the different properties of a body the Multiverse. + """ + POSITION = "position" + ORIENTATION = "quaternion" + RELATIVE_VELOCITY = "relative_velocity" + + +class MultiverseJointProperty(MultiverseProperty): + pass + + +class MultiverseJointPosition(MultiverseJointProperty): + """ + Enum for the Position names of the different joint types in the Multiverse. + """ + REVOLUTE_JOINT_POSITION = "joint_rvalue" + PRISMATIC_JOINT_POSITION = "joint_tvalue" + + @classmethod + def from_pycram_joint_type(cls, joint_type: JointType) -> 'MultiverseJointPosition': + if joint_type in [JointType.REVOLUTE, JointType.CONTINUOUS]: + return MultiverseJointPosition.REVOLUTE_JOINT_POSITION + elif joint_type == JointType.PRISMATIC: + return MultiverseJointPosition.PRISMATIC_JOINT_POSITION + else: + raise UnsupportedJointType(joint_type) + + +class MultiverseJointCMD(MultiverseJointProperty): + """ + Enum for the Command names of the different joint types in the Multiverse. + """ + REVOLUTE_JOINT_CMD = "cmd_joint_rvalue" + PRISMATIC_JOINT_CMD = "cmd_joint_tvalue" + + @classmethod + def from_pycram_joint_type(cls, joint_type: JointType) -> 'MultiverseJointCMD': + if joint_type in [JointType.REVOLUTE, JointType.CONTINUOUS]: + return MultiverseJointCMD.REVOLUTE_JOINT_CMD + elif joint_type == JointType.PRISMATIC: + return MultiverseJointCMD.PRISMATIC_JOINT_CMD + else: + raise UnsupportedJointType(joint_type) diff --git a/src/pycram/datastructures/pose.py b/src/pycram/datastructures/pose.py index 4ca28b267..490a56e9f 100644 --- a/src/pycram/datastructures/pose.py +++ b/src/pycram/datastructures/pose.py @@ -3,7 +3,9 @@ import math import datetime -from typing_extensions import List, Union, Optional + +from tf.transformations import euler_from_quaternion +from typing_extensions import List, Union, Optional, Sized, Self import numpy as np import rospy @@ -12,6 +14,7 @@ from geometry_msgs.msg import (Pose as GeoPose, Quaternion as GeoQuaternion) from tf import transformations from ..orm.base import Pose as ORMPose, Position, Quaternion, ProcessMetaData +from ..validation.error_checkers import calculate_pose_error def get_normalized_quaternion(quaternion: np.ndarray) -> GeoQuaternion: @@ -85,6 +88,32 @@ def from_pose_stamped(pose_stamped: PoseStamped) -> Pose: p.pose = pose_stamped.pose return p + def get_position_diff(self, target_pose: Self) -> Point: + """ + Get the difference between the target and the current positions. + + :param target_pose: The target pose. + :return: The difference between the two positions. + """ + return Point(target_pose.position.x - self.position.x, target_pose.position.y - self.position.y, + target_pose.position.z - self.position.z) + + def get_z_angle_difference(self, target_pose: Self) -> float: + """ + Get the difference between two z angles. + + :param target_pose: The target pose. + :return: The difference between the two z angles. + """ + return target_pose.z_angle - self.z_angle + + @property + def z_angle(self) -> float: + """ + The z angle of the orientation of this Pose in radians. + """ + return euler_from_quaternion(self.orientation_as_list())[2] + @property def frame(self) -> str: """ @@ -144,21 +173,22 @@ def orientation(self, value) -> None: :param value: New orientation, either a list or geometry_msgs/Quaternion """ - if not isinstance(value, list) and not isinstance(value, tuple) and not isinstance(value, GeoQuaternion): - rospy.logwarn("Orientation can only be a list or geometry_msgs/Quaternion") + if not isinstance(value, Sized) and not isinstance(value, GeoQuaternion): + rospy.logwarn("Orientation can only be an iterable (list, tuple, ...etc.) or a geometry_msgs/Quaternion") return - if isinstance(value, list) or isinstance(value, tuple) and len(value) == 4: + if isinstance(value, Sized) and len(value) == 4: orientation = np.array(value) - else: + elif isinstance(value, GeoQuaternion): orientation = np.array([value.x, value.y, value.z, value.w]) + else: + rospy.logerr("Orientation has to be a list or geometry_msgs/Quaternion") + raise TypeError("Orientation has to be a list or geometry_msgs/Quaternion") # This is used instead of np.linalg.norm since numpy is too slow on small arrays self.pose.orientation = get_normalized_quaternion(orientation) def to_list(self) -> List[List[float]]: """ - Returns the position and orientation of this pose as a list containing two list. - :return: The position and orientation as lists """ return [[self.pose.position.x, self.pose.position.y, self.pose.position.z], @@ -186,16 +216,12 @@ def copy(self) -> Pose: def position_as_list(self) -> List[float]: """ - Returns only the position as a list of xyz. - - :return: The position as a list + :return: The position as a list of xyz values. """ return [self.position.x, self.position.y, self.position.z] def orientation_as_list(self) -> List[float]: """ - Returns only the orientation as a list of a quaternion - :return: The orientation as a quaternion with xyzw """ return [self.pose.orientation.x, self.pose.orientation.y, self.pose.orientation.z, self.pose.orientation.w] @@ -230,6 +256,22 @@ def __eq__(self, other: Pose) -> bool: return self_position == other_position and self_orient == other_orient and self.frame == other.frame + def almost_equal(self, other: Pose, position_tolerance_in_meters: float = 1e-3, + orientation_tolerance_in_degrees: float = 1) -> bool: + """ + Checks if the given Pose is almost equal to this Pose. The position and orientation can have a certain + tolerance. The position tolerance is given in meters and the orientation tolerance in degrees. The position + error is calculated as the euclidian distance between the positions and the orientation error as the angle + between the quaternions. + + :param other: The other Pose which should be compared + :param position_tolerance_in_meters: The tolerance for the position in meters + :param orientation_tolerance_in_degrees: The tolerance for the orientation in degrees + :return: True if the Poses are almost equal, False otherwise + """ + error = calculate_pose_error(self, other) + return error[0] <= position_tolerance_in_meters and error[1] <= orientation_tolerance_in_degrees * math.pi / 180 + def set_position(self, new_position: List[float]) -> None: """ Sets the position of this Pose to the given position. Position has to be given as a vector in cartesian space. @@ -285,25 +327,6 @@ def multiply_quaternions(self, quaternion: List) -> None: self.orientation = (x, y, z, w) - def set_orientation_from_euler(self, axis: List, euler_angles: List[float]) -> None: - """ - Convert axis-angle to quaternion. - - :param axis: (x, y, z) tuple representing rotation axis. - :param angle: rotation angle in degree - :return: The quaternion representing the axis angle - """ - angle = math.radians(euler_angles) - axis_length = math.sqrt(sum([i ** 2 for i in axis])) - normalized_axis = tuple(i / axis_length for i in axis) - - x = normalized_axis[0] * math.sin(angle / 2) - y = normalized_axis[1] * math.sin(angle / 2) - z = normalized_axis[2] * math.sin(angle / 2) - w = math.cos(angle / 2) - - return (x, y, z, w) - class Transform(TransformStamped): """ @@ -346,6 +369,27 @@ def __init__(self, translation: Optional[List[float]] = None, rotation: Optional self.frame = frame + def apply_transform_to_array_of_points(self, points: np.ndarray) -> np.ndarray: + """ + Applies this Transform to an array of points. The points are given as a Nx3 matrix, where N is the number of + points. The points are transformed from the child_frame_id to the frame_id of this Transform. + + :param points: The points that should be transformed, given as a Nx3 matrix. + """ + homogeneous_transform = self.get_homogeneous_matrix() + # add the homogeneous coordinate, by adding a column of ones to the position vectors, becoming 4xN matrix + homogenous_points = np.concatenate((points, np.ones((points.shape[0], 1))), axis=1).T + rays_end_positions = homogeneous_transform @ homogenous_points + return rays_end_positions[:3, :].T + + def get_homogeneous_matrix(self) -> np.ndarray: + """ + :return: The homogeneous matrix of this Transform + """ + translation = transformations.translation_matrix(self.translation_as_list()) + rotation = transformations.quaternion_matrix(self.rotation_as_list()) + return np.dot(translation, rotation) + @classmethod def from_pose_and_child_frame(cls, pose: Pose, child_frame_name: str) -> Transform: return cls(pose.position_as_list(), pose.orientation_as_list(), pose.frame, child_frame_name, @@ -386,7 +430,7 @@ def frame(self, value: str) -> None: self.header.frame_id = value @property - def translation(self) -> None: + def translation(self) -> Vector3: """ Property that points to the translation of this Transform """ @@ -411,7 +455,7 @@ def translation(self, value) -> None: self.transform.translation = value @property - def rotation(self) -> None: + def rotation(self) -> Quaternion: """ Property that points to the rotation of this Transform """ @@ -449,16 +493,12 @@ def copy(self) -> Transform: def translation_as_list(self) -> List[float]: """ - Returns the translation of this Transform as a list. - :return: The translation as a list of xyz """ return [self.transform.translation.x, self.transform.translation.y, self.transform.translation.z] def rotation_as_list(self) -> List[float]: """ - Returns the rotation of this Transform as a list representing a quaternion. - :return: The rotation of this Transform as a list with xyzw """ return [self.transform.rotation.x, self.transform.rotation.y, self.transform.rotation.z, @@ -553,5 +593,3 @@ def set_rotation(self, new_rotation: List[float]) -> None: :param new_rotation: The new rotation as a quaternion with xyzw """ self.rotation = new_rotation - - diff --git a/src/pycram/datastructures/world.py b/src/pycram/datastructures/world.py index 70714a6e7..d9c16e5e9 100644 --- a/src/pycram/datastructures/world.py +++ b/src/pycram/datastructures/world.py @@ -6,100 +6,38 @@ import time from abc import ABC, abstractmethod from copy import copy -from queue import Queue - import numpy as np import rospy from geometry_msgs.msg import Point -from typing_extensions import List, Optional, Dict, Tuple, Callable, TYPE_CHECKING -from typing_extensions import Union +from typing_extensions import List, Optional, Dict, Tuple, Callable, TYPE_CHECKING, Union, Type from ..cache_manager import CacheManager -from .enums import JointType, ObjectType, WorldMode -from ..world_concepts.event import Event +from ..config.world_conf import WorldConfig +from ..datastructures.dataclasses import (Color, AxisAlignedBoundingBox, CollisionCallbacks, + MultiBody, VisualShape, BoxVisualShape, CylinderVisualShape, + SphereVisualShape, + CapsuleVisualShape, PlaneVisualShape, MeshVisualShape, + ObjectState, WorldState, ClosestPointsList, + ContactPointsList, VirtualMobileBaseJoints) +from ..datastructures.enums import JointType, ObjectType, WorldMode, Arms +from ..datastructures.pose import Pose, Transform +from ..datastructures.world_entity import StateEntity +from ..failures import ProspectionObjectNotFound, WorldObjectNotFound from ..local_transformer import LocalTransformer -from .pose import Pose, Transform +from ..robot_description import RobotDescription +from ..validation.goal_validator import (MultiPoseGoalValidator, + PoseGoalValidator, JointPositionGoalValidator, + MultiJointPositionGoalValidator, GoalValidator, + validate_joint_position, validate_multiple_joint_positions, + validate_object_pose, validate_multiple_object_poses) from ..world_concepts.constraints import Constraint -from .dataclasses import (Color, AxisAlignedBoundingBox, CollisionCallbacks, - MultiBody, VisualShape, BoxVisualShape, CylinderVisualShape, SphereVisualShape, - CapsuleVisualShape, PlaneVisualShape, MeshVisualShape, - ObjectState, State, WorldState) +from ..world_concepts.event import Event if TYPE_CHECKING: from ..world_concepts.world_object import Object - from ..description import Link, Joint - - -class StateEntity: - """ - The StateEntity class is used to store the state of an object or the physics simulator. This is used to save and - restore the state of the World. - """ - - def __init__(self): - self._saved_states: Dict[int, State] = {} - - @property - def saved_states(self) -> Dict[int, State]: - """ - Returns the saved states of this entity. - """ - return self._saved_states - - def save_state(self, state_id: int) -> int: - """ - Saves the state of this entity with the given state id. - - :param state_id: The unique id of the state. - """ - self._saved_states[state_id] = self.current_state - return state_id - - @property - @abstractmethod - def current_state(self) -> State: - """ - Returns the current state of this entity. - - :return: The current state of this entity. - """ - pass - - @current_state.setter - @abstractmethod - def current_state(self, state: State) -> None: - """ - Sets the current state of this entity. - - :param state: The new state of this entity. - """ - pass - - def restore_state(self, state_id: int) -> None: - """ - Restores the state of this entity from a saved state using the given state id. - - :param state_id: The unique id of the state. - """ - self.current_state = self.saved_states[state_id] - - def remove_saved_states(self) -> None: - """ - Removes all saved states of this entity. - """ - self._saved_states = {} - - -class WorldEntity(StateEntity, ABC): - """ - A data class that represents an entity of the world, such as an object or a link. - """ - - def __init__(self, _id: int, world: Optional[World] = None): - StateEntity.__init__(self) - self.id = _id - self.world: World = world if world is not None else World.current_world + from ..description import Link, Joint, ObjectDescription + from ..object_descriptors.generic import ObjectDescription as GenericObjectDescription class World(StateEntity, ABC): @@ -109,17 +47,17 @@ class World(StateEntity, ABC): current_world which is managed by the World class itself. """ - simulation_frequency: float + conf: Type[WorldConfig] = WorldConfig """ - Global reference for the simulation frequency (Hz), used in calculating the equivalent real time in the simulation. + The configurations of the world, the default configurations are defined in world_conf.py in the config folder. """ current_world: Optional[World] = None """ - Global reference to the currently used World, usually this is the - graphical one. However, if you are inside a UseProspectionWorld() environment the current_world points to the - prospection world. In this way you can comfortably use the current_world, which should point towards the World - used at the moment. + Global reference to the currently used World, usually this is the + graphical one. However, if you are inside a UseProspectionWorld() environment the current_world points to the + prospection world. In this way you can comfortably use the current_world, which should point towards the World + used at the moment. """ robot: Optional[Object] = None @@ -128,51 +66,49 @@ class World(StateEntity, ABC): the URDF with the name of the URDF on the parameter server. """ - data_directory: List[str] = [os.path.join(os.path.dirname(__file__), '..', '..', '..', 'resources')] - """ - Global reference for the data directories, this is used to search for the description files of the robot - and the objects. - """ - - cache_dir = data_directory[0] + '/cached/' + cache_manager: CacheManager = CacheManager(conf.cache_dir, [conf.resources_path], False) """ - Global reference for the cache directory, this is used to cache the description files of the robot and the objects. + Global reference for the cache manager, this is used to cache the description files of the robot and the objects. """ - def __init__(self, mode: WorldMode, is_prospection_world: bool, simulation_frequency: float): + def __init__(self, mode: WorldMode, is_prospection_world: bool = False, clear_cache: bool = False): """ - Creates a new simulation, the mode decides if the simulation should be a rendered window or just run in the - background. There can only be one rendered simulation. - The World object also initializes the Events for attachment, detachment and for manipulating the world. + Create a new simulation, the mode decides if the simulation should be a rendered window or just run in the + background. There can only be one rendered simulation. + The World object also initializes the Events for attachment, detachment and for manipulating the world. - :param mode: Can either be "GUI" for rendered window or "DIRECT" for non-rendered. The default parameter is "GUI" - :param is_prospection_world: For internal usage, decides if this World should be used as a prospection world. + :param mode: Can either be "GUI" for rendered window or "DIRECT" for non-rendered. The default parameter is + "GUI" + :param is_prospection_world: For internal usage, decides if this World should be used as a prospection world. + :param clear_cache: Whether to clear the cache directory. """ StateEntity.__init__(self) + if clear_cache or (self.conf.clear_cache_at_start and not self.cache_manager.cache_cleared): + self.cache_manager.clear_cache() + + GoalValidator.raise_error = self.conf.raise_goal_validator_error + if World.current_world is None: World.current_world = self - World.simulation_frequency = simulation_frequency - self.cache_manager = CacheManager(self.cache_dir, self.data_directory) + self.object_lock: threading.Lock = threading.Lock() self.id: Optional[int] = -1 # This is used to connect to the physics server (allows multiple clients) self._init_world(mode) + self.objects: List[Object] = [] + # List of all Objects in the World + self.is_prospection_world: bool = is_prospection_world self._init_and_sync_prospection_world() self.local_transformer = LocalTransformer() self._update_local_transformer_worlds() - self.objects: List[Object] = [] - # List of all Objects in the World - - - self.mode: WorldMode = mode # The mode of the simulation, can be "GUI" or "DIRECT" @@ -182,16 +118,104 @@ def __init__(self, mode: WorldMode, is_prospection_world: bool, simulation_frequ self._current_state: Optional[WorldState] = None + self._init_goal_validators() + + self.original_state_id = self.save_state() + + @classmethod + def get_cache_dir(cls) -> str: + """ + Return the cache directory. + """ + return cls.cache_manager.cache_dir + + def add_object(self, obj: Object) -> None: + """ + Add an object to the world. + + :param obj: The object to be added. + """ + self.object_lock.acquire() + self.objects.append(obj) + self.add_object_to_original_state(obj) + self.object_lock.release() + + @property + def robot_description(self) -> RobotDescription: + """ + Return the current robot description. + """ + return RobotDescription.current_robot_description + + @property + def robot_has_actuators(self) -> bool: + """ + Return whether the robot has actuators. + """ + return self.robot_description.has_actuators + + def get_actuator_for_joint(self, joint: Joint) -> str: + """ + Get the actuator name for a given joint. + """ + return self.robot_joint_actuators[joint.name] + + def joint_has_actuator(self, joint: Joint) -> bool: + """ + Return whether the joint has an actuator. + """ + return joint.name in self.robot_joint_actuators + + @property + def robot_joint_actuators(self) -> Dict[str, str]: + """ + Return the joint actuators of the robot. + """ + return self.robot_description.joint_actuators + + def _init_goal_validators(self): + """ + Initialize the goal validators for the World objects' poses, positions, and orientations. + """ + + # Objects Pose goal validators + self.pose_goal_validator = PoseGoalValidator(self.get_object_pose, self.conf.get_pose_tolerance(), + self.conf.acceptable_percentage_of_goal) + self.multi_pose_goal_validator = MultiPoseGoalValidator( + lambda x: list(self.get_multiple_object_poses(x).values()), + self.conf.get_pose_tolerance(), self.conf.acceptable_percentage_of_goal) + + # Joint Goal validators + self.joint_position_goal_validator = JointPositionGoalValidator( + self.get_joint_position, + acceptable_revolute_joint_position_error=self.conf.revolute_joint_position_tolerance, + acceptable_prismatic_joint_position_error=self.conf.prismatic_joint_position_tolerance, + acceptable_percentage_of_goal_achieved=self.conf.acceptable_percentage_of_goal) + self.multi_joint_position_goal_validator = MultiJointPositionGoalValidator( + lambda x: list(self.get_multiple_joint_positions(x).values()), + acceptable_revolute_joint_position_error=self.conf.revolute_joint_position_tolerance, + acceptable_prismatic_joint_position_error=self.conf.prismatic_joint_position_tolerance, + acceptable_percentage_of_goal_achieved=self.conf.acceptable_percentage_of_goal) + + def check_object_exists(self, obj: Object) -> bool: + """ + Check if the object exists in the simulator. + + :param obj: The object to check. + :return: True if the object is in the world, False otherwise. + """ + raise NotImplementedError + @abstractmethod def _init_world(self, mode: WorldMode): """ - Initializes the physics simulation. + Initialize the physics simulation. """ raise NotImplementedError def _init_events(self): """ - Initializes dynamic events that can be used to react to changes in the World. + Initialize dynamic events that can be used to react to changes in the World. """ self.detachment_event: Event = Event() self.attachment_event: Event = Event() @@ -199,86 +223,108 @@ def _init_events(self): def _init_and_sync_prospection_world(self): """ - Initializes the prospection world and the synchronization between the main and the prospection world. + Initialize the prospection world and the synchronization between the main and the prospection world. """ self._init_prospection_world() self._sync_prospection_world() def _update_local_transformer_worlds(self): """ - Updates the local transformer worlds with the current world and prospection world. + Update the local transformer worlds with the current world and prospection world. """ self.local_transformer.world = self self.local_transformer.prospection_world = self.prospection_world def _init_prospection_world(self): """ - Initializes the prospection world, if this is a prospection world itself it will not create another prospection, + Initialize the prospection world, if this is a prospection world itself it will not create another prospection, world, but instead set the prospection world to None, else it will create a prospection world. """ if self.is_prospection_world: # then no need to add another prospection world self.prospection_world = None else: self.prospection_world: World = self.__class__(WorldMode.DIRECT, - True, - World.simulation_frequency) + True) def _sync_prospection_world(self): """ - Synchronizes the prospection world with the main world, this means that every object in the main world will be + Synchronize the prospection world with the main world, this means that every object in the main world will be added to the prospection world and vice versa. """ if self.is_prospection_world: # then no need to add another prospection world self.world_sync = None else: self.world_sync: WorldSync = WorldSync(self, self.prospection_world) + self.pause_world_sync() self.world_sync.start() - def update_cache_dir_with_object(self, path: str, ignore_cached_files: bool, - obj: Object) -> str: + def preprocess_object_file_and_get_its_cache_path(self, path: str, ignore_cached_files: bool, + description: ObjectDescription, name: str, + scale_mesh: Optional[float] = None) -> str: """ - Updates the cache directory with the given object. + Update the cache directory with the given object. :param path: The path to the object. :param ignore_cached_files: If the cached files should be ignored. - :param obj: The object to be added to the cache directory. + :param description: The object description. + :param name: The name of the object. + :param scale_mesh: The scale of the mesh. + :return: The path of the cached object. """ - return self.cache_manager.update_cache_dir_with_object(path, ignore_cached_files, obj.description, obj.name) + return self.cache_manager.update_cache_dir_with_object(path, ignore_cached_files, description, name, scale_mesh) @property def simulation_time_step(self): """ The time step of the simulation in seconds. """ - return 1 / World.simulation_frequency + return 1 / self.__class__.conf.simulation_frequency @abstractmethod - def load_object_and_get_id(self, path: Optional[str] = None, pose: Optional[Pose] = None) -> int: + def load_object_and_get_id(self, path: Optional[str] = None, pose: Optional[Pose] = None, + obj_type: Optional[ObjectType] = None) -> int: """ - Loads a description file (e.g. URDF) at the given pose and returns the id of the loaded object. + Load a description file (e.g. URDF) at the given pose and returns the id of the loaded object. :param path: The path to the description file, if None the description file is assumed to be already loaded. :param pose: The pose at which the object should be loaded. + :param obj_type: The type of the object. :return: The id of the loaded object. """ pass + def load_generic_object_and_get_id(self, description: GenericObjectDescription, + pose: Optional[Pose] = None) -> int: + """ + Create a visual and collision box in the simulation and returns the id of the loaded object. + + :param description: The object description. + :param pose: The pose at which the object should be loaded. + """ + raise NotImplementedError + + def get_object_names(self) -> List[str]: + """ + Return the names of all objects in the World. + + :return: A list of object names. + """ + return [obj.name for obj in self.objects] + def get_object_by_name(self, name: str) -> Optional[Object]: """ - Returns the object with the given name. If there is no object with the given name, None is returned. + Return the object with the given name. If there is no object with the given name, None is returned. :param name: The name of the returned Objects. :return: The object with the given name, if there is one. """ - object = list(filter(lambda obj: obj.name == name, self.objects)) - if len(object) > 0: - return object[0] - return None + matching_objects = list(filter(lambda obj: obj.name == name, self.objects)) + return matching_objects[0] if len(matching_objects) > 0 else None def get_object_by_type(self, obj_type: ObjectType) -> List[Object]: """ - Returns a list of all Objects which have the type 'obj_type'. + Return a list of all Objects which have the type 'obj_type'. :param obj_type: The type of the returned Objects. :return: A list of all Objects that have the type 'obj_type'. @@ -287,59 +333,92 @@ def get_object_by_type(self, obj_type: ObjectType) -> List[Object]: def get_object_by_id(self, obj_id: int) -> Object: """ - Returns the single Object that has the unique id. + Return the single Object that has the unique id. :param obj_id: The unique id for which the Object should be returned. :return: The Object with the id 'id'. """ return list(filter(lambda obj: obj.id == obj_id, self.objects))[0] - @abstractmethod - def remove_object_by_id(self, obj_id: int) -> None: + def remove_visual_object(self, obj_id: int) -> bool: """ - Removes the object with the given id from the world. + Remove the object with the given id from the world, and saves a new original state for the world. :param obj_id: The unique id of the object to be removed. + :return: Whether the object was removed successfully. + """ + + removed = self._remove_visual_object(obj_id) + if removed: + self.update_simulator_state_id_in_original_state() + else: + rospy.logwarn(f"Object with id {obj_id} could not be removed.") + return removed + + @abstractmethod + def _remove_visual_object(self, obj_id: int) -> bool: + """ + Remove the visual object with the given id from the world, and update the simulator state in the original state. + + :param obj_id: The unique id of the visual object to be removed. + :return: Whether the object was removed successfully. """ pass @abstractmethod - def remove_object_from_simulator(self, obj: Object) -> None: + def remove_object_from_simulator(self, obj: Object) -> bool: """ - Removes an object from the physics simulator. + Remove an object from the physics simulator. :param obj: The object to be removed. + :return: Whether the object was removed successfully. """ pass def remove_object(self, obj: Object) -> None: """ - Removes this object from the current world. + Remove this object from the current world. For the object to be removed it has to be detached from all objects it is currently attached to. After this is done a call to world remove object is done to remove this Object from the simulation/world. :param obj: The object to be removed. """ - obj.detach_all() - - self.objects.remove(obj) + self.object_lock.acquire() - # This means the current world of the object is not the prospection world, since it - # has a reference to the prospection world - if self.prospection_world is not None: - self.world_sync.remove_obj_queue.put(obj) - self.world_sync.remove_obj_queue.join() + obj.detach_all() - self.remove_object_from_simulator(obj) + if self.remove_object_from_simulator(obj): + self.objects.remove(obj) + self.remove_object_from_original_state(obj) if World.robot == obj: World.robot = None + self.object_lock.release() + + def remove_object_from_original_state(self, obj: Object) -> None: + """ + Remove an object from the original state of the world. + + :param obj: The object to be removed. + """ + self.original_state.object_states.pop(obj.name) + self.original_state.simulator_state_id = self.save_physics_simulator_state(use_same_id=True) + + def add_object_to_original_state(self, obj: Object) -> None: + """ + Add an object to the original state of the world. + + :param obj: The object to be added. + """ + self.original_state.object_states[obj.name] = obj.current_state + self.update_simulator_state_id_in_original_state() + def add_fixed_constraint(self, parent_link: Link, child_link: Link, child_to_parent_transform: Transform) -> int: """ - Creates a fixed joint constraint between the given parent and child links, + Create a fixed joint constraint between the given parent and child links, the joint frame will be at the origin of the child link frame, and would have the same orientation as the child link frame. @@ -390,7 +469,7 @@ def get_joint_position(self, joint: Joint) -> float: @abstractmethod def get_object_joint_names(self, obj: Object) -> List[str]: """ - Returns the names of all joints of this object. + Return the names of all joints of this object. :param obj: The object. :return: A list of joint names. @@ -407,10 +486,60 @@ def get_link_pose(self, link: Link) -> Pose: """ pass + @abstractmethod + def get_multiple_link_poses(self, links: List[Link]) -> Dict[str, Pose]: + """ + Get the poses of multiple links of an articulated object with respect to the world frame. + + :param links: The links as a list of AbstractLink objects. + :return: A dictionary with link names as keys and Pose objects as values. + """ + pass + + @abstractmethod + def get_link_position(self, link: Link) -> List[float]: + """ + Get the position of a link of an articulated object with respect to the world frame. + + :param link: The link as a AbstractLink object. + :return: The position of the link as a list of floats. + """ + pass + + @abstractmethod + def get_link_orientation(self, link: Link) -> List[float]: + """ + Get the orientation of a link of an articulated object with respect to the world frame. + + :param link: The link as a AbstractLink object. + :return: The orientation of the link as a list of floats. + """ + pass + + @abstractmethod + def get_multiple_link_positions(self, links: List[Link]) -> Dict[str, List[float]]: + """ + Get the positions of multiple links of an articulated object with respect to the world frame. + + :param links: The links as a list of AbstractLink objects. + :return: A dictionary with link names as keys and lists of floats as values. + """ + pass + + @abstractmethod + def get_multiple_link_orientations(self, links: List[Link]) -> Dict[str, List[float]]: + """ + Get the orientations of multiple links of an articulated object with respect to the world frame. + + :param links: The links as a list of AbstractLink objects. + :return: A dictionary with link names as keys and lists of floats as values. + """ + pass + @abstractmethod def get_object_link_names(self, obj: Object) -> List[str]: """ - Returns the names of all links of this object. + Return the names of all links of this object. :param obj: The object. :return: A list of link names. @@ -419,7 +548,7 @@ def get_object_link_names(self, obj: Object) -> List[str]: def simulate(self, seconds: float, real_time: Optional[bool] = False) -> None: """ - Simulates Physics in the World for a given amount of seconds. Usually this simulation is faster than real + Simulate Physics in the World for a given amount of seconds. Usually this simulation is faster than real time. By setting the 'real_time' parameter this simulation is slowed down such that the simulated time is equal to real time. @@ -427,12 +556,12 @@ def simulate(self, seconds: float, real_time: Optional[bool] = False) -> None: :param real_time: If the simulation should happen in real time or faster. """ self.set_realtime(real_time) - for i in range(0, int(seconds * self.simulation_frequency)): + for i in range(0, int(seconds * self.conf.simulation_frequency)): curr_time = rospy.Time.now() self.step() for objects, callbacks in self.coll_callbacks.items(): contact_points = self.get_contact_points_between_two_objects(objects[0], objects[1]) - if contact_points != (): + if len(contact_points) > 0: callbacks.on_collision_cb() elif callbacks.no_collision_cb is not None: callbacks.no_collision_cb() @@ -444,7 +573,7 @@ def simulate(self, seconds: float, real_time: Optional[bool] = False) -> None: def update_all_objects_poses(self) -> None: """ - Updates the positions of all objects in the world. + Update the positions of all objects in the world. """ for obj in self.objects: obj.update_pose() @@ -453,20 +582,89 @@ def update_all_objects_poses(self) -> None: def get_object_pose(self, obj: Object) -> Pose: """ Get the pose of an object in the world frame from the current object pose in the simulator. + + :param obj: The object. + """ + pass + + @abstractmethod + def get_multiple_object_poses(self, objects: List[Object]) -> Dict[str, Pose]: + """ + Get the poses of multiple objects in the world frame from the current object poses in the simulator. + + :param objects: The objects. """ pass + @abstractmethod + def get_multiple_object_positions(self, objects: List[Object]) -> Dict[str, List[float]]: + """ + Get the positions of multiple objects in the world frame from the current object poses in the simulator. + + :param objects: The objects. + """ + pass + + @abstractmethod + def get_object_position(self, obj: Object) -> List[float]: + """ + Get the position of an object in the world frame from the current object pose in the simulator. + + :param obj: The object. + """ + pass + + @abstractmethod + def get_multiple_object_orientations(self, objects: List[Object]) -> Dict[str, List[float]]: + """ + Get the orientations of multiple objects in the world frame from the current object poses in the simulator. + + :param objects: The objects. + """ + pass + + @abstractmethod + def get_object_orientation(self, obj: Object) -> List[float]: + """ + Get the orientation of an object in the world frame from the current object pose in the simulator. + + :param obj: The object. + """ + pass + + @property + def robot_virtual_joints(self) -> List[Joint]: + """ + The virtual joints of the robot. + """ + return [self.robot.joints[name] for name in self.robot_virtual_joints_names] + + @property + def robot_virtual_joints_names(self) -> List[str]: + """ + The names of the virtual joints of the robot. + """ + return self.robot_description.virtual_mobile_base_joints.names + + def get_robot_mobile_base_joints(self) -> VirtualMobileBaseJoints: + """ + Get the mobile base joints of the robot. + + :return: The mobile base joints. + """ + return self.robot_description.virtual_mobile_base_joints + @abstractmethod def perform_collision_detection(self) -> None: """ - Checks for collisions between all objects in the World and updates the contact points. + Check for collisions between all objects in the World and updates the contact points. """ pass @abstractmethod - def get_object_contact_points(self, obj: Object) -> List: + def get_object_contact_points(self, obj: Object) -> ContactPointsList: """ - Returns a list of contact points of this Object with all other Objects. + Return a list of contact points of this Object with all other Objects. :param obj: The object. :return: A list of all contact points with other objects @@ -474,9 +672,9 @@ def get_object_contact_points(self, obj: Object) -> List: pass @abstractmethod - def get_contact_points_between_two_objects(self, obj1: Object, obj2: Object) -> List: + def get_contact_points_between_two_objects(self, obj1: Object, obj2: Object) -> ContactPointsList: """ - Returns a list of contact points between obj1 and obj2. + Return a list of contact points between obj_a and obj_b. :param obj1: The first object. :param obj2: The second object. @@ -484,24 +682,97 @@ def get_contact_points_between_two_objects(self, obj1: Object, obj2: Object) -> """ pass + def get_object_closest_points(self, obj: Object, max_distance: float) -> ClosestPointsList: + """ + Return the closest points of this object with all other objects in the world. + + :param obj: The object. + :param max_distance: The maximum distance between the points. + :return: A list of the closest points. + """ + all_obj_closest_points = [self.get_closest_points_between_objects(obj, other_obj, max_distance) for other_obj in + self.objects + if other_obj != obj] + return ClosestPointsList([point for closest_points in all_obj_closest_points for point in closest_points]) + + def get_closest_points_between_objects(self, object_a: Object, object_b: Object, max_distance: float) \ + -> ClosestPointsList: + """ + Return the closest points between two objects. + + :param object_a: The first object. + :param object_b: The second object. + :param max_distance: The maximum distance between the points. + :return: A list of the closest points. + """ + raise NotImplementedError + + @validate_joint_position @abstractmethod - def reset_joint_position(self, joint: Joint, joint_position: float) -> None: + def reset_joint_position(self, joint: Joint, joint_position: float) -> bool: """ Reset the joint position instantly without physics simulation + .. note:: + It is recommended to use the validate_joint_position decorator to validate the joint position for + the implementation of this method. + :param joint: The joint to reset the position for. :param joint_position: The new joint pose. + :return: True if the reset was successful, False otherwise """ pass + @validate_multiple_joint_positions @abstractmethod - def reset_object_base_pose(self, obj: Object, pose: Pose): + def set_multiple_joint_positions(self, joint_positions: Dict[Joint, float]) -> bool: + """ + Set the positions of multiple joints of an articulated object. + + .. note:: + It is recommended to use the validate_multiple_joint_positions decorator to validate the + joint positions for the implementation of this method. + + :param joint_positions: A dictionary with joint objects as keys and joint positions as values. + :return: True if the set was successful, False otherwise. + """ + pass + + @abstractmethod + def get_multiple_joint_positions(self, joints: List[Joint]) -> Dict[str, float]: + """ + Get the positions of multiple joints of an articulated object. + + :param joints: The joints as a list of Joint objects. + """ + pass + + @validate_object_pose + @abstractmethod + def reset_object_base_pose(self, obj: Object, pose: Pose) -> bool: """ Reset the world position and orientation of the base of the object instantaneously, not through physics simulation. (x,y,z) position vector and (x,y,z,w) quaternion orientation. + .. note:: + It is recommended to use the validate_object_pose decorator to validate the object pose for the + implementation of this method. + :param obj: The object. :param pose: The new pose as a Pose object. + :return: True if the reset was successful, False otherwise. + """ + pass + + @validate_multiple_object_poses + @abstractmethod + def reset_multiple_objects_base_poses(self, objects: Dict[Object, Pose]) -> bool: + """ + Reset the world position and orientation of the base of multiple objects instantaneously, + not through physics simulation. (x,y,z) position vector and (x,y,z,w) quaternion orientation. + + :param objects: A dictionary with objects as keys and poses as values. + :return: True if the reset was successful, False otherwise. """ pass @@ -512,10 +783,20 @@ def step(self): """ pass + def get_arm_tool_frame_link(self, arm: Arms) -> Link: + """ + Get the tool frame link of the arm of the robot. + + :param arm: The arm for which the tool frame link should be returned. + :return: The tool frame link of the arm. + """ + ee_link_name = self.robot_description.get_arm_tool_frame(arm) + return self.robot.get_link(ee_link_name) + @abstractmethod def set_link_color(self, link: Link, rgba_color: Color): """ - Changes the rgba_color of a link of this object, the rgba_color has to be given as Color object. + Change the rgba_color of a link of this object, the rgba_color has to be given as Color object. :param link: The link which should be colored. :param rgba_color: The rgba_color as Color object with RGBA values between 0 and 1. @@ -545,7 +826,7 @@ def get_colors_of_object_links(self, obj: Object) -> Dict[str, Color]: @abstractmethod def get_object_axis_aligned_bounding_box(self, obj: Object) -> AxisAlignedBoundingBox: """ - Returns the axis aligned bounding box of this object. The return of this method are two points in + Return the axis aligned bounding box of this object. The return of this method are two points in world coordinate frame which define a bounding box. :param obj: The object for which the bounding box should be returned. @@ -556,7 +837,7 @@ def get_object_axis_aligned_bounding_box(self, obj: Object) -> AxisAlignedBoundi @abstractmethod def get_link_axis_aligned_bounding_box(self, link: Link) -> AxisAlignedBoundingBox: """ - Returns the axis aligned bounding box of the link. The return of this method are two points in + Return the axis aligned bounding box of the link. The return of this method are two points in world coordinate frame which define a bounding box. """ pass @@ -564,7 +845,7 @@ def get_link_axis_aligned_bounding_box(self, link: Link) -> AxisAlignedBoundingB @abstractmethod def set_realtime(self, real_time: bool) -> None: """ - Enables the real time simulation of Physics in the World. By default, this is disabled and Physics is only + Enable the real time simulation of Physics in the World. By default, this is disabled and Physics is only simulated to reason about it. :param real_time: Whether the World should simulate Physics in real time. @@ -574,7 +855,7 @@ def set_realtime(self, real_time: bool) -> None: @abstractmethod def set_gravity(self, gravity_vector: List[float]) -> None: """ - Sets the gravity that is used in the World. By default, it is set to the gravity on earth ([0, 0, -9.8]). + Set the gravity that is used in the World. By default, it is set to the gravity on earth ([0, 0, -9.8]). Gravity is given as a vector in x,y,z. Gravity is only applied while simulating Physic. :param gravity_vector: The gravity vector that should be used in the World. @@ -583,7 +864,7 @@ def set_gravity(self, gravity_vector: List[float]) -> None: def set_robot_if_not_set(self, robot: Object) -> None: """ - Sets the robot if it is not set yet. + Set the robot if it is not set yet. :param robot: The Object reference to the Object representing the robot. """ @@ -593,7 +874,7 @@ def set_robot_if_not_set(self, robot: Object) -> None: @staticmethod def set_robot(robot: Union[Object, None]) -> None: """ - Sets the global variable for the robot Object This should be set on spawning the robot. + Set the global variable for the robot Object This should be set on spawning the robot. :param robot: The Object reference to the Object representing the robot. """ @@ -602,17 +883,21 @@ def set_robot(robot: Union[Object, None]) -> None: @staticmethod def robot_is_set() -> bool: """ - Returns whether the robot has been set or not. + Return whether the robot has been set or not. :return: True if the robot has been set, False otherwise. """ return World.robot is not None - def exit(self) -> None: + def exit(self, remove_saved_states: bool = True) -> None: """ - Closes the World as well as the prospection world, also collects any other thread that is running. + Close the World as well as the prospection world, also collects any other thread that is running. + + :param remove_saved_states: Whether to remove the saved states. """ self.exit_prospection_world_if_exists() + self.reset_world(remove_saved_states) + self.remove_all_objects() self.disconnect_from_physics_server() self.reset_robot() self.join_threads() @@ -621,7 +906,7 @@ def exit(self) -> None: def exit_prospection_world_if_exists(self) -> None: """ - Exits the prospection world if it exists. + Exit the prospection world if it exists. """ if self.prospection_world: self.terminate_world_sync() @@ -630,21 +915,21 @@ def exit_prospection_world_if_exists(self) -> None: @abstractmethod def disconnect_from_physics_server(self) -> None: """ - Disconnects the world from the physics server. + Disconnect the world from the physics server. """ pass def reset_current_world(self) -> None: """ - Resets the pose of every object in the World to the pose it was spawned in and sets every joint to 0. + Reset the pose of every object in the World to the pose it was spawned in and sets every joint to 0. """ for obj in self.objects: obj.set_pose(obj.original_pose) - obj.set_joint_positions(dict(zip(list(obj.joint_names), [0] * len(obj.joint_names)))) + obj.set_multiple_joint_positions(dict(zip(list(obj.joint_names), [0] * len(obj.joint_names)))) def reset_robot(self) -> None: """ - Sets the robot class variable to None. + Set the robot class variable to None. """ self.set_robot(None) @@ -657,14 +942,15 @@ def join_threads(self) -> None: def terminate_world_sync(self) -> None: """ - Terminates the world sync thread. + Terminate the world sync thread. """ self.world_sync.terminate = True + self.resume_world_sync() self.world_sync.join() def save_state(self, state_id: Optional[int] = None) -> int: """ - Returns the id of the saved state of the World. The saved state contains the states of all the objects and + Return the id of the saved state of the World. The saved state contains the states of all the objects and the state of the physics simulator. :return: A unique id of the state @@ -677,18 +963,23 @@ def save_state(self, state_id: Optional[int] = None) -> int: @property def current_state(self) -> WorldState: if self._current_state is None: - self._current_state = WorldState(self.save_physics_simulator_state(), self.object_states) - return self._current_state + simulator_state = None if self.conf.use_physics_simulator_state else self.save_physics_simulator_state(True) + self._current_state = WorldState(simulator_state, self.object_states) + return WorldState(self._current_state.simulator_state_id, self.object_states) @current_state.setter def current_state(self, state: WorldState) -> None: - self.restore_physics_simulator_state(state.simulator_state_id) - self.object_states = state.object_states + if self.current_state != state: + if self.conf.use_physics_simulator_state: + self.restore_physics_simulator_state(state.simulator_state_id) + else: + for obj in self.objects: + self.get_object_by_name(obj.name).current_state = state.object_states[obj.name] @property def object_states(self) -> Dict[str, ObjectState]: """ - Returns the states of all objects in the World. + Return the states of all objects in the World. :return: A dictionary with the object id as key and the object state as value. """ @@ -697,14 +988,14 @@ def object_states(self) -> Dict[str, ObjectState]: @object_states.setter def object_states(self, states: Dict[str, ObjectState]) -> None: """ - Sets the states of all objects in the World. + Set the states of all objects in the World. """ for obj_name, obj_state in states.items(): self.get_object_by_name(obj_name).current_state = obj_state def save_objects_state(self, state_id: int) -> None: """ - Saves the state of all objects in the World according to the given state using the unique state id. + Save the state of all objects in the World according to the given state using the unique state id. :param state_id: The unique id representing the state. """ @@ -712,10 +1003,11 @@ def save_objects_state(self, state_id: int) -> None: obj.save_state(state_id) @abstractmethod - def save_physics_simulator_state(self) -> int: + def save_physics_simulator_state(self, use_same_id: bool = False) -> int: """ - Saves the state of the physics simulator and returns the unique id of the state. + Save the state of the physics simulator and returns the unique id of the state. + :param use_same_id: If the same id should be used for the state. :return: The unique id representing the state. """ pass @@ -723,7 +1015,7 @@ def save_physics_simulator_state(self) -> int: @abstractmethod def remove_physics_simulator_state(self, state_id: int) -> None: """ - Removes the state of the physics simulator with the given id. + Remove the state of the physics simulator with the given id. :param state_id: The unique id representing the state. """ @@ -732,7 +1024,7 @@ def remove_physics_simulator_state(self, state_id: int) -> None: @abstractmethod def restore_physics_simulator_state(self, state_id: int) -> None: """ - Restores the objects and environment state in the physics simulator according to + Restore the objects and environment state in the physics simulator according to the given state using the unique state id. :param state_id: The unique id representing the state. @@ -744,7 +1036,7 @@ def get_images_for_target(self, cam_pose: Pose, size: Optional[int] = 256) -> List[np.ndarray]: """ - Calculates the view and projection Matrix and returns 3 images: + Calculate the view and projection Matrix and returns 3 images: 1. An RGB image 2. A depth image @@ -763,7 +1055,7 @@ def register_two_objects_collision_callbacks(self, on_collision_callback: Callable, on_collision_removal_callback: Optional[Callable] = None) -> None: """ - Registers callback methods for contact between two Objects. There can be a callback for when the two Objects + Register callback methods for contact between two Objects. There can be a callback for when the two Objects get in contact and, optionally, for when they are not in contact anymore. :param object_a: An object in the World @@ -775,74 +1067,109 @@ def register_two_objects_collision_callbacks(self, on_collision_removal_callback) @classmethod - def add_resource_path(cls, path: str) -> None: + def get_data_directories(cls) -> List[str]: + """ + The resources directories where the objects, robots, and environments are stored. + """ + return cls.cache_manager.data_directories + + @classmethod + def add_resource_path(cls, path: str, prepend: bool = False) -> None: """ - Adds a resource path in which the World will search for files. This resource directory is searched if an + Add a resource path in which the World will search for files. This resource directory is searched if an Object is spawned only with a filename. :param path: A path in the filesystem in which to search for files. + :param prepend: Put the new path at the beginning of the list such that it is searched first. + """ + if prepend: + cls.cache_manager.data_directories = [path] + cls.cache_manager.data_directories + else: + cls.cache_manager.data_directories.append(path) + + @classmethod + def remove_resource_path(cls, path: str) -> None: + """ + Remove the given path from the data_directories list. + + :param path: The path to remove. + """ + cls.cache_manager.data_directories.remove(path) + + @classmethod + def change_cache_dir_path(cls, path: str) -> None: + """ + Change the cache directory to the given path + + :param path: The new path for the cache directory. """ - cls.data_directory.append(path) + cls.cache_manager.cache_dir = os.path.join(path, cls.conf.cache_dir_name) def get_prospection_object_for_object(self, obj: Object) -> Object: """ - Returns the corresponding object from the prospection world for a given object in the main world. + Return the corresponding object from the prospection world for a given object in the main world. If the given Object is already in the prospection world, it is returned. :param obj: The object for which the corresponding object in the prospection World should be found. :return: The corresponding object in the prospection world. """ - self.world_sync.add_obj_queue.join() - try: - return self.world_sync.object_mapping[obj] - except KeyError: - prospection_world = self if self.is_prospection_world else self.prospection_world - if obj in prospection_world.objects: - return obj - else: - raise ValueError( - f"There is no prospection object for the given object: {obj}, this could be the case if" - f" the object isn't anymore in the main (graphical) World" - f" or if the given object is already a prospection object. ") + with UseProspectionWorld(): + return self.world_sync.get_prospection_object(obj) def get_object_for_prospection_object(self, prospection_object: Object) -> Object: """ - Returns the corresponding object from the main World for a given + Return the corresponding object from the main World for a given object in the prospection world. If the given object is not in the prospection world an error will be raised. :param prospection_object: The object for which the corresponding object in the main World should be found. :return: The object in the main World. """ - object_map = self.world_sync.object_mapping - try: - return list(object_map.keys())[list(object_map.values()).index(prospection_object)] - except ValueError: - raise ValueError("The given object is not in the prospection world.") + with UseProspectionWorld(): + return self.world_sync.get_world_object(prospection_object) + + def remove_all_objects(self, exclude_objects: Optional[List[Object]] = None) -> None: + """ + Remove all objects from the World. - def reset_world(self, remove_saved_states=True) -> None: + :param exclude_objects: A list of objects that should not be removed. """ - Resets the World to the state it was first spawned in. + objs_copy = [obj for obj in self.objects] + exclude_objects = [] if exclude_objects is None else exclude_objects + [self.remove_object(obj) for obj in objs_copy if obj not in exclude_objects] + + def reset_world(self, remove_saved_states=False) -> None: + """ + Reset the World to the state it was first spawned in. All attached objects will be detached, all joints will be set to the default position of 0 and all objects will be set to the position and orientation in which they were spawned. :param remove_saved_states: If the saved states should be removed. """ - + self.restore_state(self.original_state_id) if remove_saved_states: self.remove_saved_states() - - for obj in self.objects: - obj.reset(remove_saved_states) + self.original_state_id = self.save_state() def remove_saved_states(self) -> None: """ - Removes all saved states of the World. + Remove all saved states of the World. """ - for state_id in self.saved_states: - self.remove_physics_simulator_state(state_id) + if self.conf.use_physics_simulator_state: + for state_id in self.saved_states: + self.remove_physics_simulator_state(state_id) + else: + self.remove_objects_saved_states() super().remove_saved_states() + self.original_state_id = None + + def remove_objects_saved_states(self) -> None: + """ + Remove all saved states of the objects in the World. + """ + for obj in self.objects: + obj.remove_saved_states() def update_transforms_for_objects_in_current_world(self) -> None: """ @@ -883,6 +1210,12 @@ def create_visual_shape(self, visual_shape: VisualShape) -> int: :param visual_shape: The visual shape to be created, uses the VisualShape dataclass defined in world_dataclasses :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_visual_shape, visual_shape) + + def _create_visual_shape(self, visual_shape: VisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_visual_shape` + """ raise NotImplementedError def create_multi_body_from_visual_shapes(self, visual_shape_ids: List[int], pose: Pose) -> int: @@ -920,51 +1253,92 @@ def create_multi_body(self, multi_body: MultiBody) -> int: :param multi_body: The multi body to be created, uses the MultiBody dataclass defined in world_dataclasses. :return: The unique id of the created multi body. """ + return self._simulator_object_creator(self._create_multi_body, multi_body) + + def _create_multi_body(self, multi_body: MultiBody) -> int: + """ + See :py:meth:`~pycram.world.World.create_multi_body` + """ raise NotImplementedError def create_box_visual_shape(self, shape_data: BoxVisualShape) -> int: """ Creates a box visual shape in the physics simulator and returns the unique id of the created shape. - :param shape_data: The parameters that define the box visual shape to be created, uses the BoxVisualShape dataclass defined in world_dataclasses. + :param shape_data: The parameters that define the box visual shape to be created, uses the BoxVisualShape + dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_box_visual_shape, shape_data) + + def _create_box_visual_shape(self, shape_data: BoxVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_box_visual_shape` + """ raise NotImplementedError def create_cylinder_visual_shape(self, shape_data: CylinderVisualShape) -> int: """ Creates a cylinder visual shape in the physics simulator and returns the unique id of the created shape. - :param shape_data: The parameters that define the cylinder visual shape to be created, uses the CylinderVisualShape dataclass defined in world_dataclasses. + :param shape_data: The parameters that define the cylinder visual shape to be created, uses the + CylinderVisualShape dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_cylinder_visual_shape, shape_data) + + def _create_cylinder_visual_shape(self, shape_data: CylinderVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_cylinder_visual_shape` + """ raise NotImplementedError def create_sphere_visual_shape(self, shape_data: SphereVisualShape) -> int: """ Creates a sphere visual shape in the physics simulator and returns the unique id of the created shape. - :param shape_data: The parameters that define the sphere visual shape to be created, uses the SphereVisualShape dataclass defined in world_dataclasses. + :param shape_data: The parameters that define the sphere visual shape to be created, uses the SphereVisualShape + dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_sphere_visual_shape, shape_data) + + def _create_sphere_visual_shape(self, shape_data: SphereVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_sphere_visual_shape` + """ raise NotImplementedError def create_capsule_visual_shape(self, shape_data: CapsuleVisualShape) -> int: """ Creates a capsule visual shape in the physics simulator and returns the unique id of the created shape. - :param shape_data: The parameters that define the capsule visual shape to be created, uses the CapsuleVisualShape dataclass defined in world_dataclasses. + :param shape_data: The parameters that define the capsule visual shape to be created, uses the + CapsuleVisualShape dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_capsule_visual_shape, shape_data) + + def _create_capsule_visual_shape(self, shape_data: CapsuleVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_capsule_visual_shape` + """ raise NotImplementedError def create_plane_visual_shape(self, shape_data: PlaneVisualShape) -> int: """ Creates a plane visual shape in the physics simulator and returns the unique id of the created shape. - :param shape_data: The parameters that define the plane visual shape to be created, uses the PlaneVisualShape dataclass defined in world_dataclasses. + :param shape_data: The parameters that define the plane visual shape to be created, uses the PlaneVisualShape + dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_plane_visual_shape, shape_data) + + def _create_plane_visual_shape(self, shape_data: PlaneVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_plane_visual_shape` + """ raise NotImplementedError def create_mesh_visual_shape(self, shape_data: MeshVisualShape) -> int: @@ -975,6 +1349,12 @@ def create_mesh_visual_shape(self, shape_data: MeshVisualShape) -> int: uses the MeshVisualShape dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_mesh_visual_shape, shape_data) + + def _create_mesh_visual_shape(self, shape_data: MeshVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_mesh_visual_shape` + """ raise NotImplementedError def add_text(self, text: str, position: List[float], orientation: Optional[List[float]] = None, size: float = 0.1, @@ -985,14 +1365,26 @@ def add_text(self, text: str, position: List[float], orientation: Optional[List[ :param text: The text to be added. :param position: The position of the text in the world. - :param orientation: By default, debug text will always face the camera, automatically rotation. By specifying a text orientation (quaternion), the orientation will be fixed in world space or local space (when parent is specified). + :param orientation: By default, debug text will always face the camera, automatically rotation. By specifying a + text orientation (quaternion), the orientation will be fixed in world space or local space + (when parent is specified). :param size: The size of the text. :param color: The color of the text. - :param life_time: The lifetime in seconds of the text to remain in the world, if 0 the text will remain in the world until it is removed manually. + :param life_time: The lifetime in seconds of the text to remain in the world, if 0 the text will remain in the + world until it is removed manually. :param parent_object_id: The id of the object to which the text should be attached. :param parent_link_id: The id of the link to which the text should be attached. :return: The id of the added text. """ + return self._simulator_object_creator(self._add_text, text, position, orientation, size, color, life_time, + parent_object_id, parent_link_id) + + def _add_text(self, text: str, position: List[float], orientation: Optional[List[float]] = None, size: float = 0.1, + color: Optional[Color] = Color(), life_time: Optional[float] = 0, + parent_object_id: Optional[int] = None, parent_link_id: Optional[int] = None) -> int: + """ + See :py:meth:`~pycram.world.World.add_text` + """ raise NotImplementedError def remove_text(self, text_id: Optional[int] = None) -> None: @@ -1001,6 +1393,12 @@ def remove_text(self, text_id: Optional[int] = None) -> None: :param text_id: The id of the text to be removed. """ + self._simulator_object_remover(self._remove_text, text_id) + + def _remove_text(self, text_id: Optional[int] = None) -> None: + """ + See :py:meth:`~pycram.world.World.remove_text` + """ raise NotImplementedError def enable_joint_force_torque_sensor(self, obj: Object, fts_joint_idx: int) -> None: @@ -1027,7 +1425,7 @@ def disable_joint_force_torque_sensor(self, obj: Object, joint_id: int) -> None: def get_joint_reaction_force_torque(self, obj: Object, joint_id: int) -> List[float]: """ - Returns the joint reaction forces and torques of the specified joint. + Get the joint reaction forces and torques of the specified joint. :param obj: The object in which the joint is located. :param joint_id: The id of the joint for which the force torque should be returned. @@ -1037,7 +1435,7 @@ def get_joint_reaction_force_torque(self, obj: Object, joint_id: int) -> List[fl def get_applied_joint_motor_torque(self, obj: Object, joint_id: int) -> float: """ - Returns the applied torque by a joint motor. + Get the applied torque by a joint motor. :param obj: The object in which the joint is located. :param joint_id: The id of the joint for which the applied motor torque should be returned. @@ -1045,6 +1443,85 @@ def get_applied_joint_motor_torque(self, obj: Object, joint_id: int) -> float: """ raise NotImplementedError + def pause_world_sync(self) -> None: + """ + Pause the world synchronization. + """ + self.world_sync.sync_lock.acquire() + + def resume_world_sync(self) -> None: + """ + Resume the world synchronization. + """ + self.world_sync.sync_lock.release() + + def add_vis_axis(self, pose: Pose) -> int: + """ + Add a visual axis to the world. + + :param pose: The pose of the visual axis. + :return: The id of the added visual axis. + """ + return self._simulator_object_creator(self._add_vis_axis, pose) + + def _add_vis_axis(self, pose: Pose) -> None: + """ + See :py:meth:`~pycram.world.World.add_vis_axis` + """ + rospy.logwarn(f"Visual axis is not supported in {self.__class__.__name__}") + + def remove_vis_axis(self) -> None: + """ + Remove the visual axis from the world. + """ + self._simulator_object_remover(self._remove_vis_axis) + + def _remove_vis_axis(self) -> None: + """ + See :py:meth:`~pycram.world.World.remove_vis_axis` + """ + rospy.logwarn(f"Visual axis is not supported in {self.__class__.__name__}") + + def _simulator_object_creator(self, creator_func: Callable, *args, **kwargs) -> int: + """ + Create an object in the physics simulator and returns the created object id. + + :param creator_func: The function that creates the object in the physics simulator. + :param args: The arguments for the creator function. + :param kwargs: The keyword arguments for the creator function. + :return: The created object id. + """ + obj_id = creator_func(*args, **kwargs) + self.update_simulator_state_id_in_original_state() + return obj_id + + def _simulator_object_remover(self, remover_func: Callable, *args, **kwargs) -> None: + """ + Remove an object from the physics simulator. + + :param remover_func: The function that removes the object from the physics simulator. + :param args: The arguments for the remover function. + :param kwargs: The keyword arguments for the remover function. + """ + remover_func(*args, **kwargs) + self.update_simulator_state_id_in_original_state() + + def update_simulator_state_id_in_original_state(self, use_same_id: bool = False) -> None: + """ + Update the simulator state id in the original state if use_physics_simulator_state is True in the configuration. + + :param use_same_id: If the same id should be used for the state. + """ + if self.conf.use_physics_simulator_state: + self.original_state.simulator_state_id = self.save_physics_simulator_state(use_same_id) + + @property + def original_state(self) -> WorldState: + """ + The saved original state of the world. + """ + return self.saved_states[self.original_state_id] + def __del__(self): self.exit() @@ -1058,36 +1535,25 @@ class UseProspectionWorld: with UseProspectionWorld(): NavigateAction.Action([[1, 0, 0], [0, 0, 0, 1]]).perform() """ - - WAIT_TIME_FOR_ADDING_QUEUE = 20 + WAIT_TIME_AS_N_SIMULATION_STEPS: int = 20 """ - The time in seconds to wait for the adding queue to be ready. + The time in simulation steps to wait before switching to the prospection world """ def __init__(self): self.prev_world: Optional[World] = None # The previous world is saved to restore it after the with block is exited. - def sync_worlds(self): - """ - Synchronizes the state of the prospection world with the main world. - """ - for world_obj, prospection_obj in World.current_world.world_sync.object_mapping.items(): - prospection_obj.current_state = world_obj.current_state - def __enter__(self): """ This method is called when entering the with block, it will set the current world to the prospection world """ if not World.current_world.is_prospection_world: - time.sleep(self.WAIT_TIME_FOR_ADDING_QUEUE * World.current_world.simulation_time_step) - # blocks until the adding queue is ready - World.current_world.world_sync.add_obj_queue.join() - self.sync_worlds() - self.prev_world = World.current_world - World.current_world.world_sync.pause_sync = True World.current_world = World.current_world.prospection_world + World.current_world.resume_world_sync() + time.sleep(self.WAIT_TIME_AS_N_SIMULATION_STEPS * World.current_world.simulation_time_step) + World.current_world.pause_world_sync() def __exit__(self, *args): """ @@ -1095,7 +1561,6 @@ def __exit__(self, *args): """ if self.prev_world is not None: World.current_world = self.prev_world - World.current_world.world_sync.pause_sync = False class WorldSync(threading.Thread): @@ -1103,12 +1568,15 @@ class WorldSync(threading.Thread): Synchronizes the state between the World and its prospection world. Meaning the cartesian and joint position of everything in the prospection world will be synchronized with the main World. - Adding and removing objects is done via queues, such that loading times of objects - in the prospection world does not affect the World. The class provides the possibility to pause the synchronization, this can be used if reasoning should be done in the prospection world. """ + WAIT_TIME_AS_N_SIMULATION_STEPS = 20 + """ + The time in simulation steps to wait between each iteration of the syncing loop. + """ + def __init__(self, world: World, prospection_world: World): threading.Thread.__init__(self) self.world: World = world @@ -1116,48 +1584,109 @@ def __init__(self, world: World, prospection_world: World): self.prospection_world.world_sync = self self.terminate: bool = False - self.add_obj_queue: Queue = Queue() - self.remove_obj_queue: Queue = Queue() self.pause_sync: bool = False # Maps world to prospection world objects - self.object_mapping: Dict[Object, Object] = {} + self.object_to_prospection_object_map: Dict[Object, Object] = {} + self.prospection_object_to_object_map: Dict[Object, Object] = {} self.equal_states = False + self.sync_lock: threading.Lock = threading.Lock() - def run(self, wait_time_as_n_simulation_steps: Optional[int] = 1): + def run(self): """ Main method of the synchronization, this thread runs in a loop until the terminate flag is set. While this loop runs it continuously checks the cartesian and joint position of every object in the World and updates the corresponding object in the - prospection world. When there are entries in the adding or removing queue the corresponding objects will - be added or removed in the same iteration. - - :param wait_time_as_n_simulation_steps: The time in simulation steps to wait between each iteration of - the syncing loop. + prospection world. """ while not self.terminate: - self.check_for_pause() - while not self.add_obj_queue.empty(): - obj = self.add_obj_queue.get() - # Maps the World object to the prospection world object - self.object_mapping[obj] = copy(obj) - self.add_obj_queue.task_done() - while not self.remove_obj_queue.empty(): - obj = self.remove_obj_queue.get() - # Get prospection world object reference from object mapping - prospection_obj = self.object_mapping[obj] - prospection_obj.remove() - del self.object_mapping[obj] - self.remove_obj_queue.task_done() - self.check_for_pause() - time.sleep(wait_time_as_n_simulation_steps * self.world.simulation_time_step) - - def check_for_pause(self) -> None: - """ - Checks if :py:attr:`~self.pause_sync` is true and sleeps this thread until it isn't anymore. - """ - while self.pause_sync: - time.sleep(0.1) + self.sync_lock.acquire() + if not self.terminate: + self.sync_worlds() + self.sync_lock.release() + time.sleep(WorldSync.WAIT_TIME_AS_N_SIMULATION_STEPS * self.world.simulation_time_step) + + def get_world_object(self, prospection_object: Object) -> Object: + """ + Get the corresponding object from the main World for a given object in the prospection world. + + :param prospection_object: The object for which the corresponding object in the main World should be found. + :return: The object in the main World. + """ + try: + return self.prospection_object_to_object_map[prospection_object] + except KeyError: + if prospection_object in self.world.objects: + return prospection_object + raise WorldObjectNotFound(prospection_object) + + def get_prospection_object(self, obj: Object) -> Object: + """ + Get the corresponding object from the prospection world for a given object in the main world. + + :param obj: The object for which the corresponding object in the prospection World should be found. + :return: The corresponding object in the prospection world. + """ + try: + return self.object_to_prospection_object_map[obj] + except KeyError: + if obj in self.prospection_world.objects: + return obj + raise ProspectionObjectNotFound(obj) + + def sync_worlds(self): + """ + Syncs the prospection world with the main world by adding and removing objects and synchronizing their states. + """ + self.remove_objects_not_in_world() + self.add_objects_not_in_prospection_world() + self.prospection_object_to_object_map = {prospection_obj: obj for obj, prospection_obj in + self.object_to_prospection_object_map.items()} + self.sync_objects_states() + + def remove_objects_not_in_world(self): + """ + Removes all objects that are not in the main world from the prospection world. + """ + obj_map_copy = copy(self.object_to_prospection_object_map) + [self.remove_object(obj) for obj in obj_map_copy.keys() if obj not in self.world.objects] + + def add_objects_not_in_prospection_world(self): + """ + Adds all objects that are in the main world but not in the prospection world to the prospection world. + """ + [self.add_object(obj) for obj in self.world.objects if obj not in self.object_to_prospection_object_map] + + def add_object(self, obj: Object) -> None: + """ + Adds an object to the prospection world. + + :param obj: The object to be added. + """ + self.object_to_prospection_object_map[obj] = obj.copy_to_prospection() + + def remove_object(self, obj: Object) -> None: + """ + Removes an object from the prospection world. + + :param obj: The object to be removed. + """ + prospection_obj = self.object_to_prospection_object_map[obj] + prospection_obj.remove() + del self.object_to_prospection_object_map[obj] + + def sync_objects_states(self) -> None: + """ + Synchronizes the state of all objects in the World with the prospection world. + """ + # Set the pose of the prospection objects to the pose of the world objects + obj_pose_dict = {prospection_obj: obj.pose + for obj, prospection_obj in self.object_to_prospection_object_map.items()} + self.world.prospection_world.reset_multiple_objects_base_poses(obj_pose_dict) + for obj, prospection_obj in self.object_to_prospection_object_map.items(): + prospection_obj.set_attachments(obj.attachments) + prospection_obj.link_states = obj.link_states + prospection_obj.joint_states = obj.joint_states def check_for_equal(self) -> bool: """ @@ -1167,7 +1696,12 @@ def check_for_equal(self) -> bool: :return: True if both Worlds have the same state, False otherwise. """ eql = True - for obj, prospection_obj in self.object_mapping.items(): + prospection_names = self.prospection_world.get_object_names() + eql = eql and [name in prospection_names for name in self.world.get_object_names()] + eql = eql and len(prospection_names) == len(self.world.get_object_names()) + if not eql: + return False + for obj, prospection_obj in self.object_to_prospection_object_map.items(): eql = eql and obj.get_pose().dist(prospection_obj.get_pose()) < 0.001 self.equal_states = eql return eql diff --git a/src/pycram/datastructures/world_entity.py b/src/pycram/datastructures/world_entity.py new file mode 100644 index 000000000..1e7c61e06 --- /dev/null +++ b/src/pycram/datastructures/world_entity.py @@ -0,0 +1,77 @@ +from abc import ABC, abstractmethod + +from typing_extensions import TYPE_CHECKING, Dict + +from .dataclasses import State + +if TYPE_CHECKING: + from ..datastructures.world import World + + +class StateEntity: + """ + The StateEntity class is used to store the state of an object or the physics simulator. This is used to save and + restore the state of the World. + """ + + def __init__(self): + self._saved_states: Dict[int, State] = {} + + @property + def saved_states(self) -> Dict[int, State]: + """ + :return: the saved states of this entity. + """ + return self._saved_states + + def save_state(self, state_id: int) -> int: + """ + Saves the state of this entity with the given state id. + + :param state_id: The unique id of the state. + """ + self._saved_states[state_id] = self.current_state + return state_id + + @property + @abstractmethod + def current_state(self) -> State: + """ + :return: The current state of this entity. + """ + pass + + @current_state.setter + @abstractmethod + def current_state(self, state: State) -> None: + """ + Sets the current state of this entity. + + :param state: The new state of this entity. + """ + pass + + def restore_state(self, state_id: int) -> None: + """ + Restores the state of this entity from a saved state using the given state id. + + :param state_id: The unique id of the state. + """ + self.current_state = self.saved_states[state_id] + + def remove_saved_states(self) -> None: + """ + Removes all saved states of this entity. + """ + self._saved_states = {} + + +class WorldEntity(StateEntity, ABC): + """ + A data class that represents an entity of the world, such as an object or a link. + """ + + def __init__(self, _id: int, world: 'World'): + StateEntity.__init__(self) + self.id = _id + self.world: 'World' = world diff --git a/src/pycram/description.py b/src/pycram/description.py index 0ad05d7f1..bd3002ed7 100644 --- a/src/pycram/description.py +++ b/src/pycram/description.py @@ -1,34 +1,36 @@ from __future__ import annotations import logging +import os import pathlib from abc import ABC, abstractmethod import rospy +import trimesh from geometry_msgs.msg import Point, Quaternion -from typing_extensions import Tuple, Union, Any, List, Optional, Dict, TYPE_CHECKING +from typing_extensions import Tuple, Union, Any, List, Optional, Dict, TYPE_CHECKING, Self, deprecated +from .datastructures.dataclasses import JointState, AxisAlignedBoundingBox, Color, LinkState, VisualShape from .datastructures.enums import JointType -from .local_transformer import LocalTransformer from .datastructures.pose import Pose, Transform -from .datastructures.world import WorldEntity -from .datastructures.dataclasses import JointState, AxisAlignedBoundingBox, Color, LinkState, VisualShape +from .datastructures.world_entity import WorldEntity +from .failures import ObjectDescriptionNotFound +from .local_transformer import LocalTransformer if TYPE_CHECKING: from .world_concepts.world_object import Object class EntityDescription(ABC): - """ - A class that represents a description of an entity. This can be a link, joint or object description. + A description of an entity. This can be a link, joint or object description. """ @property @abstractmethod def origin(self) -> Pose: """ - Returns the origin of this entity. + :return: the origin of this entity. """ pass @@ -36,14 +38,14 @@ def origin(self) -> Pose: @abstractmethod def name(self) -> str: """ - Returns the name of this entity. + :return: the name of this entity. """ pass class LinkDescription(EntityDescription): """ - A class that represents a link description of an object. + A link description of an object. """ def __init__(self, parsed_link_description: Any): @@ -53,7 +55,7 @@ def __init__(self, parsed_link_description: Any): @abstractmethod def geometry(self) -> Union[VisualShape, None]: """ - Returns the geometry type of the collision element of this link. + The geometry type of the collision element of this link. """ pass @@ -63,8 +65,13 @@ class JointDescription(EntityDescription): A class that represents the description of a joint. """ - def __init__(self, parsed_joint_description: Any): + def __init__(self, parsed_joint_description: Optional[Any] = None, is_virtual: bool = False): + """ + :param parsed_joint_description: The parsed description of the joint (e.g. from urdf or mjcf file). + :param is_virtual: True if the joint is virtual (i.e. not a physically existing joint), False otherwise. + """ self.parsed_description = parsed_joint_description + self.is_virtual: Optional[bool] = is_virtual @property @abstractmethod @@ -86,8 +93,6 @@ def axis(self) -> Point: @abstractmethod def has_limits(self) -> bool: """ - Checks if this joint has limits. - :return: True if the joint has limits, False otherwise. """ pass @@ -120,7 +125,7 @@ def upper_limit(self) -> Union[float, None]: @property @abstractmethod - def parent_link_name(self) -> str: + def parent(self) -> str: """ :return: The name of the parent link of this joint. """ @@ -128,7 +133,7 @@ def parent_link_name(self) -> str: @property @abstractmethod - def child_link_name(self) -> str: + def child(self) -> str: """ :return: The name of the child link of this joint. """ @@ -159,6 +164,13 @@ def __init__(self, _id: int, obj: Object): WorldEntity.__init__(self, _id, obj.world) self.object: Object = obj + @property + def object_name(self) -> str: + """ + The name of the object to which this joint belongs. + """ + return self.object.name + @property @abstractmethod def pose(self) -> Pose: @@ -170,7 +182,7 @@ def pose(self) -> Pose: @property def transform(self) -> Transform: """ - Returns the transform of this entity. + The transform of this entity. :return: The transform of this entity. """ @@ -180,7 +192,7 @@ def transform(self) -> Transform: @abstractmethod def tf_frame(self) -> str: """ - Returns the tf frame of this entity. + The tf frame of this entity. :return: The tf frame of this entity. """ @@ -196,7 +208,7 @@ def object_id(self) -> int: class Link(ObjectEntity, LinkDescription, ABC): """ - Represents a link of an Object in the World. + A link of an Object in the World. """ def __init__(self, _id: int, link_description: LinkDescription, obj: Object): @@ -204,7 +216,48 @@ def __init__(self, _id: int, link_description: LinkDescription, obj: Object): LinkDescription.__init__(self, link_description.parsed_description) self.local_transformer: LocalTransformer = LocalTransformer() self.constraint_ids: Dict[Link, int] = {} - self._update_pose() + self._current_pose: Optional[Pose] = None + self.update_pose() + + def set_pose(self, pose: Pose) -> None: + """ + Set the pose of this link to the given pose. + NOTE: This will move the entire object such that the link is at the given pose, it will not consider any joints + that can allow the link to be at the given pose. + + :param pose: The target pose for this link. + """ + self.object.set_pose(self.get_object_pose_given_link_pose(pose)) + + def get_object_pose_given_link_pose(self, pose): + """ + Get the object pose given the link pose, which could be a hypothetical link pose to see what would be the object + pose in that case (assuming that the object itself moved not the joints). + + :param pose: The link pose. + """ + return (pose.to_transform(self.tf_frame) * self.get_transform_to_root_link()).to_pose() + + def get_pose_given_object_pose(self, pose): + """ + Get the link pose given the object pose, which could be a hypothetical object pose to see what would be the link + pose in that case (assuming that the object itself moved not the joints). + + :param pose: The object pose. + """ + return (pose.to_transform(self.object.tf_frame) * self.get_transform_from_root_link()).to_pose() + + def get_transform_from_root_link(self) -> Transform: + """ + Return the transformation from the root link of the object to this link. + """ + return self.get_transform_from_link(self.object.root_link) + + def get_transform_to_root_link(self) -> Transform: + """ + Return the transformation from this link to the root link of the object. + """ + return self.get_transform_to_link(self.object.root_link) @property def current_state(self) -> LinkState: @@ -212,25 +265,28 @@ def current_state(self) -> LinkState: @current_state.setter def current_state(self, link_state: LinkState) -> None: - self.constraint_ids = link_state.constraint_ids + if self.current_state != link_state: + self.constraint_ids = link_state.constraint_ids - def add_fixed_constraint_with_link(self, child_link: 'Link') -> int: + def add_fixed_constraint_with_link(self, child_link: Self, + child_to_parent_transform: Optional[Transform] = None) -> int: """ - Adds a fixed constraint between this link and the given link, used to create attachments for example. + Add a fixed constraint between this link and the given link, to create attachments for example. :param child_link: The child link to which a fixed constraint should be added. + :param child_to_parent_transform: The transformation between the two links. :return: The unique id of the constraint. """ - constraint_id = self.world.add_fixed_constraint(self, - child_link, - child_link.get_transform_from_link(self)) + if child_to_parent_transform is None: + child_to_parent_transform = child_link.get_transform_to_link(self) + constraint_id = self.world.add_fixed_constraint(self, child_link, child_to_parent_transform) self.constraint_ids[child_link] = constraint_id child_link.constraint_ids[self] = constraint_id return constraint_id def remove_constraint_with_link(self, child_link: 'Link') -> None: """ - Removes the constraint between this link and the given link. + Remove the constraint between this link and the given link. :param child_link: The child link of the constraint that should be removed. """ @@ -240,17 +296,22 @@ def remove_constraint_with_link(self, child_link: 'Link') -> None: del child_link.constraint_ids[self] @property - def is_root(self) -> bool: + def is_only_link(self) -> bool: + """ + :return: True if this link is the only link, False otherwise. """ - Returns whether this link is the root link of the object. + return self.object.has_one_link + @property + def is_root(self) -> bool: + """ :return: True if this link is the root link, False otherwise. """ return self.object.get_root_link_id() == self.id def update_transform(self, transform_time: Optional[rospy.Time] = None) -> None: """ - Updates the transformation of this link at the given time. + Update the transformation of this link at the given time. :param transform_time: The time at which the transformation should be updated. """ @@ -258,8 +319,6 @@ def update_transform(self, transform_time: Optional[rospy.Time] = None) -> None: def get_transform_to_link(self, link: 'Link') -> Transform: """ - Returns the transformation from this link to the given link. - :param link: The link to which the transformation should be returned. :return: A Transform object with the transformation from this link to the given link. """ @@ -267,8 +326,6 @@ def get_transform_to_link(self, link: 'Link') -> Transform: def get_transform_from_link(self, link: 'Link') -> Transform: """ - Returns the transformation from the given link to this link. - :param link: The link from which the transformation should be returned. :return: A Transform object with the transformation from the given link to this link. """ @@ -276,8 +333,6 @@ def get_transform_from_link(self, link: 'Link') -> Transform: def get_pose_wrt_link(self, link: 'Link') -> Pose: """ - Returns the pose of this link with respect to the given link. - :param link: The link with respect to which the pose should be returned. :return: A Pose object with the pose of this link with respect to the given link. """ @@ -285,8 +340,6 @@ def get_pose_wrt_link(self, link: 'Link') -> Pose: def get_axis_aligned_bounding_box(self) -> AxisAlignedBoundingBox: """ - Returns the axis aligned bounding box of this link. - :return: An AxisAlignedBoundingBox object with the axis aligned bounding box of this link. """ return self.world.get_link_axis_aligned_bounding_box(self) @@ -294,8 +347,6 @@ def get_axis_aligned_bounding_box(self) -> AxisAlignedBoundingBox: @property def position(self) -> Point: """ - The getter for the position of the link relative to the world frame. - :return: A Point object containing the position of the link relative to the world frame. """ return self.pose.position @@ -303,8 +354,6 @@ def position(self) -> Point: @property def position_as_list(self) -> List[float]: """ - The getter for the position of the link relative to the world frame as a list. - :return: A list containing the position of the link relative to the world frame. """ return self.pose.position_as_list() @@ -312,8 +361,6 @@ def position_as_list(self) -> List[float]: @property def orientation(self) -> Quaternion: """ - The getter for the orientation of the link relative to the world frame. - :return: A Quaternion object containing the orientation of the link relative to the world frame. """ return self.pose.orientation @@ -321,55 +368,58 @@ def orientation(self) -> Quaternion: @property def orientation_as_list(self) -> List[float]: """ - The getter for the orientation of the link relative to the world frame as a list. - :return: A list containing the orientation of the link relative to the world frame. """ return self.pose.orientation_as_list() - def _update_pose(self) -> None: + def update_pose(self) -> None: """ - Updates the current pose of this link from the world. + Update the current pose of this link from the world. """ self._current_pose = self.world.get_link_pose(self) @property def pose(self) -> Pose: """ - The pose of the link relative to the world frame. - :return: A Pose object containing the pose of the link relative to the world frame. """ + if self.world.conf.update_poses_from_sim_on_get: + self.update_pose() return self._current_pose @property def pose_as_list(self) -> List[List[float]]: """ - The pose of the link relative to the world frame as a list. - :return: A list containing the position and orientation of the link relative to the world frame. """ return self.pose.to_list() def get_origin_transform(self) -> Transform: """ - Returns the transformation between the link frame and the origin frame of this link. + :return: the transformation between the link frame and the origin frame of this link. """ return self.origin.to_transform(self.tf_frame) @property def color(self) -> Color: """ - The getter for the rgba_color of this link. - :return: A Color object containing the rgba_color of this link. """ return self.world.get_link_color(self) + @deprecated("Use color property setter instead") + def set_color(self, color: Color) -> None: + """ + Set the color of this link, could be rgb or rgba. + + :param color: The color as a list of floats, either rgb or rgba. + """ + self.color = color + @color.setter def color(self, color: Color) -> None: """ - The setter for the color of this link, could be rgb or rgba. + Set the color of this link, could be rgb or rgba. :param color: The color as a list of floats, either rgb or rgba. """ @@ -401,8 +451,8 @@ def __hash__(self): class RootLink(Link, ABC): """ - Represents the root link of an Object in the World. - It differs from the normal AbstractLink class in that the pose ande the tf_frame is the same as that of the object. + The root link of an Object in the World. + This differs from the normal AbstractLink class in that the pose and the tf_frame is the same as that of the object. """ def __init__(self, obj: Object): @@ -411,12 +461,12 @@ def __init__(self, obj: Object): @property def tf_frame(self) -> str: """ - Returns the tf frame of the root link, which is the same as the tf frame of the object. + :return: the tf frame of the root link, which is the same as the tf frame of the object. """ return self.object.tf_frame - def _update_pose(self) -> None: - self._current_pose = self.object.get_pose() + def update_pose(self) -> None: + self._current_pose = self.world.get_object_pose(self.object) def __copy__(self): return RootLink(self.object) @@ -424,14 +474,16 @@ def __copy__(self): class Joint(ObjectEntity, JointDescription, ABC): """ - Represents a joint of an Object in the World. + Represent a joint of an Object in the World. """ def __init__(self, _id: int, joint_description: JointDescription, - obj: Object): + obj: Object, is_virtual: Optional[bool] = False): ObjectEntity.__init__(self, _id, obj) - JointDescription.__init__(self, joint_description.parsed_description) + JointDescription.__init__(self, joint_description.parsed_description, is_virtual) + self.acceptable_error = (self.world.conf.revolute_joint_position_tolerance if self.type == JointType.REVOLUTE + else self.world.conf.prismatic_joint_position_tolerance) self._update_position() @property @@ -444,38 +496,34 @@ def tf_frame(self) -> str: @property def pose(self) -> Pose: """ - Returns the pose of this joint. The pose is the pose of the child link of this joint. - - :return: The pose of this joint. + :return: The pose of this joint. The pose is the pose of the child link of this joint. """ return self.child_link.pose def _update_position(self) -> None: """ - Updates the current position of the joint from the physics simulator. + Update the current position of the joint from the physics simulator. """ self._current_position = self.world.get_joint_position(self) @property def parent_link(self) -> Link: """ - Returns the parent link of this joint. - :return: The parent link as a AbstractLink object. """ - return self.object.get_link(self.parent_link_name) + return self.object.get_link(self.parent) @property def child_link(self) -> Link: """ - Returns the child link of this joint. - :return: The child link as a AbstractLink object. """ - return self.object.get_link(self.child_link_name) + return self.object.get_link(self.child) @property def position(self) -> float: + if self.world.conf.update_poses_from_sim_on_get: + self._update_position() return self._current_position def reset_position(self, position: float) -> None: @@ -484,8 +532,6 @@ def reset_position(self, position: float) -> None: def get_object_id(self) -> int: """ - Returns the id of the object to which this joint belongs. - :return: The integer id of the object to which this joint belongs. """ return self.object.id @@ -493,8 +539,8 @@ def get_object_id(self) -> int: @position.setter def position(self, joint_position: float) -> None: """ - Sets the position of the given joint to the given joint pose. If the pose is outside the joint limits, - an error will be printed. However, the joint will be set either way. + Set the position of the given joint to the given joint pose. If the pose is outside the joint limits, + issue a warning. However, set the joint either way. :param joint_position: The target pose for this joint """ @@ -524,16 +570,16 @@ def get_applied_motor_torque(self) -> float: @property def current_state(self) -> JointState: - return JointState(self.position) + return JointState(self.position, self.acceptable_error) @current_state.setter def current_state(self, joint_state: JointState) -> None: """ - Updates the current state of this joint from the given joint state if the position is different. + Update the current state of this joint from the given joint state if the position is different. :param joint_state: The joint state to update from. """ - if self._current_position != joint_state.position: + if self.current_state != joint_state: self.position = joint_state.position def __copy__(self): @@ -547,12 +593,11 @@ def __hash__(self): class ObjectDescription(EntityDescription): - """ A class that represents the description of an object. """ - mesh_extensions: Tuple[str] = (".obj", ".stl", ".dae") + mesh_extensions: Tuple[str] = (".obj", ".stl", ".dae", ".ply") """ The file extensions of the mesh files that can be used to generate a description file. """ @@ -570,23 +615,107 @@ def __init__(self, path: Optional[str] = None): """ :param path: The path of the file to update the description data from. """ + + self._links: Optional[List[LinkDescription]] = None + self._joints: Optional[List[JointDescription]] = None + self._link_map: Optional[Dict[str, Any]] = None + self._joint_map: Optional[Dict[str, Any]] = None + if path: self.update_description_from_file(path) else: self._parsed_description = None + self.virtual_joint_names: List[str] = [] + + @property + @abstractmethod + def child_map(self) -> Dict[str, List[Tuple[str, str]]]: + """ + :return: A dictionary mapping the name of a link to its children which are represented as a tuple of the child + joint name and the link name. + """ + pass + + @property + @abstractmethod + def parent_map(self) -> Dict[str, Tuple[str, str]]: + """ + :return: A dictionary mapping the name of a link to its parent joint and link as a tuple. + """ + pass + + @property + @abstractmethod + def link_map(self) -> Dict[str, LinkDescription]: + """ + :return: A dictionary mapping the name of a link to its description. + """ + pass + + @property + @abstractmethod + def joint_map(self) -> Dict[str, JointDescription]: + """ + :return: A dictionary mapping the name of a joint to its description. + """ + pass + + def is_joint_virtual(self, name: str) -> bool: + """ + :param name: The name of the joint. + :return: True if the joint is virtual, False otherwise. + """ + return name in self.virtual_joint_names + + @abstractmethod + def add_joint(self, name: str, child: str, joint_type: JointType, + axis: Point, parent: Optional[str] = None, origin: Optional[Pose] = None, + lower_limit: Optional[float] = None, upper_limit: Optional[float] = None, + is_virtual: Optional[bool] = False) -> None: + """ + Add a joint to this object. + + :param name: The name of the joint. + :param child: The name of the child link. + :param joint_type: The type of the joint. + :param axis: The axis of the joint. + :param parent: The name of the parent link. + :param origin: The origin of the joint. + :param lower_limit: The lower limit of the joint. + :param upper_limit: The upper limit of the joint. + :param is_virtual: True if the joint is virtual, False otherwise. + """ + pass + def update_description_from_file(self, path: str) -> None: """ - Updates the description of this object from the file at the given path. + Update the description of this object from the file at the given path. :param path: The path of the file to update from. """ self._parsed_description = self.load_description(path) + def update_description_from_string(self, description_string: str) -> None: + """ + Update the description of this object from the given description string. + + :param description_string: The description string to update from. + """ + self._parsed_description = self.load_description_from_string(description_string) + + def load_description_from_string(self, description_string: str) -> Any: + """ + Load the description from the given string. + + :param description_string: The description string to load from. + """ + raise NotImplementedError + @property def parsed_description(self) -> Any: """ - Return the object parsed from the description file. + :return: The object parsed from the description file. """ return self._parsed_description @@ -600,46 +729,74 @@ def parsed_description(self, parsed_description: Any): @abstractmethod def load_description(self, path: str) -> Any: """ - Loads the description from the file at the given path. + Load the description from the file at the given path. :param path: The path to the source file, if only a filename is provided then the resources directories will be searched. """ pass - def generate_description_from_file(self, path: str, name: str, extension: str) -> str: + def generate_description_from_file(self, path: str, name: str, extension: str, save_path: str, + scale_mesh: Optional[float] = None) -> None: """ - Generates and preprocesses the description from the file at the given path and returns the preprocessed - description as a string. + Generate and preprocess the description from the file at the given path and save the preprocessed + description. The generated description will be saved at the given save path. :param path: The path of the file to preprocess. :param name: The name of the object. :param extension: The file extension of the file to preprocess. - :return: The processed description string. + :param save_path: The path to save the generated description file. + :param scale_mesh: The scale of the mesh. + :raises ObjectDescriptionNotFound: If the description file could not be found/read. """ - description_string = None if extension in self.mesh_extensions: - description_string = self.generate_from_mesh_file(path, name) + if extension == ".ply": + mesh = trimesh.load(path) + path = path.replace(extension, ".obj") + if scale_mesh is not None: + mesh.apply_scale(scale_mesh) + mesh.export(path) + self.generate_from_mesh_file(path, name, save_path=save_path) elif extension == self.get_file_extension(): - description_string = self.generate_from_description_file(path) + self.generate_from_description_file(path, save_path=save_path) else: try: # Using the description from the parameter server - description_string = self.generate_from_parameter_server(path) + self.generate_from_parameter_server(path, save_path=save_path) except KeyError: - logging.warning(f"Couldn't find dile data in the ROS parameter server") - if description_string is None: - logging.error(f"Could not find file with path {path} in the resources directory nor" - f" in the ros parameter server.") - raise FileNotFoundError + logging.warning(f"Couldn't find file data in the ROS parameter server") - return description_string + if not self.check_description_file_exists_and_can_be_read(save_path): + raise ObjectDescriptionNotFound(name, path, extension) - def get_file_name(self, path_object: pathlib.Path, extension: str, object_name: str) -> str: + @staticmethod + def check_description_file_exists_and_can_be_read(path: str) -> bool: """ - Returns the file name of the description file. + Check if the description file exists at the given path. + :param path: The path to the description file. + :return: True if the file exists, False otherwise. + """ + exists = os.path.exists(path) + if exists: + with open(path, "r") as file: + exists = bool(file.read()) + return exists + + @staticmethod + def write_description_to_file(description_string: str, save_path: str) -> None: + """ + Write the description string to the file at the given path. + + :param description_string: The description string to write. + :param save_path: The path of the file to write to. + """ + with open(save_path, "w") as file: + file.write(description_string) + + def get_file_name(self, path_object: pathlib.Path, extension: str, object_name: str) -> str: + """ :param path_object: The path object of the description file or the mesh file. :param extension: The file extension of the description file or the mesh file. :param object_name: The name of the object. @@ -656,36 +813,39 @@ def get_file_name(self, path_object: pathlib.Path, extension: str, object_name: @classmethod @abstractmethod - def generate_from_mesh_file(cls, path: str, name: str) -> str: + def generate_from_mesh_file(cls, path: str, name: str, save_path: str) -> None: """ - Generates a description file from one of the mesh types defined in the mesh_extensions and - returns the path of the generated file. + Generate a description file from one of the mesh types defined in the mesh_extensions and + return the path of the generated file. The generated file will be saved at the given save_path. :param path: The path to the .obj file. :param name: The name of the object. - :return: The path of the generated description file. + :param save_path: The path to save the generated description file. """ pass @classmethod @abstractmethod - def generate_from_description_file(cls, path: str) -> str: + def generate_from_description_file(cls, path: str, save_path: str, make_mesh_paths_absolute: bool = True) -> None: """ - Preprocesses the given file and returns the preprocessed description string. + Preprocess the given file and return the preprocessed description string. The preprocessed description will be + saved at the given save_path. :param path: The path of the file to preprocess. - :return: The preprocessed description string. + :param save_path: The path to save the preprocessed description file. + :param make_mesh_paths_absolute: Whether to make the mesh paths absolute. """ pass @classmethod @abstractmethod - def generate_from_parameter_server(cls, name: str) -> str: + def generate_from_parameter_server(cls, name: str, save_path: str) -> None: """ - Preprocesses the description from the ROS parameter server and returns the preprocessed description string. + Preprocess the description from the ROS parameter server and return the preprocessed description string. + The preprocessed description will be saved at the given save_path. :param name: The name of the description on the parameter server. - :return: The preprocessed description string. + :param save_path: The path to save the preprocessed description file. """ pass @@ -697,12 +857,11 @@ def links(self) -> List[LinkDescription]: """ pass - @abstractmethod def get_link_by_name(self, link_name: str) -> LinkDescription: """ :return: The link description with the given name. """ - pass + return self.link_map[link_name] @property @abstractmethod @@ -712,12 +871,11 @@ def joints(self) -> List[JointDescription]: """ pass - @abstractmethod def get_joint_by_name(self, joint_name: str) -> JointDescription: """ :return: The joint description with the given name. """ - pass + return self.joint_map[joint_name] @abstractmethod def get_root(self) -> str: @@ -726,8 +884,15 @@ def get_root(self) -> str: """ pass + def get_tip(self) -> str: + """ + :return: the name of the tip link of this object. + """ + raise NotImplementedError + @abstractmethod - def get_chain(self, start_link_name: str, end_link_name: str) -> List[str]: + def get_chain(self, start_link_name: str, end_link_name: str, joints: Optional[bool] = True, + links: Optional[bool] = True, fixed: Optional[bool] = True) -> List[str]: """ :return: the chain of links from 'start_link_name' to 'end_link_name'. """ diff --git a/src/pycram/designator.py b/src/pycram/designator.py index f53716cb9..6ec97d037 100644 --- a/src/pycram/designator.py +++ b/src/pycram/designator.py @@ -365,9 +365,8 @@ def ground(self) -> Any: def get_slots(self) -> List[str]: """ - Returns a list of all slots of this description. Can be used for inspecting different descriptions and debugging. - - :return: A list of all slots. + :return: a list of all slots of this description. Can be used for inspecting different descriptions and + debugging. """ return list(self.__dict__.keys()) @@ -376,7 +375,7 @@ def copy(self) -> DesignatorDescription: def get_default_ontology_concept(self) -> owlready2.Thing | None: """ - Returns the first element of ontology_concept_holders if there is, else None + :return: The first element of ontology_concept_holders if there is, else None """ return self.ontology_concept_holders[0].ontology_concept if self.ontology_concept_holders else None @@ -597,8 +596,6 @@ def insert(self, session: Session) -> ORMObjectDesignator: def frozen_copy(self) -> 'ObjectDesignatorDescription.Object': """ - Returns a copy of this designator containing only the fields. - :return: A copy containing only the fields of this class. The WorldObject attached to this pycram object is not copied. The _pose gets set to a method that statically returns the pose of the object when this method was called. """ result = ObjectDesignatorDescription.Object(self.name, self.obj_type, None) @@ -633,7 +630,7 @@ def __repr__(self): def special_knowledge_adjustment_pose(self, grasp: str, pose: Pose) -> Pose: """ - Returns the adjusted target pose based on special knowledge for "grasp front". + Get the adjusted target pose based on special knowledge for "grasp front". :param grasp: From which side the object should be grasped :param pose: Pose at which the object should be grasped, before adjustment diff --git a/src/pycram/designators/action_designator.py b/src/pycram/designators/action_designator.py index 3029c680d..2c77c29fe 100644 --- a/src/pycram/designators/action_designator.py +++ b/src/pycram/designators/action_designator.py @@ -8,17 +8,14 @@ import numpy as np from sqlalchemy.orm import Session from tf import transformations -from typing_extensions import Any, List, Union, Callable, Optional, Type - -import rospy +from typing_extensions import List, Union, Callable, Optional, Type from .location_designator import CostmapLocation from .motion_designator import MoveJointsMotion, MoveGripperMotion, MoveArmJointsMotion, MoveTCPMotion, MoveMotion, \ LookingMotion, DetectingMotion, OpeningMotion, ClosingMotion from .object_designator import ObjectDesignatorDescription, BelieveObject, ObjectPart from ..local_transformer import LocalTransformer -from ..plan_failures import ObjectUnfetchable, ReachabilityFailure -# from ..robot_descriptions import robot_description +from ..failures import ObjectUnfetchable, ReachabilityFailure from ..robot_description import RobotDescription from ..tasktree import with_tree @@ -874,7 +871,7 @@ def perform(self) -> None: ParkArmsActionPerformable(Arms.BOTH).perform() pickup_loc = CostmapLocation(target=self.object_designator, reachable_for=robot_desig.resolve(), reachable_arm=self.arm) - # Tries to find a pick-up posotion for the robot that uses the given arm + # Tries to find a pick-up position for the robot that uses the given arm pickup_pose = None for pose in pickup_loc: if self.arm in pose.reachable_arms: diff --git a/src/pycram/designators/location_designator.py b/src/pycram/designators/location_designator.py index d59f60a30..5bd50bcfb 100644 --- a/src/pycram/designators/location_designator.py +++ b/src/pycram/designators/location_designator.py @@ -178,6 +178,7 @@ def __iter__(self): if self.visible_for or self.reachable_for: robot_object = self.visible_for.world_object if self.visible_for else self.reachable_for.world_object test_robot = World.current_world.get_prospection_object_for_object(robot_object) + with UseProspectionWorld(): for maybe_pose in PoseGenerator(final_map, number_of_samples=600): res = True @@ -249,7 +250,6 @@ def __iter__(self) -> Location: final_map = occupancy + gaussian - test_robot = World.current_world.get_prospection_object_for_object(self.robot) # Find a Joint of type prismatic which is above the handle in the URDF tree @@ -283,8 +283,10 @@ def __iter__(self) -> Location: valid_goal, arms_goal = reachability_validator(maybe_pose, test_robot, goal_pose, allowed_collision={test_robot: hand_links}) - if valid_init and valid_goal: - yield self.Location(maybe_pose, list(set(arms_init).intersection(set(arms_goal)))) + arms_list = list(set(arms_init).intersection(set(arms_goal))) + + if valid_init and valid_goal and len(arms_list) > 0: + yield self.Location(maybe_pose, arms_list) class SemanticCostmapLocation(LocationDesignatorDescription): diff --git a/src/pycram/designators/motion_designator.py b/src/pycram/designators/motion_designator.py index 5e1174edd..589d5f9ad 100644 --- a/src/pycram/designators/motion_designator.py +++ b/src/pycram/designators/motion_designator.py @@ -5,7 +5,7 @@ from .object_designator import ObjectDesignatorDescription, ObjectPart, RealObject from ..designator import ResolutionError from ..orm.base import ProcessMetaData -from ..plan_failures import PerceptionObjectNotFound +from ..failures import PerceptionObjectNotFound from ..process_module import ProcessModuleManager from ..orm.motion_designator import (MoveMotion as ORMMoveMotion, MoveTCPMotion as ORMMoveTCPMotion, LookingMotion as ORMLookingMotion, diff --git a/src/pycram/designators/specialized_designators/location/giskard_location.py b/src/pycram/designators/specialized_designators/location/giskard_location.py index 1400a8e63..de0d0a6e8 100644 --- a/src/pycram/designators/specialized_designators/location/giskard_location.py +++ b/src/pycram/designators/specialized_designators/location/giskard_location.py @@ -53,7 +53,7 @@ def __iter__(self) -> CostmapLocation.Location: prospection_robot = World.current_world.get_prospection_object_for_object(World.robot) with UseProspectionWorld(): - prospection_robot.set_joint_positions(robot_joint_states) + prospection_robot.set_multiple_joint_positions(robot_joint_states) prospection_robot.set_pose(pose) gripper_pose = prospection_robot.get_link_pose(chain.get_tool_frame()) diff --git a/src/pycram/designators/specialized_designators/probabilistic/probabilistic_action.py b/src/pycram/designators/specialized_designators/probabilistic/probabilistic_action.py index ed4ed0d70..606969672 100644 --- a/src/pycram/designators/specialized_designators/probabilistic/probabilistic_action.py +++ b/src/pycram/designators/specialized_designators/probabilistic/probabilistic_action.py @@ -19,7 +19,7 @@ from ....designator import ActionDesignatorDescription, ObjectDesignatorDescription from ....local_transformer import LocalTransformer from ....orm.views import PickUpWithContextView -from ....plan_failures import ObjectUnreachable, PlanFailure +from ....failures import ObjectUnreachable, PlanFailure class Grasp(SetElement): diff --git a/src/pycram/external_interfaces/giskard.py b/src/pycram/external_interfaces/giskard.py index bf10618cb..269be7e9b 100644 --- a/src/pycram/external_interfaces/giskard.py +++ b/src/pycram/external_interfaces/giskard.py @@ -21,7 +21,7 @@ try: from giskardpy.python_interface.old_python_interface import OldGiskardWrapper as GiskardWrapper - from giskard_msgs.msg import WorldBody, CollisionEntry + from giskard_msgs.msg import WorldBody, MoveResult, CollisionEntry except ModuleNotFoundError as e: rospy.logwarn("Failed to import Giskard messages, the real robot will not be available") @@ -168,7 +168,7 @@ def spawn_object(object: Object) -> None: :param object: World object that should be spawned """ if len(object.link_name_to_id) == 1: - geometry = object.get_link_geometry(object.root_link_name) + geometry = object.get_link_geometry(object.root_link.name) if isinstance(geometry, MeshVisualShape): filename = geometry.file_name spawn_mesh(object.name, filename, object.get_pose()) @@ -590,9 +590,7 @@ def allow_gripper_collision(gripper: str) -> None: @init_giskard_interface def get_gripper_group_names() -> List[str]: """ - Returns a list of groups that are registered in giskard which have 'gripper' in their name. - - :return: The list of gripper groups + :return: The list of groups that are registered in giskard which have 'gripper' in their name. """ groups = giskard_wrapper.get_group_names() return list(filter(lambda elem: "gripper" in elem, groups)) @@ -601,7 +599,7 @@ def get_gripper_group_names() -> List[str]: @init_giskard_interface def add_gripper_groups() -> None: """ - Adds the gripper links as a group for collision avoidance. + Add the gripper links as a group for collision avoidance. :return: Response of the RegisterGroup Service """ @@ -645,7 +643,7 @@ def avoid_collisions(object1: Object, object2: Object) -> None: @init_giskard_interface def make_world_body(object: Object) -> 'WorldBody': """ - Creates a WorldBody message for a World Object. The WorldBody will contain the URDF of the World Object + Create a WorldBody message for a World Object. The WorldBody will contain the URDF of the World Object :param object: The World Object :return: A WorldBody message for the World Object diff --git a/src/pycram/external_interfaces/ik.py b/src/pycram/external_interfaces/ik.py index 17ceca769..83f084ac6 100644 --- a/src/pycram/external_interfaces/ik.py +++ b/src/pycram/external_interfaces/ik.py @@ -14,7 +14,7 @@ from ..local_transformer import LocalTransformer from ..datastructures.pose import Pose from ..robot_description import RobotDescription -from ..plan_failures import IKError +from ..failures import IKError from ..external_interfaces.giskard import projection_cartesian_goal, allow_gripper_collision @@ -234,7 +234,7 @@ def request_giskard_ik(target_pose: Pose, robot: Object, gripper: str) -> Tuple[ robot_joint_states[joint_name] = state with UseProspectionWorld(): - prospection_robot.set_joint_positions(robot_joint_states) + prospection_robot.set_multiple_joint_positions(robot_joint_states) prospection_robot.set_pose(pose) tip_pose = prospection_robot.get_link_pose(gripper) diff --git a/src/pycram/failure_handling.py b/src/pycram/failure_handling.py index 3e409eb6b..1c53061a4 100644 --- a/src/pycram/failure_handling.py +++ b/src/pycram/failure_handling.py @@ -1,6 +1,6 @@ from .datastructures.enums import State from .designator import DesignatorDescription -from .plan_failures import PlanFailure +from .failures import PlanFailure from threading import Lock from typing_extensions import Union, Tuple, Any, List from .language import Language, Monitor diff --git a/src/pycram/plan_failures.py b/src/pycram/failures.py similarity index 78% rename from src/pycram/plan_failures.py rename to src/pycram/failures.py index e8ac32fe1..736adf8ee 100644 --- a/src/pycram/plan_failures.py +++ b/src/pycram/failures.py @@ -1,3 +1,12 @@ +from pathlib import Path + +from typing_extensions import TYPE_CHECKING, List + +if TYPE_CHECKING: + from .world_concepts.world_object import Object + from .datastructures.enums import JointType + + class PlanFailure(Exception): """Implementation of plan failures.""" @@ -127,8 +136,10 @@ def __init__(self, *args, **kwargs): class IKError(PlanFailure): """Thrown when no inverse kinematics solution could be found""" + def __init__(self, pose, base_frame, tip_frame): - self.message = "Position {} in frame '{}' is not reachable for end effector: '{}'".format(pose, base_frame, tip_frame) + self.message = "Position {} in frame '{}' is not reachable for end effector: '{}'".format(pose, base_frame, + tip_frame) super(IKError, self).__init__(self.message) @@ -395,3 +406,66 @@ def __init__(*args, **kwargs): class CollisionError(PlanFailure): def __init__(*args, **kwargs): super().__init__(*args, **kwargs) + + +""" +The following exceptions are used in the PyCRAM framework to handle errors related to the world and the objects in it. +They are usually related to a bug in the code or a misuse of the framework (e.g. logical errors in the code). +""" + + +class ProspectionObjectNotFound(KeyError): + def __init__(self, obj: 'Object'): + super().__init__(f"The given object {obj.name} is not in the prospection world.") + + +class WorldObjectNotFound(KeyError): + def __init__(self, obj: 'Object'): + super().__init__(f"The given object {obj.name} is not in the main world.") + + +class ObjectAlreadyExists(Exception): + def __init__(self, obj: 'Object'): + super().__init__(f"An object with the name {obj.name} already exists in the world.") + + +class ObjectDescriptionNotFound(KeyError): + def __init__(self, object_name: str, path: str, extension: str): + super().__init__(f"{object_name} with path {path} and extension {extension} is not in supported extensions, and" + f" the description data was not found on the ROS parameter server") + + +class WorldMismatchErrorBetweenObjects(Exception): + def __init__(self, obj_1: 'Object', obj_2: 'Object'): + super().__init__(f"World mismatch between the attached objects {obj_1.name} and {obj_2.name}," + f"obj_1.world: {obj_1.world}, obj_2.world: {obj_2.world}") + + +class ObjectFrameNotFoundError(KeyError): + def __init__(self, frame_name: str): + super().__init__(f"Frame {frame_name} does not belong to any of the objects in the world.") + + +class MultiplePossibleTipLinks(Exception): + def __init__(self, object_name: str, start_link: str, tip_links: List[str]): + super().__init__(f"Multiple possible tip links found for object {object_name} with start link {start_link}:" + f" {tip_links}") + + +class UnsupportedFileExtension(Exception): + def __init__(self, object_name: str, path: str): + extension = Path(path).suffix + super().__init__(f"Unsupported file extension for object {object_name} with path {path}" + f"and extension {extension}") + + +class ObjectDescriptionUndefined(Exception): + def __init__(self, object_name: str): + super().__init__(f"Object description for object {object_name} is not defined, eith a path or a description" + f"object should be provided.") + + +class UnsupportedJointType(Exception): + def __init__(self, joint_type: 'JointType'): + super().__init__(f"Unsupported joint type: {joint_type}") + diff --git a/src/pycram/helper.py b/src/pycram/helper.py index 73cf77dbc..d3c39176e 100644 --- a/src/pycram/helper.py +++ b/src/pycram/helper.py @@ -3,6 +3,13 @@ Classes: Singleton -- implementation of singleton metaclass """ +import os + +import rospy +from typing_extensions import Dict, Optional +import xml.etree.ElementTree as ET + + class Singleton(type): """ Metaclass for singletons @@ -16,4 +23,93 @@ class Singleton(type): def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] \ No newline at end of file + return cls._instances[cls] + + +def parse_mjcf_actuators(file_path: str) -> Dict[str, str]: + """ + Parse the actuator elements from an MJCF file. + + :param file_path: The path to the MJCF file. + """ + tree = ET.parse(file_path) + root = tree.getroot() + + joint_actuators = {} + + # Iterate through all actuator elements + for actuator in root.findall(".//actuator/*"): + name = actuator.get('name') + joint = actuator.get('joint') + if name and joint: + joint_actuators[joint] = name + + return joint_actuators + + +def get_robot_mjcf_path(robot_relative_dir: str, robot_name: str, xml_name: Optional[str] = None) -> Optional[str]: + """ + Get the path to the MJCF file of a robot. + + :param robot_relative_dir: The relative directory of the robot in the Multiverse resources/robots directory. + :param robot_name: The name of the robot. + :param xml_name: The name of the XML file of the robot. + :return: The path to the MJCF file of the robot if it exists, otherwise None. + """ + xml_name = xml_name if xml_name is not None else robot_name + if '.xml' not in xml_name: + xml_name = xml_name + '.xml' + multiverse_resources = find_multiverse_resources_path() + try: + robot_folder = os.path.join(multiverse_resources, 'robots', robot_relative_dir, robot_name) + except TypeError: + rospy.logwarn("Multiverse resources path not found.") + return None + if multiverse_resources is not None: + list_dir = os.listdir(robot_folder) + if 'mjcf' in list_dir: + if xml_name in os.listdir(robot_folder + '/mjcf'): + return os.path.join(robot_folder, 'mjcf', xml_name) + elif xml_name in os.listdir(robot_folder): + return os.path.join(robot_folder, xml_name) + return None + + +def find_multiverse_resources_path() -> Optional[str]: + """ + :return: The path to the Multiverse resources directory. + """ + # Get the path to the Multiverse installation + multiverse_path = find_multiverse_path() + + # Check if the path to the Multiverse installation was found + if multiverse_path: + # Construct the path to the resources directory + resources_path = os.path.join(multiverse_path, 'resources') + + # Check if the resources directory exists + if os.path.exists(resources_path): + return resources_path + + return None + + +def find_multiverse_path() -> Optional[str]: + """ + :return: the path to the Multiverse installation. + """ + # Get the value of PYTHONPATH environment variable + pythonpath = os.getenv('PYTHONPATH') + + # Check if PYTHONPATH is set + if pythonpath: + # Split the PYTHONPATH into individual paths using the platform-specific path separator + paths = pythonpath.split(os.pathsep) + + # Iterate through each path and check if 'Multiverse' is in it + for path in paths: + if 'multiverse' in path: + multiverse_path = path.split('multiverse')[0] + return multiverse_path + 'multiverse' + + diff --git a/src/pycram/language.py b/src/pycram/language.py index f63b56a34..5a5f36153 100644 --- a/src/pycram/language.py +++ b/src/pycram/language.py @@ -6,11 +6,11 @@ from typing_extensions import Iterable, Optional, Callable, Dict, Any, List, Union, Tuple from anytree import NodeMixin, Node, PreOrderIter -from pycram.datastructures.enums import State +from .datastructures.enums import State import threading from .fluent import Fluent -from .plan_failures import PlanFailure, NotALanguageExpression +from .failures import PlanFailure, NotALanguageExpression from .external_interfaces import giskard diff --git a/src/pycram/local_transformer.py b/src/pycram/local_transformer.py index f8aef28ac..d4e83018b 100644 --- a/src/pycram/local_transformer.py +++ b/src/pycram/local_transformer.py @@ -8,7 +8,6 @@ from tf import TransformerROS from rospy import Duration -from geometry_msgs.msg import TransformStamped from .datastructures.pose import Pose, Transform from typing_extensions import List, Optional, Union, Iterable @@ -29,6 +28,7 @@ class LocalTransformer(TransformerROS): """ _instance = None + prospection_prefix: str = "prospection/" def __new__(cls, *args, **kwargs): if not cls._instance: @@ -65,18 +65,14 @@ def transform_to_object_frame(self, pose: Pose, target_frame = world_object.tf_frame return self.transform_pose(pose, target_frame) - def update_transforms_for_objects(self, source_object_name: str, target_object_name: str) -> None: + def update_transforms_for_objects(self, object_names: List[str]) -> None: """ Updates the transforms for objects affected by the transformation. The objects are identified by their names. - :param source_object_name: Name of the object of the source frame - :param target_object_name: Name of the object of the target frame + :param object_names: List of object names for which the transforms should be updated """ - source_object = self.world.get_object_by_name(source_object_name) - target_object = self.world.get_object_by_name(target_object_name) - for obj in {source_object, target_object}: - if obj: - obj.update_link_transforms() + objects = list(map(self.world.get_object_by_name, object_names)) + [obj.update_link_transforms() for obj in objects] def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]: """ @@ -86,14 +82,15 @@ def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]: :param target_frame: Name of the TF frame into which the Pose should be transformed :return: A transformed pose in the target frame """ - self.update_transforms_for_objects(self.get_object_name_for_frame(pose.frame), - self.get_object_name_for_frame(target_frame)) + objects = list(map(self.get_object_name_for_frame, [pose.frame, target_frame])) + self.update_transforms_for_objects([obj for obj in objects if obj is not None]) copy_pose = pose.copy() copy_pose.header.stamp = rospy.Time(0) if not self.canTransform(target_frame, pose.frame, rospy.Time(0)): rospy.logerr( - f"Can not transform pose: \n {pose}\n to frame: {target_frame}.\n Maybe try calling 'update_transforms_for_object'") + f"Can not transform pose: \n {pose}\n to frame: {target_frame}." + f"\n Maybe try calling 'update_transforms_for_object'") return new_pose = super().transformPose(target_frame, copy_pose) @@ -103,14 +100,30 @@ def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]: return Pose(*copy_pose.to_list(), frame=new_pose.header.frame_id) - def get_object_name_for_frame(self, frame: str) -> str: + def get_object_name_for_frame(self, frame: str) -> Optional[str]: """ - Returns the name of the object that is associated with the given frame. + Get the name of the object that is associated with the given frame. :param frame: The frame for which the object name should be returned :return: The name of the object associated with the frame """ - return frame.split("/")[0] + world = self.prospection_world if self.prospection_prefix in frame else self.world + if frame == "map": + return None + obj_name = [obj.name for obj in world.objects if frame == obj.tf_frame] + return obj_name[0] if len(obj_name) > 0 else self.get_object_name_for_link_frame(frame) + + def get_object_name_for_link_frame(self, link_frame: str) -> Optional[str]: + """ + Get the name of the object that is associated with the given link frame. + + :param link_frame: The frame of the link for which the object name should be returned + :return: The name of the object associated with the link frame + """ + world = self.prospection_world if self.prospection_prefix in link_frame else self.world + object_name = [obj.name for obj in world.objects for link in obj.links.values() + if link_frame in (link.name, link.tf_frame)] + return object_name[0] if len(object_name) > 0 else None def lookup_transform_from_source_to_target_frame(self, source_frame: str, target_frame: str, time: Optional[rospy.rostime.Time] = None) -> Transform: @@ -123,8 +136,8 @@ def lookup_transform_from_source_to_target_frame(self, source_frame: str, target :param time: Time at which the transform should be looked up :return: The transform from source_frame to target_frame """ - self.update_transforms_for_objects(self.get_object_name_for_frame(source_frame), - self.get_object_name_for_frame(target_frame)) + objects = list(map(self.get_object_name_for_frame, [source_frame, target_frame])) + self.update_transforms_for_objects([obj for obj in objects if obj is not None]) tf_time = time if time else self.getLatestCommonTime(source_frame, target_frame) translation, rotation = self.lookupTransform(source_frame, target_frame, tf_time) @@ -142,9 +155,7 @@ def update_transforms(self, transforms: Iterable[Transform], time: rospy.Time = def get_all_frames(self) -> List[str]: """ - Returns all know coordinate frames as a list with human-readable entries. - - :return: A list of all know coordinate frames. + :return: A list of all known coordinate frames as a list with human-readable entries. """ frames = self.allFramesAsString().split("\n") frames.remove("") diff --git a/src/pycram/object_descriptors/generic.py b/src/pycram/object_descriptors/generic.py index b30b224e9..fb2b456ec 100644 --- a/src/pycram/object_descriptors/generic.py +++ b/src/pycram/object_descriptors/generic.py @@ -1,3 +1,5 @@ +from typing import Optional, Tuple + from typing_extensions import List, Any, Union, Dict from geometry_msgs.msg import Point @@ -9,11 +11,21 @@ ObjectDescription as AbstractObjectDescription +class NamedBoxVisualShape(BoxVisualShape): + def __init__(self, name: str, color: Color, visual_frame_position: List[float], half_extents: List[float]): + super().__init__(color, visual_frame_position, half_extents) + self._name: str = name + + @property + def name(self) -> str: + return self._name + + class LinkDescription(AbstractLinkDescription): - def __init__(self, name: str, visual_frame_position: List[float], half_extents: List[float], color: Color = Color()): - self.parsed_description: BoxVisualShape = BoxVisualShape(color, visual_frame_position, half_extents) - self._name: str = name + def __init__(self, name: str, visual_frame_position: List[float], half_extents: List[float], + color: Color = Color()): + super().__init__(NamedBoxVisualShape(name, color, visual_frame_position, half_extents)) @property def geometry(self) -> Union[VisualShape, None]: @@ -25,7 +37,7 @@ def origin(self) -> Pose: @property def name(self) -> str: - return self._name + return self.parsed_description.name @property def color(self) -> Color: @@ -34,6 +46,14 @@ def color(self) -> Color: class JointDescription(AbstractJointDescription): + @property + def parent(self) -> str: + raise NotImplementedError + + @property + def child(self) -> str: + raise NotImplementedError + @property def type(self) -> JointType: return JointType.UNKNOWN @@ -93,17 +113,39 @@ def load_description(self, path: str) -> Any: ... @classmethod - def generate_from_mesh_file(cls, path: str, name: str) -> str: + def generate_from_mesh_file(cls, path: str, name: str, save_path: str) -> str: raise NotImplementedError @classmethod - def generate_from_description_file(cls, path: str) -> str: + def generate_from_description_file(cls, path: str, save_path: str, make_mesh_paths_absolute: bool = True) -> str: raise NotImplementedError @classmethod - def generate_from_parameter_server(cls, name: str) -> str: + def generate_from_parameter_server(cls, name: str, save_path: str) -> str: raise NotImplementedError + @property + def parent_map(self) -> Dict[str, Tuple[str, str]]: + return {} + + @property + def link_map(self) -> Dict[str, LinkDescription]: + return {self._links[0].name: self._links[0]} + + @property + def joint_map(self) -> Dict[str, JointDescription]: + return {} + + @property + def child_map(self) -> Dict[str, List[Tuple[str, str]]]: + return {} + + def add_joint(self, name: str, child: str, joint_type: JointType, + axis: Point, parent: Optional[str] = None, origin: Optional[Pose] = None, + lower_limit: Optional[float] = None, upper_limit: Optional[float] = None, + is_virtual: Optional[bool] = False) -> None: + ... + @property def shape_data(self) -> List[float]: return self._links[0].geometry.shape_data()['halfExtents'] @@ -130,7 +172,8 @@ def get_joint_by_name(self, joint_name: str) -> JointDescription: def get_root(self) -> str: return self._links[0].name - def get_chain(self, start_link_name: str, end_link_name: str) -> List[str]: + def get_chain(self, start_link_name: str, end_link_name: str, joints: Optional[bool] = True, + links: Optional[bool] = True, fixed: Optional[bool] = True) -> List[str]: raise NotImplementedError("Do Not Do This on generic objects as they have no chains") @staticmethod diff --git a/src/pycram/object_descriptors/mjcf.py b/src/pycram/object_descriptors/mjcf.py new file mode 100644 index 000000000..d4200db1f --- /dev/null +++ b/src/pycram/object_descriptors/mjcf.py @@ -0,0 +1,474 @@ +import pathlib + +import numpy as np +import rospy +from dm_control import mjcf +from geometry_msgs.msg import Point +from typing_extensions import Union, List, Optional, Dict, Tuple + +from ..datastructures.dataclasses import Color, VisualShape, BoxVisualShape, CylinderVisualShape, \ + SphereVisualShape, MeshVisualShape +from ..datastructures.enums import JointType, MJCFGeomType, MJCFJointType +from ..datastructures.pose import Pose +from ..description import JointDescription as AbstractJointDescription, \ + LinkDescription as AbstractLinkDescription, ObjectDescription as AbstractObjectDescription +from ..failures import MultiplePossibleTipLinks + +try: + from multiverse_parser import Configuration, Factory, InertiaSource, GeomBuilder + from multiverse_parser import (WorldBuilder, + GeomType, GeomProperty, + MeshProperty, + MaterialProperty) + from multiverse_parser import MjcfExporter + from pxr import Usd, UsdGeom +except ImportError: + # do not import this module if multiverse is not found + raise ImportError("Multiverse not found.") + + +class LinkDescription(AbstractLinkDescription): + """ + A class that represents a link description of an object. + """ + + def __init__(self, mjcf_description: mjcf.Element): + super().__init__(mjcf_description) + + @property + def geometry(self) -> Union[VisualShape, None]: + """ + :return: The geometry type of the collision element of this link. + """ + return self._get_visual_shape(self.parsed_description.find_all('geom')[0]) + + @staticmethod + def _get_visual_shape(mjcf_geometry) -> Union[VisualShape, None]: + """ + :param mjcf_geometry: The MJCFGeometry to get the visual shape for. + :return: The VisualShape of the given MJCFGeometry object. + """ + if mjcf_geometry.type == MJCFGeomType.BOX.value: + return BoxVisualShape(Color(), [0, 0, 0], mjcf_geometry.size) + if mjcf_geometry.type == MJCFGeomType.CYLINDER.value: + return CylinderVisualShape(Color(), [0, 0, 0], mjcf_geometry.size[0], mjcf_geometry.size[1] * 2) + if mjcf_geometry.type == MJCFGeomType.SPHERE.value: + return SphereVisualShape(Color(), [0, 0, 0], mjcf_geometry.size[0]) + if mjcf_geometry.type == MJCFGeomType.MESH.value: + return MeshVisualShape(Color(), [0, 0, 0], mjcf_geometry.scale, mjcf_geometry.filename) + return None + + @property + def origin(self) -> Union[Pose, None]: + """ + :return: The origin of this link. + """ + return parse_pose_from_body_element(self.parsed_description) + + @property + def name(self) -> str: + return self.parsed_description.name + + +class JointDescription(AbstractJointDescription): + + mjcf_type_map = { + MJCFJointType.HINGE.value: JointType.REVOLUTE, + MJCFJointType.BALL.value: JointType.SPHERICAL, + MJCFJointType.SLIDE.value: JointType.PRISMATIC, + MJCFJointType.FREE.value: JointType.FLOATING + } + """ + A dictionary mapping the MJCF joint types to the PyCRAM joint types. + """ + + pycram_type_map = {pycram_type: mjcf_type for mjcf_type, pycram_type in mjcf_type_map.items()} + """ + A dictionary mapping the PyCRAM joint types to the MJCF joint types. + """ + + def __init__(self, mjcf_description: mjcf.Element, is_virtual: Optional[bool] = False): + super().__init__(mjcf_description, is_virtual=is_virtual) + + @property + def origin(self) -> Pose: + return parse_pose_from_body_element(self.parsed_description) + + @property + def name(self) -> str: + return self.parsed_description.name + + @property + def has_limits(self) -> bool: + return self.parsed_description.limited + + @property + def type(self) -> JointType: + """ + :return: The type of this joint. + """ + if hasattr(self.parsed_description, 'type'): + return self.mjcf_type_map[self.parsed_description.type] + else: + return self.mjcf_type_map[MJCFJointType.FREE.value] + + @property + def axis(self) -> Point: + """ + :return: The axis of this joint, for example the rotation axis for a revolute joint. + """ + return Point(*self.parsed_description.axis) + + @property + def lower_limit(self) -> Union[float, None]: + """ + :return: The lower limit of this joint, or None if the joint has no limits. + """ + if self.has_limits: + return self.parsed_description.range[0] + else: + return None + + @property + def upper_limit(self) -> Union[float, None]: + """ + :return: The upper limit of this joint, or None if the joint has no limits. + """ + if self.has_limits: + return self.parsed_description.range[1] + else: + return None + + @property + def parent(self) -> str: + """ + :return: The name of the parent link of this joint. + """ + return self._parent_link_element.parent.name + + @property + def child(self) -> str: + """ + :return: The name of the child link of this joint. + """ + return self._parent_link_element.name + + @property + def _parent_link_element(self) -> mjcf.Element: + return self.parsed_description.parent + + @property + def damping(self) -> float: + """ + :return: The damping of this joint. + """ + return self.parsed_description.damping + + @property + def friction(self) -> float: + raise NotImplementedError("Friction is not implemented for MJCF joints.") + + +class ObjectFactory(Factory): + """ + Create MJCF object descriptions from mesh files. + """ + def __init__(self, object_name: str, file_path: str, config: Configuration, texture_type: str = "png"): + super().__init__(file_path, config) + + self._world_builder = WorldBuilder(usd_file_path=self.tmp_usd_file_path) + + body_builder = self._world_builder.add_body(body_name=object_name) + + tmp_usd_mesh_file_path, tmp_origin_mesh_file_path = self.import_mesh( + mesh_file_path=file_path, merge_mesh=True) + mesh_stage = Usd.Stage.Open(tmp_usd_mesh_file_path) + for idx, mesh_prim in enumerate([prim for prim in mesh_stage.Traverse() if prim.IsA(UsdGeom.Mesh)]): + mesh_name = mesh_prim.GetName() + mesh_path = mesh_prim.GetPath() + mesh_property = MeshProperty.from_mesh_file_path(mesh_file_path=tmp_usd_mesh_file_path, + mesh_path=mesh_path) + # mesh_property._texture_coordinates = None # TODO: See if needed otherwise remove it. + geom_property = GeomProperty(geom_type=GeomType.MESH, + is_visible=False, + is_collidable=True) + geom_builder = body_builder.add_geom(geom_name=f"SM_{object_name}_mesh_{idx}", + geom_property=geom_property) + geom_builder.add_mesh(mesh_name=mesh_name, mesh_property=mesh_property) + + # Add texture if available + texture_file_path = file_path.replace(pathlib.Path(file_path).suffix, f".{texture_type}") + if pathlib.Path(texture_file_path).exists(): + self.add_material_with_texture(geom_builder=geom_builder, material_name=f"M_{object_name}_{idx}", + texture_file_path=texture_file_path) + + geom_builder.build() + + body_builder.compute_and_set_inertial(inertia_source=InertiaSource.FROM_COLLISION_MESH) + + @staticmethod + def add_material_with_texture(geom_builder: GeomBuilder, material_name: str, texture_file_path: str): + """ + Add a material with a texture to the geom builder. + + :param geom_builder: The geom builder to add the material to. + :param material_name: The name of the material. + :param texture_file_path: The path to the texture file. + """ + material_property = MaterialProperty(diffuse_color=texture_file_path, + opacity=None, + emissive_color=None, + specular_color=None) + geom_builder.add_material(material_name=material_name, + material_property=material_property) + + def export_to_mjcf(self, output_file_path: str): + """ + Export the object to a MJCF file. + + :param output_file_path: The path to the output file. + """ + exporter = MjcfExporter(self, output_file_path) + exporter.build() + exporter.export(keep_usd=False) + + +class ObjectDescription(AbstractObjectDescription): + """ + A class that represents an object description of an object. + """ + + class Link(AbstractObjectDescription.Link, LinkDescription): + ... + + class RootLink(AbstractObjectDescription.RootLink, Link): + ... + + class Joint(AbstractObjectDescription.Joint, JointDescription): + ... + + def __init__(self): + super().__init__() + self._link_map = None + self._joint_map = None + self._child_map = None + self._parent_map = None + self._links = None + self._joints = None + self.virtual_joint_names = [] + + @property + def child_map(self) -> Dict[str, List[Tuple[str, str]]]: + """ + :return: A dictionary mapping the name of a link to its children which are represented as a tuple of the child + joint name and the link name. + """ + if self._child_map is None: + self._child_map = self._construct_child_map() + return self._child_map + + def _construct_child_map(self) -> Dict[str, List[Tuple[str, str]]]: + """ + Construct the child map of the object. + """ + child_map = {} + for joint in self.joints: + if joint.parent not in child_map: + child_map[joint.parent] = [(joint.name, joint.child)] + else: + child_map[joint.parent].append((joint.name, joint.child)) + return child_map + + @property + def parent_map(self) -> Dict[str, Tuple[str, str]]: + """ + :return: A dictionary mapping the name of a link to its parent joint and link as a tuple. + """ + if self._parent_map is None: + self._parent_map = self._construct_parent_map() + return self._parent_map + + def _construct_parent_map(self) -> Dict[str, Tuple[str, str]]: + """ + Construct the parent map of the object. + """ + child_map = self.child_map + parent_map = {} + for parent, children in child_map.items(): + for child in children: + parent_map[child[1]] = (child[0], parent) + return parent_map + + @property + def link_map(self) -> Dict[str, LinkDescription]: + """ + :return: A dictionary mapping the name of a link to its description. + """ + if self._link_map is None: + self._link_map = {link.name: link for link in self.links} + return self._link_map + + @property + def joint_map(self) -> Dict[str, JointDescription]: + """ + :return: A dictionary mapping the name of a joint to its description. + """ + if self._joint_map is None: + self._joint_map = {joint.name: joint for joint in self.joints} + return self._joint_map + + def add_joint(self, name: str, child: str, joint_type: JointType, + axis: Point, parent: Optional[str] = None, origin: Optional[Pose] = None, + lower_limit: Optional[float] = None, upper_limit: Optional[float] = None, + is_virtual: Optional[bool] = False) -> None: + """ + Finds the child link and adds a joint to it in the object description. + for arguments documentation see :meth:`pycram.description.ObjectDescription.add_joint` + """ + + position: Optional[List[float]] = None + quaternion: Optional[List[float]] = None + lower_limit: float = 0.0 if lower_limit is None else lower_limit + upper_limit: float = 0.0 if upper_limit is None else upper_limit + limit = [lower_limit, upper_limit] + + if origin is not None: + position = origin.position_as_list() + quaternion = origin.orientation_as_list() + quaternion = [quaternion[1], quaternion[2], quaternion[3], quaternion[0]] + if axis is not None: + axis = [axis.x, axis.y, axis.z] + self.parsed_description.find(child).add('joint', name=name, type=JointDescription.pycram_type_map[joint_type], + axis=axis, pos=position, quat=quaternion, range=limit) + if is_virtual: + self.virtual_joint_names.append(name) + + def load_description(self, path) -> mjcf.RootElement: + return mjcf.from_file(path, model_dir=pathlib.Path(path).parent) + + def load_description_from_string(self, description_string: str) -> mjcf.RootElement: + return mjcf.from_xml_string(description_string) + + def generate_from_mesh_file(self, path: str, name: str, color: Optional[Color] = Color(), + save_path: Optional[str] = None) -> None: + """ + Generate a mjcf xml file with the given .obj or .stl file as mesh. In addition, use the given rgba_color + to create a material tag in the xml. + + :param path: The path to the mesh file. + :param name: The name of the object. + :param color: The color of the object. + :param save_path: The path to save the generated xml file. + """ + factory = ObjectFactory(object_name=name, file_path=path, + config=Configuration(model_name=name, + fixed_base=False, + default_rgba=np.array(color.get_rgba()))) + factory.export_to_mjcf(output_file_path=save_path) + + def generate_from_description_file(self, path: str, save_path: str, make_mesh_paths_absolute: bool = True) -> None: + mjcf_model = mjcf.from_file(path) + self.write_description_to_file(mjcf_model, save_path) + + def generate_from_parameter_server(self, name: str, save_path: str) -> None: + mjcf_string = rospy.get_param(name) + self.write_description_to_file(mjcf_string, save_path) + + @property + def joints(self) -> List[JointDescription]: + """ + :return: A list of joints descriptions of this object. + """ + if self._joints is None: + self._joints = [JointDescription(joint) for joint in self.parsed_description.find_all('joint')] + return self._joints + + @property + def links(self) -> List[LinkDescription]: + """ + :return: A list of link descriptions of this object. + """ + if self._links is None: + self._links = [LinkDescription(link) for link in self.parsed_description.find_all('body')] + return self._links + + def get_root(self) -> str: + """ + :return: the name of the root link of this object. + """ + if len(self.links) == 1: + return self.links[0].name + elif len(self.links) > 1: + return self.links[1].name + else: + raise ValueError("No links found in the object description.") + + def get_tip(self) -> str: + """ + :return: the name of the tip link of this object. + :raises MultiplePossibleTipLinks: If there are multiple possible tip links. + """ + link = self.get_root() + while link in self.child_map: + children = self.child_map[link] + if len(children) > 1: + # Multiple children, can't decide which one to take (e.g. fingers of a hand) + raise MultiplePossibleTipLinks(self.name, link, [child[1] for child in children]) + else: + child = children[0][1] + link = child + return link + + def get_chain(self, start_link_name: str, end_link_name: str, joints: Optional[bool] = True, + links: Optional[bool] = True, fixed: Optional[bool] = True) -> List[str]: + """ + :param start_link_name: The name of the start link of the chain. + :param end_link_name: The name of the end link of the chain. + :param joints: Whether to include joints in the chain. + :param links: Whether to include links in the chain. + :param fixed: Whether to include fixed joints in the chain (Note: not used in MJCF). + :return: the chain of links from 'start_link_name' to 'end_link_name'. + """ + chain = [] + if links: + chain.append(end_link_name) + link = end_link_name + while link != start_link_name: + (joint, parent) = self.parent_map[link] + if joints: + chain.append(joint) + if links: + chain.append(parent) + link = parent + chain.reverse() + return chain + + @staticmethod + def get_file_extension() -> str: + """ + :return: The file extension of the URDF file. + """ + return '.xml' + + @property + def origin(self) -> Pose: + return parse_pose_from_body_element(self.parsed_description) + + @property + def name(self) -> str: + return self.parsed_description.name + + +def parse_pose_from_body_element(body: mjcf.Element) -> Pose: + """ + Parse the pose from a body element. + + :param body: The body element. + :return: The pose of the body. + """ + position = body.pos + quaternion = body.quat + position = [0, 0, 0] if position is None else position + quaternion = [1, 0, 0, 0] if quaternion is None else quaternion + quaternion = [quaternion[1], quaternion[2], quaternion[3], quaternion[0]] + return Pose(position, quaternion) diff --git a/src/pycram/object_descriptors/urdf.py b/src/pycram/object_descriptors/urdf.py index 75ab98a03..50d11608f 100644 --- a/src/pycram/object_descriptors/urdf.py +++ b/src/pycram/object_descriptors/urdf.py @@ -1,11 +1,13 @@ +import os import pathlib -from xml.etree import ElementTree +import xml.etree.ElementTree as ET +import numpy as np import rospkg import rospy from geometry_msgs.msg import Point -from tf.transformations import quaternion_from_euler -from typing_extensions import Union, List, Optional +from tf.transformations import quaternion_from_euler, euler_from_quaternion +from typing_extensions import Union, List, Optional, Dict, Tuple from urdf_parser_py import urdf from urdf_parser_py.urdf import (URDF, Collision, Box as URDF_Box, Cylinder as URDF_Cylinder, Sphere as URDF_Sphere, Mesh as URDF_Mesh) @@ -16,6 +18,7 @@ LinkDescription as AbstractLinkDescription, ObjectDescription as AbstractObjectDescription from ..datastructures.dataclasses import Color, VisualShape, BoxVisualShape, CylinderVisualShape, \ SphereVisualShape, MeshVisualShape +from ..failures import MultiplePossibleTipLinks from ..utils import suppress_stdout_stderr @@ -30,7 +33,7 @@ def __init__(self, urdf_description: urdf.Link): @property def geometry(self) -> Union[VisualShape, None]: """ - Returns the geometry type of the URDF collision element of this link. + :return: The geometry type of the URDF collision element of this link. """ if self.collision is None: return None @@ -40,10 +43,12 @@ def geometry(self) -> Union[VisualShape, None]: @staticmethod def _get_visual_shape(urdf_geometry) -> Union[VisualShape, None]: """ - Returns the VisualShape of the given URDF geometry. + :param urdf_geometry: The URDFGeometry for which the visual shape is returned. + :return: the VisualShape of the given URDF geometry. """ if isinstance(urdf_geometry, URDF_Box): - return BoxVisualShape(Color(), [0, 0, 0], urdf_geometry.size) + half_extents = np.array(urdf_geometry.size) / 2 + return BoxVisualShape(Color(), [0, 0, 0], half_extents.tolist()) if isinstance(urdf_geometry, URDF_Cylinder): return CylinderVisualShape(Color(), [0, 0, 0], urdf_geometry.radius, urdf_geometry.length) if isinstance(urdf_geometry, URDF_Sphere): @@ -79,8 +84,10 @@ class JointDescription(AbstractJointDescription): 'planar': JointType.PLANAR, 'fixed': JointType.FIXED} - def __init__(self, urdf_description: urdf.Joint): - super().__init__(urdf_description) + pycram_type_map = {pycram_type: urdf_type for urdf_type, pycram_type in urdf_type_map.items()} + + def __init__(self, urdf_description: urdf.Joint, is_virtual: Optional[bool] = False): + super().__init__(urdf_description, is_virtual=is_virtual) @property def origin(self) -> Pose: @@ -130,14 +137,14 @@ def upper_limit(self) -> Union[float, None]: return None @property - def parent_link_name(self) -> str: + def parent(self) -> str: """ :return: The name of the parent link of this joint. """ return self.parsed_description.parent @property - def child_link_name(self) -> str: + def child(self) -> str: """ :return: The name of the child link of this joint. """ @@ -172,21 +179,83 @@ class RootLink(AbstractObjectDescription.RootLink, Link): class Joint(AbstractObjectDescription.Joint, JointDescription): ... + @property + def child_map(self) -> Dict[str, List[Tuple[str, str]]]: + """ + :return: A dictionary mapping the name of a link to its children which are represented as a tuple of the child + joint name and the link name. + """ + return self.parsed_description.child_map + + @property + def parent_map(self) -> Dict[str, Tuple[str, str]]: + """ + :return: A dictionary mapping the name of a link to its parent joint and link as a tuple. + """ + return self.parsed_description.parent_map + + @property + def link_map(self) -> Dict[str, LinkDescription]: + """ + :return: A dictionary mapping the name of a link to its description. + """ + if self._link_map is None: + self._link_map = {link.name: link for link in self.links} + return self._link_map + + @property + def joint_map(self) -> Dict[str, JointDescription]: + """ + :return: A dictionary mapping the name of a joint to its description. + """ + if self._joint_map is None: + self._joint_map = {joint.name: joint for joint in self.joints} + return self._joint_map + + def add_joint(self, name: str, child: str, joint_type: JointType, + axis: Point, parent: Optional[str] = None, origin: Optional[Pose] = None, + lower_limit: Optional[float] = None, upper_limit: Optional[float] = None, + is_virtual: Optional[bool] = False) -> None: + """ + Add a joint to the object description, could be a virtual joint as well. + For documentation of the parameters, see :meth:`pycram.description.ObjectDescription.add_joint`. + """ + if lower_limit is not None or upper_limit is not None: + limit = urdf.JointLimit(lower=lower_limit, upper=upper_limit) + else: + limit = None + if origin is not None: + origin = urdf.Pose(origin.position_as_list(), euler_from_quaternion(origin.orientation_as_list())) + if axis is not None: + axis = [axis.x, axis.y, axis.z] + if parent is None: + parent = self.get_root() + else: + parent = self.get_link_by_name(parent).parsed_description + joint = urdf.Joint(name, + parent, + self.get_link_by_name(child).parsed_description, + JointDescription.pycram_type_map[joint_type], + axis, origin, limit) + self.parsed_description.add_joint(joint) + if is_virtual: + self.virtual_joint_names.append(name) + def load_description(self, path) -> URDF: with open(path, 'r') as file: # Since parsing URDF causes a lot of warning messages which can't be deactivated, we suppress them with suppress_stdout_stderr(): return URDF.from_xml_string(file.read()) - def generate_from_mesh_file(self, path: str, name: str, color: Optional[Color] = Color()) -> str: + def generate_from_mesh_file(self, path: str, name: str, save_path: str, color: Optional[Color] = Color()) -> None: """ - Generates an URDf file with the given .obj or .stl file as mesh. In addition, the given rgba_color will be - used to create a material tag in the URDF. + Generate a URDf file with the given .obj or .stl file as mesh. In addition, use the given rgba_color to create a + material tag in the URDF. The URDF file will be saved to the given save_path. :param path: The path to the mesh file. :param name: The name of the object. + :param save_path: The path to save the URDF file to. :param color: The color of the object. - :return: The absolute path of the created file """ urdf_template = ' \n \ \n \ @@ -211,55 +280,45 @@ def generate_from_mesh_file(self, path: str, name: str, color: Optional[Color] = pathlib_obj = pathlib.Path(path) path = str(pathlib_obj.resolve()) content = urdf_template.replace("~a", name).replace("~b", path).replace("~c", rgb) - return content + self.write_description_to_file(content, save_path) - def generate_from_description_file(self, path: str) -> str: + def generate_from_description_file(self, path: str, save_path: str, make_mesh_paths_absolute: bool = True) -> None: with open(path, mode="r") as f: urdf_string = self.fix_missing_inertial(f.read()) - urdf_string = self.remove_error_tags(urdf_string) - urdf_string = self.fix_link_attributes(urdf_string) - try: - urdf_string = self.correct_urdf_string(urdf_string) - except rospkg.ResourceNotFound as e: - rospy.logerr(f"Could not find resource package linked in this URDF") - raise e - return urdf_string - - def generate_from_parameter_server(self, name: str) -> str: + urdf_string = self.remove_error_tags(urdf_string) + urdf_string = self.fix_link_attributes(urdf_string) + try: + urdf_string = self.replace_relative_references_with_absolute_paths(urdf_string) + urdf_string = self.fix_missing_inertial(urdf_string) + except rospkg.ResourceNotFound as e: + rospy.logerr(f"Could not find resource package linked in this URDF") + raise e + urdf_string = self.make_mesh_paths_absolute(urdf_string, path) if make_mesh_paths_absolute else urdf_string + self.write_description_to_file(urdf_string, save_path) + + def generate_from_parameter_server(self, name: str, save_path: str) -> None: urdf_string = rospy.get_param(name) - return self.correct_urdf_string(urdf_string) - - def get_link_by_name(self, link_name: str) -> LinkDescription: - """ - :return: The link description with the given name. - """ - for link in self.links: - if link.name == link_name: - return link - raise ValueError(f"Link with name {link_name} not found") + urdf_string = self.replace_relative_references_with_absolute_paths(urdf_string) + urdf_string = self.fix_missing_inertial(urdf_string) + self.write_description_to_file(urdf_string, save_path) @property - def links(self) -> List[LinkDescription]: - """ - :return: A list of links descriptions of this object. - """ - return [LinkDescription(link) for link in self.parsed_description.links] - - def get_joint_by_name(self, joint_name: str) -> JointDescription: + def joints(self) -> List[JointDescription]: """ - :return: The joint description with the given name. + :return: A list of joints descriptions of this object. """ - for joint in self.joints: - if joint.name == joint_name: - return joint - raise ValueError(f"Joint with name {joint_name} not found") + if self._joints is None: + self._joints = [JointDescription(joint) for joint in self.parsed_description.joints] + return self._joints @property - def joints(self) -> List[JointDescription]: + def links(self) -> List[LinkDescription]: """ - :return: A list of joints descriptions of this object. + :return: A list of link descriptions of this object. """ - return [JointDescription(joint) for joint in self.parsed_description.joints] + if self._links is None: + self._links = [LinkDescription(link) for link in self.parsed_description.links] + return self._links def get_root(self) -> str: """ @@ -267,16 +326,39 @@ def get_root(self) -> str: """ return self.parsed_description.get_root() - def get_chain(self, start_link_name: str, end_link_name: str) -> List[str]: - """ + def get_tip(self) -> str: + """ + :return: the name of the tip link of this object. + :raises MultiplePossibleTipLinks: If there are multiple possible tip links. + """ + link = self.get_root() + while link in self.parsed_description.child_map: + children = self.parsed_description.child_map[link] + if len(children) > 1: + # Multiple children, can't decide which one to take (e.g. fingers of a hand) + raise MultiplePossibleTipLinks(self.parsed_description.name, link, [child[1] for child in children]) + else: + child = children[0][1] + link = child + return link + + def get_chain(self, start_link_name: str, end_link_name: str, joints: Optional[bool] = True, + links: Optional[bool] = True, fixed: Optional[bool] = True) -> List[str]: + """ + :param start_link_name: The name of the start link of the chain. + :param end_link_name: The name of the end link of the chain. + :param joints: Whether to include joints in the chain. + :param links: Whether to include links in the chain. + :param fixed: Whether to include fixed joints in the chain. :return: the chain of links from 'start_link_name' to 'end_link_name'. """ - return self.parsed_description.get_chain(start_link_name, end_link_name) + return self.parsed_description.get_chain(start_link_name, end_link_name, joints, links, fixed) - def correct_urdf_string(self, urdf_string: str) -> str: + @staticmethod + def replace_relative_references_with_absolute_paths(urdf_string: str) -> str: """ - Changes paths for files in the URDF from ROS paths to paths in the file system. Since World (PyBullet legacy) - can't deal with ROS package paths. + Change paths for files in the URDF from ROS paths and file dir references to paths in the file system. Since + World (PyBullet legacy) can't deal with ROS package paths. :param urdf_string: The name of the URDf on the parameter server :return: The URDF string with paths in the filesystem instead of ROS packages @@ -289,9 +371,37 @@ def correct_urdf_string(self, urdf_string: str) -> str: s1 = s[1].split('/') path = r.get_path(s1[0]) line = line.replace("package://" + s1[0], path) + if 'file://' in line: + line = line.replace("file://", './') new_urdf_string += line + '\n' - return self.fix_missing_inertial(new_urdf_string) + return new_urdf_string + + @staticmethod + def make_mesh_paths_absolute(urdf_string: str, urdf_file_path: str) -> str: + """ + Convert all relative mesh paths in the URDF to absolute paths. + + :param urdf_string: The URDF description as string + :param urdf_file_path: The path to the URDF file + :returns: The new URDF description as string. + """ + # Parse the URDF file + root = ET.fromstring(urdf_string) + + # Iterate through all mesh tags + for mesh in root.findall('.//mesh'): + filename = mesh.attrib.get('filename', '') + if filename: + # If the filename is a relative path, convert it to an absolute path + if not os.path.isabs(filename): + # Deduce the base path from the relative path + base_path = os.path.dirname( + os.path.abspath(os.path.join(os.path.dirname(urdf_file_path), filename))) + abs_path = os.path.abspath(os.path.join(base_path, os.path.basename(filename))) + mesh.set('filename', abs_path) + + return ET.tostring(root, encoding='unicode') @staticmethod def fix_missing_inertial(urdf_string: str) -> str: @@ -303,10 +413,10 @@ def fix_missing_inertial(urdf_string: str) -> str: :returns: The new, corrected URDF description as string. """ - inertia_tree = ElementTree.ElementTree(ElementTree.Element("inertial")) - inertia_tree.getroot().append(ElementTree.Element("mass", {"value": "0.1"})) - inertia_tree.getroot().append(ElementTree.Element("origin", {"rpy": "0 0 0", "xyz": "0 0 0"})) - inertia_tree.getroot().append(ElementTree.Element("inertia", {"ixx": "0.01", + inertia_tree = ET.ElementTree(ET.Element("inertial")) + inertia_tree.getroot().append(ET.Element("mass", {"value": "0.1"})) + inertia_tree.getroot().append(ET.Element("origin", {"rpy": "0 0 0", "xyz": "0 0 0"})) + inertia_tree.getroot().append(ET.Element("inertia", {"ixx": "0.01", "ixy": "0", "ixz": "0", "iyy": "0.01", @@ -314,48 +424,48 @@ def fix_missing_inertial(urdf_string: str) -> str: "izz": "0.01"})) # create tree from string - tree = ElementTree.ElementTree(ElementTree.fromstring(urdf_string)) + tree = ET.ElementTree(ET.fromstring(urdf_string)) for link_element in tree.iter("link"): inertial = [*link_element.iter("inertial")] if len(inertial) == 0: link_element.append(inertia_tree.getroot()) - return ElementTree.tostring(tree.getroot(), encoding='unicode') + return ET.tostring(tree.getroot(), encoding='unicode') @staticmethod def remove_error_tags(urdf_string: str) -> str: """ - Removes all tags in the removing_tags list from the URDF since these tags are known to cause errors with the + Remove all tags in the removing_tags list from the URDF since these tags are known to cause errors with the URDF_parser :param urdf_string: String of the URDF from which the tags should be removed :return: The URDF string with the tags removed """ - tree = ElementTree.ElementTree(ElementTree.fromstring(urdf_string)) + tree = ET.ElementTree(ET.fromstring(urdf_string)) removing_tags = ["gazebo", "transmission"] for tag_name in removing_tags: all_tags = tree.findall(tag_name) for tag in all_tags: tree.getroot().remove(tag) - return ElementTree.tostring(tree.getroot(), encoding='unicode') + return ET.tostring(tree.getroot(), encoding='unicode') @staticmethod def fix_link_attributes(urdf_string: str) -> str: """ - Removes the attribute 'type' from links since this is not parsable by the URDF parser. + Remove the attribute 'type' from links since this is not parsable by the URDF parser. :param urdf_string: The string of the URDF from which the attributes should be removed :return: The URDF string with the attributes removed """ - tree = ElementTree.ElementTree(ElementTree.fromstring(urdf_string)) + tree = ET.ElementTree(ET.fromstring(urdf_string)) for link in tree.iter("link"): if "type" in link.attrib.keys(): del link.attrib["type"] - return ElementTree.tostring(tree.getroot(), encoding='unicode') + return ET.tostring(tree.getroot(), encoding='unicode') @staticmethod def get_file_extension() -> str: diff --git a/src/pycram/pose_generator_and_validator.py b/src/pycram/pose_generator_and_validator.py index a1aee7d78..6672b6c6c 100644 --- a/src/pycram/pose_generator_and_validator.py +++ b/src/pycram/pose_generator_and_validator.py @@ -1,16 +1,16 @@ import numpy as np import tf -from typing_extensions import Tuple, List, Union, Dict, Iterable -from .datastructures.pose import Pose, Transform from .datastructures.world import World -from .external_interfaces.ik import request_ik -from .local_transformer import LocalTransformer -from .plan_failures import IKError -from .robot_description import RobotDescription from .world_concepts.world_object import Object from .world_reasoning import contact from .costmaps import Costmap +from .local_transformer import LocalTransformer +from .datastructures.pose import Pose, Transform +from .robot_description import RobotDescription +from .external_interfaces.ik import request_ik +from .failures import IKError +from typing_extensions import Tuple, List, Union, Dict, Iterable class PoseGenerator: @@ -119,13 +119,13 @@ def visibility_validator(pose: Pose, robot_pose = robot.get_pose() if isinstance(object_or_pose, Object): robot.set_pose(pose) - camera_pose = robot.get_link_pose(RobotDescription.current_robot_description.get_camera_frame()) + camera_pose = robot.get_link_pose(RobotDescription.current_robot_description.get_camera_link()) robot.set_pose(Pose([100, 100, 0], [0, 0, 0, 1])) ray = world.ray_test(camera_pose.position_as_list(), object_or_pose.get_position_as_list()) res = ray == object_or_pose.id else: robot.set_pose(pose) - camera_pose = robot.get_link_pose(RobotDescription.current_robot_description.get_camera_frame()) + camera_pose = robot.get_link_pose(RobotDescription.current_robot_description.get_camera_link()) robot.set_pose(Pose([100, 100, 0], [0, 0, 0, 1])) # TODO: Check if this is correct ray = world.ray_test(camera_pose.position_as_list(), object_or_pose) @@ -203,14 +203,14 @@ def reachability_validator(pose: Pose, # test the possible solution and apply it to the robot pose, joint_states = request_ik(target, robot, joints, tool_frame) robot.set_pose(pose) - robot.set_joint_positions(joint_states) + robot.set_multiple_joint_positions(joint_states) # _apply_ik(robot, resp, joints) in_contact = collision_check(robot, allowed_collision) if not in_contact: # only check for retract pose if pose worked pose, joint_states = request_ik(retract_target_pose, robot, joints, tool_frame) robot.set_pose(pose) - robot.set_joint_positions(joint_states) + robot.set_multiple_joint_positions(joint_states) # _apply_ik(robot, resp, joints) in_contact = collision_check(robot, allowed_collision) if not in_contact: @@ -218,7 +218,7 @@ def reachability_validator(pose: Pose, except IKError: pass finally: - robot.set_joint_positions(joint_state_before_ik) + robot.set_multiple_joint_positions(joint_state_before_ik) if arms: res = True return res, arms @@ -246,5 +246,6 @@ def collision_check(robot: Object, allowed_collision: Dict[Object, List]): if obj.name == "floor": continue in_contact = _in_contact(robot, obj, allowed_collision, allowed_robot_links) - + if in_contact: + break return in_contact diff --git a/src/pycram/process_module.py b/src/pycram/process_module.py index eca953db6..2e89e303f 100644 --- a/src/pycram/process_module.py +++ b/src/pycram/process_module.py @@ -277,8 +277,6 @@ def __init__(self, robot_name): @staticmethod def get_manager() -> Union[ProcessModuleManager, None]: """ - Returns the Process Module manager for the currently loaded robot or None if there is no Manager. - :return: ProcessModuleManager instance of the current robot """ manager = None @@ -307,7 +305,7 @@ def get_manager() -> Union[ProcessModuleManager, None]: def navigate(self) -> Type[ProcessModule]: """ - Returns the Process Module for navigating the robot with respect to + Get the Process Module for navigating the robot with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for navigating @@ -317,7 +315,7 @@ def navigate(self) -> Type[ProcessModule]: def pick_up(self) -> Type[ProcessModule]: """ - Returns the Process Module for picking up with respect to the :py:attr:`~ProcessModuleManager.execution_type` + Get the Process Module for picking up with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for picking up an object """ @@ -326,7 +324,7 @@ def pick_up(self) -> Type[ProcessModule]: def place(self) -> Type[ProcessModule]: """ - Returns the Process Module for placing with respect to the :py:attr:`~ProcessModuleManager.execution_type` + Get the Process Module for placing with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for placing an Object """ @@ -335,7 +333,7 @@ def place(self) -> Type[ProcessModule]: def looking(self) -> Type[ProcessModule]: """ - Returns the Process Module for looking at a point with respect to + Get the Process Module for looking at a point with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for looking at a specific point @@ -345,7 +343,7 @@ def looking(self) -> Type[ProcessModule]: def detecting(self) -> Type[ProcessModule]: """ - Returns the Process Module for detecting an object with respect to + Get the Process Module for detecting an object with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for detecting an object @@ -355,7 +353,7 @@ def detecting(self) -> Type[ProcessModule]: def move_tcp(self) -> Type[ProcessModule]: """ - Returns the Process Module for moving the Tool Center Point with respect to + Get the Process Module for moving the Tool Center Point with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for moving the TCP @@ -365,7 +363,7 @@ def move_tcp(self) -> Type[ProcessModule]: def move_arm_joints(self) -> Type[ProcessModule]: """ - Returns the Process Module for moving the joints of the robot arm + Get the Process Module for moving the joints of the robot arm with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for moving the arm joints @@ -375,7 +373,7 @@ def move_arm_joints(self) -> Type[ProcessModule]: def world_state_detecting(self) -> Type[ProcessModule]: """ - Returns the Process Module for detecting an object using the world state with respect to the + Get the Process Module for detecting an object using the world state with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for world state detecting @@ -385,7 +383,7 @@ def world_state_detecting(self) -> Type[ProcessModule]: def move_joints(self) -> Type[ProcessModule]: """ - Returns the Process Module for moving any joint of the robot with respect to the + Get the Process Module for moving any joint of the robot with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for moving joints @@ -395,7 +393,7 @@ def move_joints(self) -> Type[ProcessModule]: def move_gripper(self) -> Type[ProcessModule]: """ - Returns the Process Module for moving the gripper with respect to + Get the Process Module for moving the gripper with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for moving the gripper @@ -405,7 +403,7 @@ def move_gripper(self) -> Type[ProcessModule]: def open(self) -> Type[ProcessModule]: """ - Returns the Process Module for opening drawers with respect to + Get the Process Module for opening drawers with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for opening drawers @@ -415,7 +413,7 @@ def open(self) -> Type[ProcessModule]: def close(self) -> Type[ProcessModule]: """ - Returns the Process Module for closing drawers with respect to + Get the Process Module for closing drawers with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for closing drawers diff --git a/src/pycram/process_modules/boxy_process_modules.py b/src/pycram/process_modules/boxy_process_modules.py index b8abbb0b2..4cb8b6e89 100644 --- a/src/pycram/process_modules/boxy_process_modules.py +++ b/src/pycram/process_modules/boxy_process_modules.py @@ -89,13 +89,13 @@ def _execute(self, desig): pose_in_shoulder = local_transformer.transform_pose(target, robot.get_link_tf_frame("neck_shoulder_link")) if pose_in_shoulder.position.x >= 0 and pose_in_shoulder.position.x >= abs(pose_in_shoulder.position.y): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "front")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "front")) if pose_in_shoulder.position.y >= 0 and pose_in_shoulder.position.y >= abs(pose_in_shoulder.position.x): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "neck_right")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "neck_right")) if pose_in_shoulder.position.x <= 0 and abs(pose_in_shoulder.position.x) > abs(pose_in_shoulder.position.y): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "back")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "back")) if pose_in_shoulder.position.y <= 0 and abs(pose_in_shoulder.position.y) > abs(pose_in_shoulder.position.x): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "neck_left")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "neck_left")) pose_in_shoulder = local_transformer.transform_pose(target, robot.get_link_tf_frame("neck_shoulder_link")) @@ -115,7 +115,7 @@ def _execute(self, desig): robot = World.robot gripper = desig.gripper motion = desig.motion - robot.set_joint_positions(RobotDescription.current_robot_description.kinematic_chains[gripper].get_static_gripper_state(motion)) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.kinematic_chains[gripper].get_static_gripper_state(motion)) class BoxyDetecting(ProcessModule): @@ -160,9 +160,9 @@ def _execute(self, desig: MoveArmJointsMotion): robot = World.robot if desig.right_arm_poses: - robot.set_joint_positions(desig.right_arm_poses) + robot.set_multiple_joint_positions(desig.right_arm_poses) if desig.left_arm_poses: - robot.set_joint_positions(desig.left_arm_poses) + robot.set_multiple_joint_positions(desig.left_arm_poses) class BoxyWorldStateDetecting(ProcessModule): diff --git a/src/pycram/process_modules/donbot_process_modules.py b/src/pycram/process_modules/donbot_process_modules.py index e3fcec033..e39b5bcb3 100644 --- a/src/pycram/process_modules/donbot_process_modules.py +++ b/src/pycram/process_modules/donbot_process_modules.py @@ -69,13 +69,13 @@ def _execute(self, desig): pose_in_shoulder = local_transformer.transform_pose(target, robot.get_link_tf_frame("ur5_shoulder_link")) if pose_in_shoulder.position.x >= 0 and pose_in_shoulder.position.x >= abs(pose_in_shoulder.position.y): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "front")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "front")) if pose_in_shoulder.position.y >= 0 and pose_in_shoulder.position.y >= abs(pose_in_shoulder.position.x): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "arm_right")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "arm_right")) if pose_in_shoulder.position.x <= 0 and abs(pose_in_shoulder.position.x) > abs(pose_in_shoulder.position.y): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "back")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "back")) if pose_in_shoulder.position.y <= 0 and abs(pose_in_shoulder.position.y) > abs(pose_in_shoulder.position.x): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "arm_left")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "arm_left")) pose_in_shoulder = local_transformer.transform_pose(target, robot.get_link_tf_frame("ur5_shoulder_link")) @@ -94,7 +94,7 @@ def _execute(self, desig): robot = World.robot gripper = desig.gripper motion = desig.motion - robot.set_joint_positions(RobotDescription.current_robot_description.get_arm_chain(gripper).get_static_gripper_state(motion)) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_arm_chain(gripper).get_static_gripper_state(motion)) class DonbotMoveTCP(ProcessModule): @@ -118,7 +118,7 @@ class DonbotMoveJoints(ProcessModule): def _execute(self, desig: MoveArmJointsMotion): robot = World.robot if desig.left_arm_poses: - robot.set_joint_positions(desig.left_arm_poses) + robot.set_multiple_joint_positions(desig.left_arm_poses) class DonbotWorldStateDetecting(ProcessModule): diff --git a/src/pycram/process_modules/pr2_process_modules.py b/src/pycram/process_modules/pr2_process_modules.py index 7e7592cb7..c87accdfe 100644 --- a/src/pycram/process_modules/pr2_process_modules.py +++ b/src/pycram/process_modules/pr2_process_modules.py @@ -87,7 +87,7 @@ def _execute(self, desig: DetectingMotion): robot = World.robot object_type = desig.object_type # Should be "wide_stereo_optical_frame" - cam_frame_name = RobotDescription.current_robot_description.get_camera_frame() + camera_link_name = RobotDescription.current_robot_description.get_camera_link() # should be [0, 0, 1] camera_description = RobotDescription.current_robot_description.cameras[ list(RobotDescription.current_robot_description.cameras.keys())[0]] @@ -95,7 +95,7 @@ def _execute(self, desig: DetectingMotion): objects = World.current_world.get_object_by_type(object_type) for obj in objects: - if btr.visible(obj, robot.get_link_pose(cam_frame_name), front_facing_axis): + if btr.visible(obj, robot.get_link_pose(camera_link_name), front_facing_axis): return obj @@ -121,9 +121,9 @@ def _execute(self, desig: MoveArmJointsMotion): robot = World.robot if desig.right_arm_poses: - robot.set_joint_positions(desig.right_arm_poses) + robot.set_multiple_joint_positions(desig.right_arm_poses) if desig.left_arm_poses: - robot.set_joint_positions(desig.left_arm_poses) + robot.set_multiple_joint_positions(desig.left_arm_poses) class PR2MoveJoints(ProcessModule): @@ -133,7 +133,7 @@ class PR2MoveJoints(ProcessModule): def _execute(self, desig: MoveJointsMotion): robot = World.robot - robot.set_joint_positions(dict(zip(desig.names, desig.positions))) + robot.set_multiple_joint_positions(dict(zip(desig.names, desig.positions))) class Pr2WorldStateDetecting(ProcessModule): diff --git a/src/pycram/process_modules/stretch_process_modules.py b/src/pycram/process_modules/stretch_process_modules.py index dac9bb24c..9a7fa9212 100644 --- a/src/pycram/process_modules/stretch_process_modules.py +++ b/src/pycram/process_modules/stretch_process_modules.py @@ -132,7 +132,7 @@ def _move_arm_tcp(target: Pose, robot: Object, arm: Arms) -> None: # inv = request_ik(target, robot, joints, gripper) pose, joint_states = request_giskard_ik(target, robot, gripper) robot.set_pose(pose) - robot.set_joint_positions(joint_states) + robot.set_multiple_joint_positions(joint_states) ########################################################### diff --git a/src/pycram/robot_description.py b/src/pycram/robot_description.py index ce9805395..c6e41f9d3 100644 --- a/src/pycram/robot_description.py +++ b/src/pycram/robot_description.py @@ -3,10 +3,12 @@ import rospy from typing_extensions import List, Dict, Union, Optional -from urdf_parser_py.urdf import URDF +from .datastructures.dataclasses import VirtualMobileBaseJoints +from .datastructures.enums import Arms, Grasp, GripperState, GripperType, JointType +from .object_descriptors.urdf import ObjectDescription as URDFObject from .utils import suppress_stdout_stderr -from .datastructures.enums import Arms, Grasp, GripperState, GripperType +from .helper import parse_mjcf_actuators class RobotDescriptionManager: @@ -42,7 +44,12 @@ def load_description(self, name: str): RobotDescription.current_robot_description = self.descriptions[name] return self.descriptions[name] else: - rospy.logerr(f"Robot description {name} not found") + for key in self.descriptions.keys(): + if key in name.lower(): + RobotDescription.current_robot_description = self.descriptions[key] + return self.descriptions[key] + else: + rospy.logerr(f"Robot description {name} not found") def register_description(self, description: RobotDescription): """ @@ -81,7 +88,7 @@ class RobotDescription: """ Torso joint of the robot """ - urdf_object: URDF + urdf_object: URDFObject """ Parsed URDF of the robot """ @@ -105,8 +112,14 @@ class RobotDescription: """ All joints defined in the URDF, by default fixed joints are not included """ + virtual_mobile_base_joints: Optional[VirtualMobileBaseJoints] = None + """ + Virtual mobile base joint names for mobile robots, these joints are not part of the URDF, however they are used to + move the robot in the simulation (e.g. set_pose for the robot would actually move these joints) + """ - def __init__(self, name: str, base_link: str, torso_link: str, torso_joint: str, urdf_path: str): + def __init__(self, name: str, base_link: str, torso_link: str, torso_joint: str, urdf_path: str, + virtual_mobile_base_joints: Optional[VirtualMobileBaseJoints] = None, mjcf_path: Optional[str] = None): """ Initialize the RobotDescription. The URDF is loaded from the given path and used as basis for the kinematic chains. @@ -116,6 +129,8 @@ def __init__(self, name: str, base_link: str, torso_link: str, torso_joint: str, :param torso_link: Torso link of the robot :param torso_joint: Torso joint of the robot, this is the joint that moves the torso upwards if there is one :param urdf_path: Path to the URDF file of the robot + :param virtual_mobile_base_joints: Virtual mobile base joint names for mobile robots + :param mjcf_path: Path to the MJCF file of the robot """ self.name = name self.base_link = base_link @@ -123,12 +138,35 @@ def __init__(self, name: str, base_link: str, torso_link: str, torso_joint: str, self.torso_joint = torso_joint with suppress_stdout_stderr(): # Since parsing URDF causes a lot of warning messages which can't be deactivated, we suppress them - self.urdf_object = URDF.from_xml_file(urdf_path) + self.urdf_object = URDFObject(urdf_path) + self.joint_types = {joint.name: joint.type for joint in self.urdf_object.joints} + self.joint_actuators: Optional[Dict] = parse_mjcf_actuators(mjcf_path) if mjcf_path is not None else None self.kinematic_chains: Dict[str, KinematicChainDescription] = {} self.cameras: Dict[str, CameraDescription] = {} self.grasps: Dict[Grasp, List[float]] = {} self.links: List[str] = [l.name for l in self.urdf_object.links] self.joints: List[str] = [j.name for j in self.urdf_object.joints] + self.virtual_mobile_base_joints: Optional[VirtualMobileBaseJoints] = virtual_mobile_base_joints + + @property + def has_actuators(self): + """ + Property to check if the robot has actuators defined in the MJCF file. + + :return: True if the robot has actuators, False otherwise + """ + return self.joint_actuators is not None + + def get_actuator_for_joint(self, joint: str) -> Optional[str]: + """ + Get the actuator name for a given joint. + + :param joint: Name of the joint + :return: Name of the actuator + """ + if self.has_actuators: + return self.joint_actuators.get(joint) + return None def add_kinematic_chain_description(self, chain: KinematicChainDescription): """ @@ -200,7 +238,7 @@ def add_grasp_orientations(self, orientations: Dict[Grasp, List[float]]): def get_manipulator_chains(self) -> List[KinematicChainDescription]: """ - Returns a list of all manipulator chains of the robot which posses an end effector. + Get a list of all manipulator chains of the robot which posses an end effector. :return: A list of KinematicChainDescription objects """ @@ -210,7 +248,7 @@ def get_manipulator_chains(self) -> List[KinematicChainDescription]: result.append(chain) return result - def get_camera_frame(self) -> str: + def get_camera_link(self) -> str: """ Quick method to get the name of a link of a camera. Uses the first camera in the list of cameras. @@ -218,9 +256,17 @@ def get_camera_frame(self) -> str: """ return self.cameras[list(self.cameras.keys())[0]].link_name + def get_camera_frame(self) -> str: + """ + Quick method to get the name of a link of a camera. Uses the first camera in the list of cameras. + + :return: A name of the link of a camera + """ + return f"{self.name}/{self.cameras[list(self.cameras.keys())[0]].link_name}" + def get_default_camera(self) -> CameraDescription: """ - Returns the first camera in the list of cameras. + Get the first camera in the list of cameras. :return: A CameraDescription object """ @@ -228,7 +274,7 @@ def get_default_camera(self) -> CameraDescription: def get_static_joint_chain(self, kinematic_chain_name: str, configuration_name: str): """ - Returns the static joint states of a kinematic chain for a specific configuration. When trying to access one of + Get the static joint states of a kinematic chain for a specific configuration. When trying to access one of the robot arms the function `:func: get_arm_chain` should be used. :param kinematic_chain_name: @@ -246,7 +292,7 @@ def get_static_joint_chain(self, kinematic_chain_name: str, configuration_name: def get_parent(self, name: str) -> str: """ - Returns the parent of a link or joint in the URDF. Always returns the imeadiate parent, for a link this is a joint + Get the parent of a link or joint in the URDF. Always returns the imeadiate parent, for a link this is a joint and vice versa. :param name: Name of the link or joint in the URDF @@ -266,7 +312,7 @@ def get_parent(self, name: str) -> str: def get_child(self, name: str, return_multiple_children: bool = False) -> Union[str, List[str]]: """ - Returns the child of a link or joint in the URDF. Always returns the immediate child, for a link this is a joint + Get the child of a link or joint in the URDF. Always returns the immediate child, for a link this is a joint and vice versa. Since a link can have multiple children, the return_multiple_children parameter can be set to True to get a list of all children. @@ -293,9 +339,19 @@ def get_child(self, name: str, return_multiple_children: bool = False) -> Union[ child_link = self.urdf_object.joint_map[name].child return child_link + def get_arm_tool_frame(self, arm: Arms) -> str: + """ + Get the name of the tool frame of a specific arm. + + :param arm: Arm for which the tool frame should be returned + :return: The name of the link of the tool frame in the URDF. + """ + chain = self.get_arm_chain(arm) + return chain.get_tool_frame() + def get_arm_chain(self, arm: Arms) -> Union[KinematicChainDescription, List[KinematicChainDescription]]: """ - Returns the kinematic chain of a specific arm. If the arm is set to BOTH, all kinematic chains are returned. + Get the kinematic chain of a specific arm. If the arm is set to BOTH, all kinematic chains are returned. :param arm: Arm for which the chain should be returned :return: KinematicChainDescription object of the arm @@ -329,7 +385,7 @@ class KinematicChainDescription: """ Last link of the chain """ - urdf_object: URDF + urdf_object: URDFObject """ Parsed URDF of the robot """ @@ -358,7 +414,7 @@ class KinematicChainDescription: Dictionary of static joint states for the chain """ - def __init__(self, name: str, start_link: str, end_link: str, urdf_object: URDF, arm_type: Arms = None, + def __init__(self, name: str, start_link: str, end_link: str, urdf_object: URDFObject, arm_type: Arms = None, include_fixed_joints=False): """ Initialize the KinematicChainDescription object. @@ -373,7 +429,7 @@ def __init__(self, name: str, start_link: str, end_link: str, urdf_object: URDF, self.name: str = name self.start_link: str = start_link self.end_link: str = end_link - self.urdf_object: URDF = urdf_object + self.urdf_object: URDFObject = urdf_object self.include_fixed_joints: bool = include_fixed_joints self.link_names: List[str] = [] self.joint_names: List[str] = [] @@ -395,11 +451,12 @@ def _init_joints(self): Initializes the joints of the chain by getting the chain from the URDF object. """ joints = self.urdf_object.get_chain(self.start_link, self.end_link, links=False) - self.joint_names = list(filter(lambda j: self.urdf_object.joint_map[j].type != "fixed" or self.include_fixed_joints, joints)) + self.joint_names = list(filter(lambda j: self.urdf_object.joint_map[j].type != JointType.FIXED + or self.include_fixed_joints, joints)) def get_joints(self) -> List[str]: """ - Returns a list of all joints of the chain. + Get a list of all joints of the chain. :return: List of joint names """ @@ -407,9 +464,7 @@ def get_joints(self) -> List[str]: def get_links(self) -> List[str]: """ - Returns a list of all links of the chain. - - :return: List of link names + :return: A list of all links of the chain. """ return self.link_names @@ -445,7 +500,7 @@ def add_static_joint_states(self, name: str, states: dict): def get_static_joint_states(self, name: str) -> Dict[str, float]: """ - Returns the dictionary of static joint states for a given name of the static joint states. + Get the dictionary of static joint states for a given name of the static joint states. :param name: Name of the static joint states :return: Dictionary of joint names and their values @@ -457,7 +512,7 @@ def get_static_joint_states(self, name: str) -> Dict[str, float]: def get_tool_frame(self) -> str: """ - Returns the name of the tool frame of the end effector of this chain, if it has an end effector. + Get the name of the tool frame of the end effector of this chain, if it has an end effector. :return: The name of the link of the tool frame in the URDF. """ @@ -468,7 +523,7 @@ def get_tool_frame(self) -> str: def get_static_gripper_state(self, state: GripperState) -> Dict[str, float]: """ - Returns the static joint states for the gripper of the chain. + Get the static joint states for the gripper of the chain. :param state: Name of the static joint states :return: Dictionary of joint names and their values @@ -552,7 +607,7 @@ class EndEffectorDescription: """ Name of the tool frame link in the URDf """ - urdf_object: URDF + urdf_object: URDFObject """ Parsed URDF of the robot """ @@ -577,7 +632,7 @@ class EndEffectorDescription: Distance the gripper can open, in cm """ - def __init__(self, name: str, start_link: str, tool_frame: str, urdf_object: URDF): + def __init__(self, name: str, start_link: str, tool_frame: str, urdf_object: URDFObject): """ Initialize the EndEffectorDescription object. @@ -589,7 +644,7 @@ def __init__(self, name: str, start_link: str, tool_frame: str, urdf_object: URD self.name: str = name self.start_link: str = start_link self.tool_frame: str = tool_frame - self.urdf_object: URDF = urdf_object + self.urdf_object: URDFObject = urdf_object self.link_names: List[str] = [] self.joint_names: List[str] = [] self.static_joint_states: Dict[GripperState, Dict[str, float]] = {} diff --git a/src/pycram/robot_descriptions/__init__.py b/src/pycram/robot_descriptions/__init__.py index 33dc54ca9..1ec759998 100644 --- a/src/pycram/robot_descriptions/__init__.py +++ b/src/pycram/robot_descriptions/__init__.py @@ -9,7 +9,8 @@ class DeprecatedRobotDescription: def raise_error(self): - raise DeprecationWarning("Robot description moved, please use RobotDescription.current_robot_description from pycram.robot_description") + raise DeprecationWarning("Robot description moved, please use RobotDescription.current_robot_description from" + " pycram.robot_description") @property def name(self): diff --git a/src/pycram/robot_descriptions/pr2_description.py b/src/pycram/robot_descriptions/pr2_description.py index 402125f2a..b7fe30ad7 100644 --- a/src/pycram/robot_descriptions/pr2_description.py +++ b/src/pycram/robot_descriptions/pr2_description.py @@ -1,13 +1,18 @@ +from ..datastructures.dataclasses import VirtualMobileBaseJoints from ..robot_description import RobotDescription, KinematicChainDescription, EndEffectorDescription, \ RobotDescriptionManager, CameraDescription from ..datastructures.enums import Arms, Grasp, GripperState, GripperType import rospkg +from ..helper import get_robot_mjcf_path + rospack = rospkg.RosPack() filename = rospack.get_path('pycram') + '/resources/robots/' + "pr2" + '.urdf' +mjcf_filename = get_robot_mjcf_path("", "pr2") + pr2_description = RobotDescription("pr2", "base_link", "torso_lift_link", "torso_lift_joint", - filename) + filename, virtual_mobile_base_joints=VirtualMobileBaseJoints(), mjcf_path=mjcf_filename) ################################## Left Arm ################################## left_arm = KinematicChainDescription("left", "torso_lift_link", "l_wrist_roll_link", diff --git a/src/pycram/robot_descriptions/tiago_description.py b/src/pycram/robot_descriptions/tiago_description.py index 6a92d47ec..37b1ec29e 100644 --- a/src/pycram/robot_descriptions/tiago_description.py +++ b/src/pycram/robot_descriptions/tiago_description.py @@ -1,13 +1,20 @@ import rospkg + +from ..datastructures.dataclasses import VirtualMobileBaseJoints +from ..datastructures.enums import GripperState, Arms, Grasp from ..robot_description import RobotDescription, KinematicChainDescription, EndEffectorDescription, \ RobotDescriptionManager, CameraDescription -from ..datastructures.enums import GripperState, Arms, Grasp +from ..helper import get_robot_mjcf_path rospack = rospkg.RosPack() filename = rospack.get_path('pycram') + '/resources/robots/' + "tiago_dual" + '.urdf' +mjcf_filename = get_robot_mjcf_path("pal_robotics", "tiago_dual") + tiago_description = RobotDescription("tiago_dual", "base_link", "torso_lift_link", "torso_lift_joint", - filename) + filename, + virtual_mobile_base_joints=VirtualMobileBaseJoints(), + mjcf_path=mjcf_filename) ################################## Left Arm ################################## left_arm = KinematicChainDescription("left_arm", "torso_lift_link", "arm_left_7_link", diff --git a/src/pycram/ros/viz_marker_publisher.py b/src/pycram/ros/viz_marker_publisher.py index 0aa149e9b..39e519b3d 100644 --- a/src/pycram/ros/viz_marker_publisher.py +++ b/src/pycram/ros/viz_marker_publisher.py @@ -3,6 +3,7 @@ import time from typing import List, Optional, Tuple +import numpy as np import rospy from geometry_msgs.msg import Vector3 from std_msgs.msg import ColorRGBA @@ -35,7 +36,7 @@ def __init__(self, topic_name="/pycram/viz_marker", interval=0.1): self.thread = threading.Thread(target=self._publish) self.kill_event = threading.Event() self.main_world = World.current_world if not World.current_world.is_prospection_world else World.current_world.world_sync.world - + self.lock = self.main_world.object_lock self.thread.start() atexit.register(self._stop_publishing) @@ -44,8 +45,9 @@ def _publish(self) -> None: Constantly publishes the Marker Array. To the given topic name at a fixed rate. """ while not self.kill_event.is_set(): + self.lock.acquire() marker_array = self._make_marker_array() - + self.lock.release() self.pub.publish(marker_array) time.sleep(self.interval) @@ -79,7 +81,7 @@ def _make_marker_array(self) -> MarkerArray: link_pose_with_origin = link_pose * link_origin msg.pose = link_pose_with_origin.to_pose().pose - color = [1, 1, 1, 1] if obj.link_name_to_id[link] == -1 else obj.get_link_color(link).get_rgba() + color = obj.get_link_color(link).get_rgba() msg.color = ColorRGBA(*color) msg.lifetime = rospy.Duration(1) @@ -94,7 +96,8 @@ def _make_marker_array(self) -> MarkerArray: msg.scale = Vector3(geom.radius * 2, geom.radius * 2, geom.length) elif isinstance(geom, BoxVisualShape): msg.type = Marker.CUBE - msg.scale = Vector3(*geom.size) + size = np.array(geom.size) * 2 + msg.scale = Vector3(size[0], size[1], size[2]) elif isinstance(geom, SphereVisualShape): msg.type = Marker.SPHERE msg.scale = Vector3(geom.radius * 2, geom.radius * 2, geom.radius * 2) diff --git a/src/pycram/tasktree.py b/src/pycram/tasktree.py index 94abf932e..06440829e 100644 --- a/src/pycram/tasktree.py +++ b/src/pycram/tasktree.py @@ -14,7 +14,7 @@ from .orm.action_designator import Action from .orm.tasktree import TaskTreeNode as ORMTaskTreeNode from .orm.base import ProcessMetaData -from .plan_failures import PlanFailure +from .failures import PlanFailure from .datastructures.enums import TaskStatus from .datastructures.dataclasses import Color diff --git a/src/pycram/utils.py b/src/pycram/utils.py index 28b5109ac..30bd10538 100644 --- a/src/pycram/utils.py +++ b/src/pycram/utils.py @@ -7,14 +7,20 @@ GeneratorList -- implementation of generator list wrappers. """ from inspect import isgeneratorfunction -from typing_extensions import List, Tuple, Callable - import os +import math + +import numpy as np +from matplotlib import pyplot as plt +import matplotlib.colors as mcolors +from typing_extensions import Tuple, Callable, List, Dict, TYPE_CHECKING from .datastructures.pose import Pose -import math +from .local_transformer import LocalTransformer -from typing_extensions import Dict +if TYPE_CHECKING: + from .world_concepts.world_object import Object + from .robot_description import CameraDescription class bcolors: @@ -36,7 +42,7 @@ class bcolors: UNDERLINE = '\033[4m' -def _apply_ik(robot: 'pycram.world_concepts.WorldObject', pose_and_joint_poses: Tuple[Pose, Dict[str, float]]) -> None: +def _apply_ik(robot: 'Object', pose_and_joint_poses: Tuple[Pose, Dict[str, float]]) -> None: """ Apllies a list of joint poses calculated by an inverse kinematics solver to a robot @@ -46,7 +52,7 @@ def _apply_ik(robot: 'pycram.world_concepts.WorldObject', pose_and_joint_poses: """ pose, joint_states = pose_and_joint_poses robot.set_pose(pose) - robot.set_joint_positions(joint_states) + robot.set_multiple_joint_positions(joint_states) class GeneratorList: @@ -113,7 +119,7 @@ def axis_angle_to_quaternion(axis: List, angle: float) -> Tuple: z = normalized_axis[2] * math.sin(angle / 2) w = math.cos(angle / 2) - return (x, y, z, w) + return tuple((x, y, z, w)) class suppress_stdout_stderr(object): @@ -130,7 +136,7 @@ class suppress_stdout_stderr(object): def __init__(self): # Open a pair of null files - self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)] + self.null_fds = [os.open(os.devnull, os.O_RDWR) for _ in range(2)] # Save the actual stdout (1) and stderr (2) file descriptors. self.save_fds = [os.dup(1), os.dup(2)] @@ -148,3 +154,257 @@ def __exit__(self, *_): # Close all file descriptors for fd in self.null_fds + self.save_fds: os.close(fd) + + +class RayTestUtils: + + def __init__(self, ray_test_batch: Callable, object_id_to_name: Dict = None): + """ + Initialize the ray test helper. + """ + self.local_transformer = LocalTransformer() + self.ray_test_batch = ray_test_batch + self.object_id_to_name = object_id_to_name + + def get_images_for_target(self, cam_pose: Pose, + camera_description: 'CameraDescription', + camera_frame: str, + size: int = 256, + camera_min_distance: float = 0.1, + camera_max_distance: int = 3, + plot: bool = False) -> List[np.ndarray]: + """ + Note: The returned color image is a repeated depth image in 3 channels. + """ + + # get the list of start positions of the rays. + rays_start_positions = self.get_camera_rays_start_positions(camera_description, camera_frame, cam_pose, size, + camera_min_distance).tolist() + + # get the list of end positions of the rays + rays_end_positions = self.get_camera_rays_end_positions(camera_description, camera_frame, cam_pose, size, + camera_max_distance).tolist() + + # apply the ray test + object_ids, distances = self.ray_test_batch(rays_start_positions, rays_end_positions, return_distance=True) + + # construct the images/masks + segmentation_mask = self.construct_segmentation_mask_from_ray_test_object_ids(object_ids, size) + depth_image = self.construct_depth_image_from_ray_test_distances(distances, size) + camera_min_distance + color_depth_image = self.construct_color_image_from_depth_image(depth_image) + + if plot: + self.plot_segmentation_mask(segmentation_mask) + self.plot_depth_image(depth_image) + + return [color_depth_image, depth_image, segmentation_mask] + + @staticmethod + def construct_segmentation_mask_from_ray_test_object_ids(object_ids: List[int], size: int) -> np.ndarray: + """ + Construct a segmentation mask from the object ids returned by the ray test. + + :param object_ids: The object ids. + :param size: The size of the grid. + :return: The segmentation mask. + """ + return np.array(object_ids).squeeze(axis=1).reshape(size, size) + + @staticmethod + def construct_depth_image_from_ray_test_distances(distances: List[float], size: int) -> np.ndarray: + """ + Construct a depth image from the distances returned by the ray test. + + :param distances: The distances. + :param size: The size of the grid. + :return: The depth image. + """ + return np.array(distances).reshape(size, size) + + @staticmethod + def construct_color_image_from_depth_image(depth_image: np.ndarray) -> np.ndarray: + """ + Construct a color image from the depth image. + + :param depth_image: The depth image. + :return: The color image. + """ + min_distance = np.min(depth_image) + max_distance = np.max(depth_image) + normalized_depth_image = (depth_image - min_distance) * 255 / (max_distance - min_distance) + return np.repeat(normalized_depth_image[:, :, np.newaxis], 3, axis=2).astype(np.uint8) + + def get_camera_rays_start_positions(self, camera_description: 'CameraDescription', camera_frame: str, + camera_pose: Pose, size: int, + camera_min_distance: float) -> np.ndarray: + + # get the start pose of the rays from the camera pose and minimum distance. + start_pose = self.get_camera_rays_start_pose(camera_description, camera_frame, camera_pose, camera_min_distance) + + # get the list of start positions of the rays. + return np.repeat(np.array([start_pose.position_as_list()]), size * size, axis=0) + + def get_camera_rays_start_pose(self, camera_description: 'CameraDescription', camera_frame: str, camera_pose: Pose, + camera_min_distance: float) -> Pose: + """ + Get the start position of the camera rays, which is the camera pose shifted by the minimum distance of the + camera. + + :param camera_description: The camera description. + :param camera_frame: The camera tf frame. + :param camera_pose: The camera pose. + :param camera_min_distance: The minimum distance from which the camera can see. + """ + camera_pose_in_camera_frame = self.local_transformer.transform_pose(camera_pose, camera_frame) + start_position = (np.array(camera_description.front_facing_axis) * camera_min_distance + + np.array(camera_pose_in_camera_frame.position_as_list())) + start_pose = Pose(start_position.tolist(), camera_pose_in_camera_frame.orientation_as_list(), camera_frame) + return self.local_transformer.transform_pose(start_pose, "map") + + def get_camera_rays_end_positions(self, camera_description: 'CameraDescription', camera_frame: str, + camera_pose: Pose, size: int, camera_max_distance: float = 3.0) -> np.ndarray: + """ + Get the end positions of the camera rays. + + :param camera_description: The camera description. + :param camera_frame: The camera frame. + :param camera_pose: The camera pose. + :param size: The size of the grid. + :param camera_max_distance: The maximum distance of the camera. + :return: The end positions of the camera rays. + """ + rays_horizontal_angles, rays_vertical_angles = self.construct_grid_of_camera_rays_angles(camera_description, + size) + rays_end_positions = self.get_end_positions_of_rays_from_angles_and_distance(rays_vertical_angles, + rays_horizontal_angles, + camera_max_distance) + return self.transform_points_from_camera_frame_to_world_frame(camera_pose, camera_frame, rays_end_positions) + + @staticmethod + def transform_points_from_camera_frame_to_world_frame(camera_pose: Pose, camera_frame: str, + points: np.ndarray) -> np.ndarray: + """ + Transform points from the camera frame to the world frame. + + :param camera_pose: The camera pose. + :param camera_frame: The camera frame. + :param points: The points to transform. + :return: The transformed points. + """ + cam_to_world_transform = camera_pose.to_transform(camera_frame) + return cam_to_world_transform.apply_transform_to_array_of_points(points) + + @staticmethod + def get_end_positions_of_rays_from_angles_and_distance(vertical_angles: np.ndarray, horizontal_angles: np.ndarray, + distance: float) -> np.ndarray: + """ + Get the end positions of the rays from the angles and the distance. + + :param vertical_angles: The vertical angles of the rays. + :param horizontal_angles: The horizontal angles of the rays. + :param distance: The distance of the rays. + :return: The end positions of the rays. + """ + rays_end_positions_x = distance * np.cos(vertical_angles) * np.sin(horizontal_angles) + rays_end_positions_x = rays_end_positions_x.reshape(-1) + rays_end_positions_z = distance * np.cos(vertical_angles) * np.cos(horizontal_angles) + rays_end_positions_z = rays_end_positions_z.reshape(-1) + rays_end_positions_y = distance * np.sin(vertical_angles) + rays_end_positions_y = rays_end_positions_y.reshape(-1) + return np.stack((rays_end_positions_x, rays_end_positions_y, rays_end_positions_z), axis=1) + + @staticmethod + def construct_grid_of_camera_rays_angles(camera_description: 'CameraDescription', + size: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Construct a 2D grid of camera rays angles. + + :param camera_description: The camera description. + :param size: The size of the grid. + :return: The 2D grid of the horizontal and the vertical angles of the camera rays. + """ + # get the camera fov angles + camera_horizontal_fov = camera_description.horizontal_angle + camera_vertical_fov = camera_description.vertical_angle + + # construct a 2d grid of rays angles + rays_horizontal_angles = np.linspace(-camera_horizontal_fov / 2, camera_horizontal_fov / 2, size) + rays_horizontal_angles = np.tile(rays_horizontal_angles, (size, 1)) + rays_vertical_angles = np.linspace(-camera_vertical_fov / 2, camera_vertical_fov / 2, size) + rays_vertical_angles = np.tile(rays_vertical_angles, (size, 1)).T + return rays_horizontal_angles, rays_vertical_angles + + @staticmethod + def plot_segmentation_mask(segmentation_mask, + object_id_to_name: Dict[int, str] = None): + """ + Plot the segmentation mask with different colors for each object. + + :param segmentation_mask: The segmentation mask. + :param object_id_to_name: The mapping from object id to object name. + """ + if object_id_to_name is None: + object_id_to_name = {uid: str(uid) for uid in np.unique(segmentation_mask)} + + # Create a custom color map + unique_ids = np.unique(segmentation_mask) + unique_ids = unique_ids[unique_ids != -1] # Exclude -1 values + + # Create a color map that assigns a unique color to each ID + colors = plt.cm.get_cmap('tab20', len(unique_ids)) # Use tab20 colormap for distinct colors + color_dict = {uid: colors(i) for i, uid in enumerate(unique_ids)} + + # Map each ID to its corresponding color + mask_shape = segmentation_mask.shape + segmentation_colored = np.zeros((mask_shape[0], mask_shape[1], 3)) + + for uid in unique_ids: + segmentation_colored[segmentation_mask == uid] = color_dict[uid][:3] # Ignore the alpha channel + + # Create a colormap for the color bar + cmap = mcolors.ListedColormap([color_dict[uid][:3] for uid in unique_ids]) + norm = mcolors.BoundaryNorm(boundaries=np.arange(len(unique_ids) + 1) - 0.5, ncolors=len(unique_ids)) + + # Plot the colored segmentation mask + fig, ax = plt.subplots() + _ = ax.imshow(segmentation_colored) + ax.axis('off') # Hide axes + ax.set_title('Segmentation Mask with Different Colors for Each Object') + + # Create color bar + cbar = fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, ticks=np.arange(len(unique_ids))) + cbar.ax.set_yticklabels( + [object_id_to_name[uid] for uid in unique_ids]) # Label the color bar with object IDs + cbar.set_label('Object Name') + + plt.show() + + @staticmethod + def plot_depth_image(depth_image): + # Plot the depth image + fig, ax = plt.subplots() + cax = ax.imshow(depth_image, cmap='viridis', vmin=0, vmax=np.max(depth_image)) + ax.axis('off') # Hide axes + ax.set_title('Depth Image') + + # Create color bar + cbar = fig.colorbar(cax, ax=ax) + cbar.set_label('Depth Value') + + plt.show() + + +def wxyz_to_xyzw(wxyz: List[float]) -> List[float]: + """ + Convert a quaternion from WXYZ to XYZW format. + """ + return [wxyz[1], wxyz[2], wxyz[3], wxyz[0]] + + +def xyzw_to_wxyz(xyzw: List[float]) -> List[float]: + """ + Convert a quaternion from XYZW to WXYZ format. + + :param xyzw: The quaternion in XYZW format. + """ + return [xyzw[3], *xyzw[:3]] \ No newline at end of file diff --git a/src/pycram/validation/__init__.py b/src/pycram/validation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pycram/validation/error_checkers.py b/src/pycram/validation/error_checkers.py new file mode 100644 index 000000000..f594b6bd6 --- /dev/null +++ b/src/pycram/validation/error_checkers.py @@ -0,0 +1,351 @@ +from abc import ABC, abstractmethod +from collections.abc import Iterable + +import numpy as np +from tf.transformations import quaternion_multiply, quaternion_inverse +from typing_extensions import List, Union, Optional, Any, Sized, Iterable as T_Iterable, TYPE_CHECKING, Tuple + +from ..datastructures.enums import JointType +if TYPE_CHECKING: + from ..datastructures.pose import Pose + + +class ErrorChecker(ABC): + """ + An abstract class that resembles an error checker. It has two main methods, one for calculating the error between + two values and another for checking if the error is acceptable. + """ + def __init__(self, acceptable_error: Union[float, T_Iterable[float]], is_iterable: Optional[bool] = False): + """ + Initialize the error checker. + + :param acceptable_error: The acceptable error. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + self._acceptable_error: np.ndarray = np.array(acceptable_error) + self.tiled_acceptable_error: Optional[np.ndarray] = None + self.is_iterable = is_iterable + + def reset(self) -> None: + """ + Reset the error checker. + """ + self.tiled_acceptable_error = None + + @property + def acceptable_error(self) -> np.ndarray: + return self._acceptable_error + + @acceptable_error.setter + def acceptable_error(self, new_acceptable_error: Union[float, T_Iterable[float]]) -> None: + self._acceptable_error = np.array(new_acceptable_error) + + def update_acceptable_error(self, new_acceptable_error: Optional[T_Iterable[float]] = None, + tile_to_match: Optional[Sized] = None,) -> None: + """ + Update the acceptable error with a new value, and tile it to match the length of the error if needed. + + :param new_acceptable_error: The new acceptable error. + :param tile_to_match: The iterable to match the length of the error with. + """ + if new_acceptable_error is not None: + self.acceptable_error = new_acceptable_error + if tile_to_match is not None and self.is_iterable: + self.update_tiled_acceptable_error(tile_to_match) + + def update_tiled_acceptable_error(self, tile_to_match: Sized) -> None: + """ + Tile the acceptable error to match the length of the error. + + :param tile_to_match: The object to match the length of the error. + :return: The tiled acceptable error. + """ + self.tiled_acceptable_error = np.tile(self.acceptable_error.flatten(), + len(tile_to_match) // self.acceptable_error.size) + + @abstractmethod + def _calculate_error(self, value_1: Any, value_2: Any) -> Union[float, List[float]]: + """ + Calculate the error between two values. + + :param value_1: The first value. + :param value_2: The second value. + :return: The error between the two values. + """ + pass + + def calculate_error(self, value_1: Any, value_2: Any) -> Union[float, List[float]]: + """ + Calculate the error between two values. + + :param value_1: The first value. + :param value_2: The second value. + :return: The error between the two values. + """ + if self.is_iterable: + return [self._calculate_error(v1, v2) for v1, v2 in zip(value_1, value_2)] + else: + return self._calculate_error(value_1, value_2) + + def is_error_acceptable(self, value_1: Any, value_2: Any) -> bool: + """ + Check if the error is acceptable. + + :param value_1: The first value. + :param value_2: The second value. + :return: Whether the error is acceptable. + """ + error = self.calculate_error(value_1, value_2) + if self.is_iterable: + error = np.array(error).flatten() + if self.tiled_acceptable_error is None or\ + len(error) != len(self.tiled_acceptable_error): + self.update_tiled_acceptable_error(error) + return np.all(error <= self.tiled_acceptable_error) + else: + return is_error_acceptable(error, self.acceptable_error) + + +class PoseErrorChecker(ErrorChecker): + + def __init__(self, acceptable_error: Union[Tuple[float], T_Iterable[Tuple[float]]] = (1e-3, np.pi / 180), + is_iterable: Optional[bool] = False): + """ + Initialize the pose error checker. + + :param acceptable_error: The acceptable pose error (position error, orientation error). + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + def _calculate_error(self, value_1: Any, value_2: Any) -> List[float]: + """ + Calculate the error between two poses. + + :param value_1: The first pose. + :param value_2: The second pose. + """ + return calculate_pose_error(value_1, value_2) + + +class PositionErrorChecker(ErrorChecker): + + def __init__(self, acceptable_error: Optional[float] = 1e-3, is_iterable: Optional[bool] = False): + """ + Initialize the position error checker. + + :param acceptable_error: The acceptable position error. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + def _calculate_error(self, value_1: Any, value_2: Any) -> float: + """ + Calculate the error between two positions. + + :param value_1: The first position. + :param value_2: The second position. + :return: The error between the two positions. + """ + return calculate_position_error(value_1, value_2) + + +class OrientationErrorChecker(ErrorChecker): + + def __init__(self, acceptable_error: Optional[float] = np.pi / 180, is_iterable: Optional[bool] = False): + """ + Initialize the orientation error checker. + + :param acceptable_error: The acceptable orientation error. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + def _calculate_error(self, value_1: Any, value_2: Any) -> float: + """ + Calculate the error between two quaternions. + + :param value_1: The first quaternion. + :param value_2: The second quaternion. + :return: The error between the two quaternions. + """ + return calculate_orientation_error(value_1, value_2) + + +class SingleValueErrorChecker(ErrorChecker): + + def __init__(self, acceptable_error: Optional[float] = 1e-3, is_iterable: Optional[bool] = False): + """ + Initialize the single value error checker. + + :param acceptable_error: The acceptable error between two values. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + def _calculate_error(self, value_1: Any, value_2: Any) -> float: + """ + Calculate the error between two values. + + :param value_1: The first value. + :param value_2: The second value. + :return: The error between the two values. + """ + return abs(value_1 - value_2) + + +class RevoluteJointPositionErrorChecker(SingleValueErrorChecker): + + def __init__(self, acceptable_error: Optional[float] = np.pi / 180, is_iterable: Optional[bool] = False): + """ + Initialize the revolute joint position error checker. + + :param acceptable_error: The acceptable revolute joint position error. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + +class PrismaticJointPositionErrorChecker(SingleValueErrorChecker): + + def __init__(self, acceptable_error: Optional[float] = 1e-3, is_iterable: Optional[bool] = False): + """ + Initialize the prismatic joint position error checker. + + :param acceptable_error: The acceptable prismatic joint position error. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + +class IterableErrorChecker(ErrorChecker): + + def __init__(self, acceptable_error: Optional[T_Iterable[float]] = None): + """ + Initialize the iterable error checker. + + :param acceptable_error: The acceptable error between two values. + """ + super().__init__(acceptable_error, True) + + def _calculate_error(self, value_1: Any, value_2: Any) -> float: + """ + Calculate the error between two values. + + :param value_1: The first value. + :param value_2: The second value. + :return: The error between the two values. + """ + return abs(value_1 - value_2) + + +class MultiJointPositionErrorChecker(IterableErrorChecker): + + def __init__(self, joint_types: List[JointType], acceptable_error: Optional[T_Iterable[float]] = None): + """ + Initialize the multi-joint position error checker. + + :param joint_types: The types of the joints. + :param acceptable_error: The acceptable error between two joint positions. + """ + self.joint_types = joint_types + if acceptable_error is None: + acceptable_error = [np.pi/180 if jt == JointType.REVOLUTE else 1e-3 for jt in joint_types] + super().__init__(acceptable_error) + + def _calculate_error(self, value_1: Any, value_2: Any) -> float: + """ + Calculate the error between two joint positions. + + :param value_1: The first joint position. + :param value_2: The second joint position. + :return: The error between the two joint positions. + """ + return calculate_joint_position_error(value_1, value_2) + + +def calculate_pose_error(pose_1: 'Pose', pose_2: 'Pose') -> List[float]: + """ + Calculate the error between two poses. + + :param pose_1: The first pose. + :param pose_2: The second pose. + :return: The error between the two poses. + """ + return [calculate_position_error(pose_1.position_as_list(), pose_2.position_as_list()), + calculate_orientation_error(pose_1.orientation_as_list(), pose_2.orientation_as_list())] + + +def calculate_position_error(position_1: List[float], position_2: List[float]) -> float: + """ + Calculate the error between two positions. + + :param position_1: The first position. + :param position_2: The second position. + :return: The error between the two positions. + """ + return np.linalg.norm(np.array(position_1) - np.array(position_2)) + + +def calculate_orientation_error(quat_1: List[float], quat_2: List[float]) -> float: + """ + Calculate the error between two quaternions. + + :param quat_1: The first quaternion. + :param quat_2: The second quaternion. + :return: The error between the two quaternions. + """ + return calculate_angle_between_quaternions(quat_1, quat_2) + + +def calculate_joint_position_error(joint_position_1: float, joint_position_2: float) -> float: + """ + Calculate the error between two joint positions. + + :param joint_position_1: The first joint position. + :param joint_position_2: The second joint position. + :return: The error between the two joint positions. + """ + return abs(joint_position_1 - joint_position_2) + + +def is_error_acceptable(error: Union[float, T_Iterable[float]], + acceptable_error: Union[float, T_Iterable[float]]) -> bool: + """ + Check if the error is acceptable. + + :param error: The error. + :param acceptable_error: The acceptable error. + :return: Whether the error is acceptable. + """ + if isinstance(error, Iterable): + return all([error_i <= acceptable_error_i for error_i, acceptable_error_i in zip(error, acceptable_error)]) + else: + return error <= acceptable_error + + +def calculate_angle_between_quaternions(quat_1: List[float], quat_2: List[float]) -> float: + """ + Calculates the angle between two quaternions. + + :param quat_1: The first quaternion. + :param quat_2: The second quaternion. + :return: A float value that represents the angle between the two quaternions. + """ + quat_diff = calculate_quaternion_difference(quat_1, quat_2) + quat_diff_angle = 2 * np.arctan2(np.linalg.norm(quat_diff[0:3]), quat_diff[3]) + if quat_diff_angle > np.pi: + quat_diff_angle = 2 * np.pi - quat_diff_angle + return quat_diff_angle + + +def calculate_quaternion_difference(quat_1: List[float], quat_2: List[float]) -> List[float]: + """ + Calculates the quaternion difference. + + :param quat_1: The quaternion of the object at the first time step. + :param quat_2: The quaternion of the object at the second time step. + :return: A list of float values that represent the quaternion difference. + """ + quat_diff = quaternion_multiply(quaternion_inverse(quat_1), quat_2) + return quat_diff diff --git a/src/pycram/validation/goal_validator.py b/src/pycram/validation/goal_validator.py new file mode 100644 index 000000000..a70f2b438 --- /dev/null +++ b/src/pycram/validation/goal_validator.py @@ -0,0 +1,550 @@ +from time import sleep, time + +import numpy as np +import rospy +from typing_extensions import Any, Callable, Optional, Union, Iterable, Dict, TYPE_CHECKING, Tuple + +from ..datastructures.enums import JointType +from .error_checkers import ErrorChecker, PoseErrorChecker, PositionErrorChecker, \ + OrientationErrorChecker, SingleValueErrorChecker + +if TYPE_CHECKING: + from ..datastructures.world import World + from ..world_concepts.world_object import Object + from ..datastructures.pose import Pose + from ..description import ObjectDescription + + Joint = ObjectDescription.Joint + Link = ObjectDescription.Link + +OptionalArgCallable = Union[Callable[[], Any], Callable[[Any], Any]] + + +class GoalValidator: + """ + A class to validate the goal by tracking the goal achievement progress. + """ + + raise_error: Optional[bool] = False + """ + Whether to raise an error if the goal is not achieved. + """ + + def __init__(self, error_checker: ErrorChecker, current_value_getter: OptionalArgCallable, + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8): + """ + Initialize the goal validator. + + :param error_checker: The error checker. + :param current_value_getter: The current value getter function which takes an optional input and returns the + current value. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved, if given, will be + used to check if this percentage is achieved instead of the complete goal. + """ + self.error_checker: ErrorChecker = error_checker + self.current_value_getter: Callable[[Optional[Any]], Any] = current_value_getter + self.acceptable_percentage_of_goal_achieved: Optional[float] = acceptable_percentage_of_goal_achieved + self.goal_value: Optional[Any] = None + self.initial_error: Optional[np.ndarray] = None + self.current_value_getter_input: Optional[Any] = None + + def register_goal_and_wait_until_achieved(self, goal_value: Any, + current_value_getter_input: Optional[Any] = None, + initial_value: Optional[Any] = None, + acceptable_error: Optional[Union[float, Iterable[float]]] = None, + max_wait_time: Optional[float] = 1, + time_per_read: Optional[float] = 0.01) -> None: + """ + Register the goal value and wait until the target is reached. + + :param goal_value: The goal value. + :param current_value_getter_input: The values that are used as input to the current value getter. + :param initial_value: The initial value. + :param acceptable_error: The acceptable error. + :param max_wait_time: The maximum time to wait. + :param time_per_read: The time to wait between each read. + """ + self.register_goal(goal_value, current_value_getter_input, initial_value, acceptable_error) + self.wait_until_goal_is_achieved(max_wait_time, time_per_read) + + def wait_until_goal_is_achieved(self, max_wait_time: Optional[float] = 2, + time_per_read: Optional[float] = 0.01) -> None: + """ + Wait until the target is reached. + + :param max_wait_time: The maximum time to wait. + :param time_per_read: The time to wait between each read. + """ + if self.goal_value is None: + return # Skip if goal value is None + start_time = time() + current = self.current_value + while not self.goal_achieved: + sleep(time_per_read) + if time() - start_time > max_wait_time: + msg = f"Failed to achieve goal from initial error {self.initial_error} with" \ + f" goal {self.goal_value} within {max_wait_time}" \ + f" seconds, the current value is {current}, error is {self.current_error}, percentage" \ + f" of goal achieved is {self.percentage_of_goal_achieved}" + if self.raise_error: + rospy.logerr(msg) + raise TimeoutError(msg) + else: + rospy.logwarn(msg) + break + current = self.current_value + self.reset() + + def reset(self) -> None: + """ + Reset the goal validator. + """ + self.goal_value = None + self.initial_error = None + self.current_value_getter_input = None + self.error_checker.reset() + + @property + def _acceptable_error(self) -> np.ndarray: + """ + The acceptable error. + """ + if self.error_checker.is_iterable: + return self.tiled_acceptable_error + else: + return self.acceptable_error + + @property + def acceptable_error(self) -> np.ndarray: + """ + The acceptable error. + """ + return self.error_checker.acceptable_error + + @property + def tiled_acceptable_error(self) -> Optional[np.ndarray]: + """ + The tiled acceptable error. + """ + return self.error_checker.tiled_acceptable_error + + def register_goal(self, goal_value: Any, + current_value_getter_input: Optional[Any] = None, + initial_value: Optional[Any] = None, + acceptable_error: Optional[Union[float, Iterable[float]]] = None): + """ + Register the goal value. + + :param goal_value: The goal value. + :param current_value_getter_input: The values that are used as input to the current value getter. + :param initial_value: The initial value. + :param acceptable_error: The acceptable error. + """ + if goal_value is None or (hasattr(goal_value, '__len__') and len(goal_value) == 0): + return # Skip if goal value is None or empty + self.goal_value = goal_value + self.current_value_getter_input = current_value_getter_input + self.update_initial_error(goal_value, initial_value=initial_value) + self.error_checker.update_acceptable_error(acceptable_error, self.initial_error) + + def update_initial_error(self, goal_value: Any, initial_value: Optional[Any] = None) -> None: + """ + Calculate the initial error. + + :param goal_value: The goal value. + :param initial_value: The initial value. + """ + if initial_value is None: + self.initial_error: np.ndarray = self.current_error + else: + self.initial_error: np.ndarray = self.calculate_error(goal_value, initial_value) + + @property + def current_value(self) -> Any: + """ + The current value of the monitored variable. + """ + if self.current_value_getter_input is not None: + return self.current_value_getter(self.current_value_getter_input) + else: + return self.current_value_getter() + + @property + def current_error(self) -> np.ndarray: + """ + The current error. + """ + return self.calculate_error(self.goal_value, self.current_value) + + def calculate_error(self, value_1: Any, value_2: Any) -> np.ndarray: + """ + Calculate the error between two values. + + :param value_1: The first value. + :param value_2: The second value. + :return: The error. + """ + return np.array(self.error_checker.calculate_error(value_1, value_2)).flatten() + + @property + def percentage_of_goal_achieved(self) -> float: + """ + The relative (relative to the acceptable error) achieved percentage of goal. + """ + percent_array = 1 - self.relative_current_error / self.relative_initial_error + percent_array_filtered = percent_array[self.relative_initial_error > self._acceptable_error] + if len(percent_array_filtered) == 0: + return 1 + else: + return np.mean(percent_array_filtered) + + @property + def actual_percentage_of_goal_achieved(self) -> float: + """ + The percentage of goal achieved. + """ + percent_array = 1 - self.current_error / np.maximum(self.initial_error, 1e-3) + percent_array_filtered = percent_array[self.initial_error > self._acceptable_error] + if len(percent_array_filtered) == 0: + return 1 + else: + return np.mean(percent_array_filtered) + + @property + def relative_current_error(self) -> np.ndarray: + """ + The relative current error (relative to the acceptable error). + """ + return self.get_relative_error(self.current_error, threshold=0) + + @property + def relative_initial_error(self) -> np.ndarray: + """ + The relative initial error (relative to the acceptable error). + """ + return np.maximum(self.initial_error, 1e-3) + + def get_relative_error(self, error: Any, threshold: Optional[float] = 1e-3) -> np.ndarray: + """ + Get the relative error by comparing the error with the acceptable error and filtering out the errors that are + less than the threshold. + + :param error: The error. + :param threshold: The threshold. + :return: The relative error. + """ + return np.maximum(error - self._acceptable_error, threshold) + + @property + def goal_achieved(self) -> bool: + """ + Check if the goal is achieved. + """ + if self.acceptable_percentage_of_goal_achieved is None: + return self.is_current_error_acceptable + else: + return self.percentage_of_goal_achieved >= self.acceptable_percentage_of_goal_achieved + + @property + def is_current_error_acceptable(self) -> bool: + """ + Check if the error is acceptable. + """ + return self.error_checker.is_error_acceptable(self.current_value, self.goal_value) + + +class PoseGoalValidator(GoalValidator): + """ + A class to validate the pose goal by tracking the goal achievement progress. + """ + + def __init__(self, current_pose_getter: OptionalArgCallable = None, + acceptable_error: Union[Tuple[float], Iterable[Tuple[float]]] = (1e-3, np.pi / 180), + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8, + is_iterable: Optional[bool] = False): + """ + Initialize the pose goal validator. + + :param current_pose_getter: The current pose getter function which takes an optional input and returns the + current pose. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + """ + super().__init__(PoseErrorChecker(acceptable_error, is_iterable=is_iterable), current_pose_getter, + acceptable_percentage_of_goal_achieved=acceptable_percentage_of_goal_achieved) + + +class MultiPoseGoalValidator(PoseGoalValidator): + """ + A class to validate the multi-pose goal by tracking the goal achievement progress. + """ + + def __init__(self, current_poses_getter: OptionalArgCallable = None, + acceptable_error: Union[Tuple[float], Iterable[Tuple[float]]] = (1e-2, 5 * np.pi / 180), + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8): + """ + Initialize the multi-pose goal validator. + + :param current_poses_getter: The current poses getter function which takes an optional input and returns the + current poses. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + """ + super().__init__(current_poses_getter, acceptable_error, acceptable_percentage_of_goal_achieved, + is_iterable=True) + + +class PositionGoalValidator(GoalValidator): + """ + A class to validate the position goal by tracking the goal achievement progress. + """ + + def __init__(self, current_position_getter: OptionalArgCallable = None, + acceptable_error: Optional[float] = 1e-3, + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8, + is_iterable: Optional[bool] = False): + """ + Initialize the position goal validator. + + :param current_position_getter: The current position getter function which takes an optional input and + returns the current position. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + :param is_iterable: Whether it is a sequence of position vectors. + """ + super().__init__(PositionErrorChecker(acceptable_error, is_iterable=is_iterable), current_position_getter, + acceptable_percentage_of_goal_achieved=acceptable_percentage_of_goal_achieved) + + +class MultiPositionGoalValidator(PositionGoalValidator): + """ + A class to validate the multi-position goal by tracking the goal achievement progress. + """ + + def __init__(self, current_positions_getter: OptionalArgCallable = None, + acceptable_error: Optional[float] = 1e-3, + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8): + """ + Initialize the multi-position goal validator. + + :param current_positions_getter: The current positions getter function which takes an optional input and + returns the current positions. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + """ + super().__init__(current_positions_getter, acceptable_error, acceptable_percentage_of_goal_achieved, + is_iterable=True) + + +class OrientationGoalValidator(GoalValidator): + """ + A class to validate the orientation goal by tracking the goal achievement progress. + """ + + def __init__(self, current_orientation_getter: OptionalArgCallable = None, + acceptable_error: Optional[float] = np.pi / 180, + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8, + is_iterable: Optional[bool] = False): + """ + Initialize the orientation goal validator. + + :param current_orientation_getter: The current orientation getter function which takes an optional input and + returns the current orientation. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + :param is_iterable: Whether it is a sequence of quaternions. + """ + super().__init__(OrientationErrorChecker(acceptable_error, is_iterable=is_iterable), current_orientation_getter, + acceptable_percentage_of_goal_achieved=acceptable_percentage_of_goal_achieved) + + +class MultiOrientationGoalValidator(OrientationGoalValidator): + """ + A class to validate the multi-orientation goal by tracking the goal achievement progress. + """ + + def __init__(self, current_orientations_getter: OptionalArgCallable = None, + acceptable_error: Optional[float] = np.pi / 180, + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8): + """ + Initialize the multi-orientation goal validator. + + :param current_orientations_getter: The current orientations getter function which takes an optional input and + returns the current orientations. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + """ + super().__init__(current_orientations_getter, acceptable_error, acceptable_percentage_of_goal_achieved, + is_iterable=True) + + +class JointPositionGoalValidator(GoalValidator): + """ + A class to validate the joint position goal by tracking the goal achievement progress. + """ + + def __init__(self, current_position_getter: OptionalArgCallable = None, + acceptable_error: Optional[float] = None, + acceptable_revolute_joint_position_error: float = np.pi / 180, + acceptable_prismatic_joint_position_error: float = 1e-3, + acceptable_percentage_of_goal_achieved: float = 0.8, + is_iterable: bool = False): + """ + Initialize the joint position goal validator. + + :param current_position_getter: The current position getter function which takes an optional input and returns + the current position. + :param acceptable_error: The acceptable error. + :param acceptable_revolute_joint_position_error: The acceptable orientation error. + :param acceptable_prismatic_joint_position_error: The acceptable position error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + :param is_iterable: Whether it is a sequence of joint positions. + """ + super().__init__(SingleValueErrorChecker(acceptable_error, is_iterable=is_iterable), current_position_getter, + acceptable_percentage_of_goal_achieved=acceptable_percentage_of_goal_achieved) + self.acceptable_orientation_error = acceptable_revolute_joint_position_error + self.acceptable_position_error = acceptable_prismatic_joint_position_error + + def register_goal(self, goal_value: Any, joint_type: JointType, + current_value_getter_input: Optional[Any] = None, + initial_value: Optional[Any] = None, + acceptable_error: Optional[float] = None): + """ + Register the goal value. + + :param goal_value: The goal value. + :param joint_type: The joint type (e.g. REVOLUTE, PRISMATIC). + :param current_value_getter_input: The values that are used as input to the current value getter. + :param initial_value: The initial value. + :param acceptable_error: The acceptable error. + """ + if acceptable_error is None: + self.error_checker.acceptable_error = self.acceptable_orientation_error if joint_type == JointType.REVOLUTE\ + else self.acceptable_position_error + super().register_goal(goal_value, current_value_getter_input, initial_value, acceptable_error) + + +class MultiJointPositionGoalValidator(GoalValidator): + """ + A class to validate the multi-joint position goal by tracking the goal achievement progress. + """ + + def __init__(self, current_positions_getter: OptionalArgCallable = None, + acceptable_error: Optional[Iterable[float]] = None, + acceptable_revolute_joint_position_error: float = np.pi / 180, + acceptable_prismatic_joint_position_error: float = 1e-3, + acceptable_percentage_of_goal_achieved: float = 0.8): + """ + Initialize the multi-joint position goal validator. + + :param current_positions_getter: The current positions getter function which takes an optional input and + returns the current positions. + :param acceptable_error: The acceptable error. + :param acceptable_revolute_joint_position_error: The acceptable orientation error. + :param acceptable_prismatic_joint_position_error: The acceptable position error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + """ + super().__init__(SingleValueErrorChecker(acceptable_error, is_iterable=True), current_positions_getter, + acceptable_percentage_of_goal_achieved) + self.acceptable_orientation_error = acceptable_revolute_joint_position_error + self.acceptable_position_error = acceptable_prismatic_joint_position_error + + def register_goal(self, goal_value: Any, joint_type: Iterable[JointType], + current_value_getter_input: Optional[Any] = None, + initial_value: Optional[Any] = None, + acceptable_error: Optional[Iterable[float]] = None): + if acceptable_error is None: + self.error_checker.acceptable_error = [self.acceptable_orientation_error if jt == JointType.REVOLUTE + else self.acceptable_position_error for jt in joint_type] + super().register_goal(goal_value, current_value_getter_input, initial_value, acceptable_error) + + +def validate_object_pose(pose_setter_func): + """ + A decorator to validate the object pose. + + :param pose_setter_func: The function to set the pose of the object. + """ + + def wrapper(world: 'World', obj: 'Object', pose: 'Pose'): + + world.pose_goal_validator.register_goal(pose, obj) + + if not pose_setter_func(world, obj, pose): + world.pose_goal_validator.reset() + return False + + world.pose_goal_validator.wait_until_goal_is_achieved() + return True + + return wrapper + + +def validate_multiple_object_poses(pose_setter_func): + """ + A decorator to validate multiple object poses. + + :param pose_setter_func: The function to set multiple poses of the objects. + """ + + def wrapper(world: 'World', object_poses: Dict['Object', 'Pose']): + + world.multi_pose_goal_validator.register_goal(list(object_poses.values()), + list(object_poses.keys())) + + if not pose_setter_func(world, object_poses): + world.multi_pose_goal_validator.reset() + return False + + world.multi_pose_goal_validator.wait_until_goal_is_achieved() + return True + + return wrapper + + +def validate_joint_position(position_setter_func): + """ + A decorator to validate the joint position. + + :param position_setter_func: The function to set the joint position. + """ + + def wrapper(world: 'World', joint: 'Joint', position: float): + + joint_type = joint.type + world.joint_position_goal_validator.register_goal(position, joint_type, joint) + + if not position_setter_func(world, joint, position): + world.joint_position_goal_validator.reset() + return False + + world.joint_position_goal_validator.wait_until_goal_is_achieved() + return True + + return wrapper + + +def validate_multiple_joint_positions(position_setter_func): + """ + A decorator to validate the joint positions, this function does not validate the virtual joints, + as in multiverse the virtual joints take command velocities and not positions, so after their goals + are set, they are zeroed thus can't be validated. (They are actually validated by the robot pose in case + of virtual mobile base joints) + + :param position_setter_func: The function to set the joint positions. + """ + + def wrapper(world: 'World', joint_positions: Dict['Joint', float]): + joint_positions_to_validate = {joint: position for joint, position in joint_positions.items() + if not joint.is_virtual} + joint_types = [joint.type for joint in joint_positions_to_validate.keys()] + world.multi_joint_position_goal_validator.register_goal(list(joint_positions_to_validate.values()), joint_types, + list(joint_positions_to_validate.keys())) + if not position_setter_func(world, joint_positions): + world.multi_joint_position_goal_validator.reset() + return False + + world.multi_joint_position_goal_validator.wait_until_goal_is_achieved() + return True + + return wrapper diff --git a/src/pycram/world_concepts/constraints.py b/src/pycram/world_concepts/constraints.py index aa41e3bdf..8de36bda1 100644 --- a/src/pycram/world_concepts/constraints.py +++ b/src/pycram/world_concepts/constraints.py @@ -2,7 +2,7 @@ import numpy as np from geometry_msgs.msg import Point -from typing_extensions import Union, List, Optional, TYPE_CHECKING +from typing_extensions import Union, List, Optional, TYPE_CHECKING, Self from ..datastructures.enums import JointType from ..datastructures.pose import Transform, Pose @@ -30,6 +30,44 @@ def __init__(self, self.child_to_constraint = child_to_constraint self._parent_to_child = None + def get_child_object_pose(self) -> Pose: + """ + :return: The pose of the child object. + """ + return self.child_link.object.pose + + def get_child_object_pose_given_parent(self, pose: Pose) -> Pose: + """ + Get the pose of the child object given the parent pose. + + :param pose: The parent object pose. + :return: The pose of the child object. + """ + pose = self.parent_link.get_pose_given_object_pose(pose) + child_link_pose = self.get_child_link_target_pose_given_parent(pose) + return self.child_link.get_object_pose_given_link_pose(child_link_pose) + + def set_child_link_pose(self): + """ + Set the target pose of the child object to the current pose of the child object in the parent object frame. + """ + self.child_link.set_pose(self.get_child_link_target_pose()) + + def get_child_link_target_pose(self) -> Pose: + """ + :return: The target pose of the child object. (The pose of the child object in the parent object frame) + """ + return self.parent_to_child_transform.to_pose() + + def get_child_link_target_pose_given_parent(self, parent_pose: Pose) -> Pose: + """ + Get the target pose of the child object link given the parent link pose. + + :param parent_pose: The parent link pose. + :return: The target pose of the child object link. + """ + return (parent_pose.to_transform(self.parent_link.tf_frame) * self.parent_to_child_transform).to_pose() + @property def parent_to_child_transform(self) -> Union[Transform, None]: if self._parent_to_child is None: @@ -44,8 +82,6 @@ def parent_to_child_transform(self, transform: Transform) -> None: @property def parent_object_id(self) -> int: """ - Returns the id of the parent object of the constraint. - :return: The id of the parent object of the constraint """ return self.parent_link.object_id @@ -53,8 +89,6 @@ def parent_object_id(self) -> int: @property def child_object_id(self) -> int: """ - Returns the id of the child object of the constraint. - :return: The id of the child object of the constraint """ return self.child_link.object_id @@ -62,8 +96,6 @@ def child_object_id(self) -> int: @property def parent_link_id(self) -> int: """ - Returns the id of the parent link of the constraint. - :return: The id of the parent link of the constraint """ return self.parent_link.id @@ -71,8 +103,6 @@ def parent_link_id(self) -> int: @property def child_link_id(self) -> int: """ - Returns the id of the child link of the constraint. - :return: The id of the child link of the constraint """ return self.child_link.id @@ -80,8 +110,6 @@ def child_link_id(self) -> int: @property def position_wrt_parent_as_list(self) -> List[float]: """ - Returns the constraint frame pose with respect to the parent origin as a list. - :return: The constraint frame pose with respect to the parent origin as a list """ return self.pose_wrt_parent.position_as_list() @@ -89,8 +117,6 @@ def position_wrt_parent_as_list(self) -> List[float]: @property def orientation_wrt_parent_as_list(self) -> List[float]: """ - Returns the constraint frame orientation with respect to the parent origin as a list. - :return: The constraint frame orientation with respect to the parent origin as a list """ return self.pose_wrt_parent.orientation_as_list() @@ -98,8 +124,6 @@ def orientation_wrt_parent_as_list(self) -> List[float]: @property def pose_wrt_parent(self) -> Pose: """ - Returns the joint frame pose with respect to the parent origin. - :return: The joint frame pose with respect to the parent origin """ return self.parent_to_constraint.to_pose() @@ -107,8 +131,6 @@ def pose_wrt_parent(self) -> Pose: @property def position_wrt_child_as_list(self) -> List[float]: """ - Returns the constraint frame pose with respect to the child origin as a list. - :return: The constraint frame pose with respect to the child origin as a list """ return self.pose_wrt_child.position_as_list() @@ -116,8 +138,6 @@ def position_wrt_child_as_list(self) -> List[float]: @property def orientation_wrt_child_as_list(self) -> List[float]: """ - Returns the constraint frame orientation with respect to the child origin as a list. - :return: The constraint frame orientation with respect to the child origin as a list """ return self.pose_wrt_child.orientation_as_list() @@ -125,8 +145,6 @@ def orientation_wrt_child_as_list(self) -> List[float]: @property def pose_wrt_child(self) -> Pose: """ - Returns the joint frame pose with respect to the child origin. - :return: The joint frame pose with respect to the child origin """ return self.child_to_constraint.to_pose() @@ -151,8 +169,6 @@ def __init__(self, @property def axis_as_list(self) -> List[float]: """ - Returns the axis of this constraint as a list. - :return: The axis of this constraint as a list of xyz """ return [self.axis.x, self.axis.y, self.axis.z] @@ -162,9 +178,10 @@ class Attachment(AbstractConstraint): def __init__(self, parent_link: Link, child_link: Link, - bidirectional: Optional[bool] = False, + bidirectional: bool = False, parent_to_child_transform: Optional[Transform] = None, - constraint_id: Optional[int] = None): + constraint_id: Optional[int] = None, + is_inverse: bool = False): """ Creates an attachment between the parent object link and the child object link. This could be a bidirectional attachment, meaning that both objects will move when one moves. @@ -180,48 +197,61 @@ def __init__(self, self.id = constraint_id self.bidirectional: bool = bidirectional self._loose: bool = False + self.is_inverse: bool = is_inverse - if self.parent_to_child_transform is None: + if parent_to_child_transform is not None: + self.parent_to_child_transform = parent_to_child_transform + + elif self.parent_to_child_transform is None: self.update_transform() if self.id is None: self.add_fixed_constraint() + @property + def parent_object(self): + return self.parent_link.object + + @property + def child_object(self): + return self.child_link.object + def update_transform_and_constraint(self) -> None: """ - Updates the transform and constraint of this attachment. + Update the transform and constraint of this attachment. """ self.update_transform() self.update_constraint() def update_transform(self) -> None: """ - Updates the transform of this attachment by calculating the transform from the parent link to the child link. + Update the transform of this attachment by calculating the transform from the parent link to the child link. """ self.parent_to_child_transform = self.calculate_transform() def update_constraint(self) -> None: """ - Updates the constraint of this attachment by removing the old constraint if one exists and adding a new one. + Update the constraint of this attachment by removing the old constraint if one exists and adding a new one. """ self.remove_constraint_if_exists() self.add_fixed_constraint() def add_fixed_constraint(self) -> None: """ - Adds a fixed constraint between the parent link and the child link. + Add a fixed constraint between the parent link and the child link. """ - self.id = self.parent_link.add_fixed_constraint_with_link(self.child_link) + self.id = self.parent_link.add_fixed_constraint_with_link(self.child_link, + self.parent_to_child_transform.invert()) def calculate_transform(self) -> Transform: """ - Calculates the transform from the parent link to the child link. + Calculate the transform from the parent link to the child link. """ return self.parent_link.get_transform_to_link(self.child_link) def remove_constraint_if_exists(self) -> None: """ - Removes the constraint between the parent and the child links if one exists. + Remove the constraint between the parent and the child links if one exists. """ if self.child_link in self.parent_link.constraint_ids: self.parent_link.remove_constraint_with_link(self.child_link) @@ -231,7 +261,7 @@ def get_inverse(self) -> 'Attachment': :return: A new Attachment object with the parent and child links swapped. """ attachment = Attachment(self.child_link, self.parent_link, self.bidirectional, - constraint_id=self.id) + constraint_id=self.id, is_inverse=not self.is_inverse) attachment.loose = not self._loose return attachment @@ -245,22 +275,15 @@ def loose(self) -> bool: @loose.setter def loose(self, loose: bool) -> None: """ - Sets the loose property of this attachment. + Set the loose property of this attachment. :param loose: If true, then the child object will not move when parent moves. """ self._loose = loose and not self.bidirectional - @property - def is_reversed(self) -> bool: - """ - :return: True if the parent and child links are swapped. - """ - return self.loose - def __del__(self) -> None: """ - Removes the constraint between the parent and the child links if one exists when the attachment is deleted. + Remove the constraint between the parent and the child links if one exists when the attachment is deleted. """ self.remove_constraint_if_exists() @@ -272,10 +295,11 @@ def __eq__(self, other): return (self.parent_link.name == other.parent_link.name and self.child_link.name == other.child_link.name and self.bidirectional == other.bidirectional + and self.loose == other.loose and np.allclose(self.parent_to_child_transform.translation_as_list(), - other.parent_to_child_transform.translation_as_list(), rtol=0, atol=1e-4) + other.parent_to_child_transform.translation_as_list(), rtol=0, atol=1e-3) and np.allclose(self.parent_to_child_transform.rotation_as_list(), - other.parent_to_child_transform.rotation_as_list(), rtol=0, atol=1e-4)) + other.parent_to_child_transform.rotation_as_list(), rtol=0, atol=1e-3)) def __hash__(self): return hash((self.parent_link.name, self.child_link.name, self.bidirectional, self.parent_to_child_transform)) diff --git a/src/pycram/world_concepts/world_object.py b/src/pycram/world_concepts/world_object.py index 986a26ebd..4975fa983 100644 --- a/src/pycram/world_concepts/world_object.py +++ b/src/pycram/world_concepts/world_object.py @@ -2,24 +2,34 @@ import logging import os +from pathlib import Path import numpy as np import rospy +from deprecated import deprecated from geometry_msgs.msg import Point, Quaternion from typing_extensions import Type, Optional, Dict, Tuple, List, Union -from ..description import ObjectDescription, LinkDescription, Joint -from ..object_descriptors.urdf import ObjectDescription as URDFObject -from ..object_descriptors.generic import ObjectDescription as GenericObjectDescription -from ..robot_descriptions import robot_description -from ..datastructures.world import WorldEntity, World -from ..world_concepts.constraints import Attachment from ..datastructures.dataclasses import (Color, ObjectState, LinkState, JointState, - AxisAlignedBoundingBox, VisualShape) + AxisAlignedBoundingBox, VisualShape, ClosestPointsList, + ContactPointsList) from ..datastructures.enums import ObjectType, JointType -from ..local_transformer import LocalTransformer from ..datastructures.pose import Pose, Transform -from ..robot_description import RobotDescriptionManager +from ..datastructures.world import World +from ..datastructures.world_entity import WorldEntity +from ..description import ObjectDescription, LinkDescription, Joint +from ..failures import ObjectAlreadyExists, WorldMismatchErrorBetweenObjects, UnsupportedFileExtension, \ + ObjectDescriptionUndefined +from ..local_transformer import LocalTransformer +from ..object_descriptors.generic import ObjectDescription as GenericObjectDescription +from ..object_descriptors.urdf import ObjectDescription as URDF + +try: + from ..object_descriptors.mjcf import ObjectDescription as MJCF +except ImportError: + MJCF = None +from ..robot_description import RobotDescriptionManager, RobotDescription +from ..world_concepts.constraints import Attachment Link = ObjectDescription.Link @@ -29,18 +39,23 @@ class Object(WorldEntity): Represents a spawned Object in the World. """ - prospection_world_prefix: str = "prospection/" + tf_prospection_world_prefix: str = "prospection/" + """ + The prefix for the tf frame of objects in the prospection world. + """ + + extension_to_description_type: Dict[str, Type[ObjectDescription]] = {URDF.get_file_extension(): URDF} """ - The ObjectDescription of the object, this contains the name and type of the object as well as the path to the source - file. + A dictionary that maps the file extension to the corresponding ObjectDescription type. """ - def __init__(self, name: str, obj_type: ObjectType, path: str, - description: Optional[Type[ObjectDescription]] = URDFObject, + def __init__(self, name: str, obj_type: ObjectType, path: Optional[str] = None, + description: Optional[ObjectDescription] = None, pose: Optional[Pose] = None, world: Optional[World] = None, - color: Optional[Color] = Color(), - ignore_cached_files: Optional[bool] = False): + color: Color = Color(), + ignore_cached_files: bool = False, + scale_mesh: Optional[float] = None): """ The constructor loads the description file into the given World, if no World is specified the :py:attr:`~World.current_world` will be used. It is also possible to load .obj and .stl file into the World. @@ -49,37 +64,45 @@ def __init__(self, name: str, obj_type: ObjectType, path: str, :param name: The name of the object :param obj_type: The type of the object as an ObjectType enum. - :param path: The path to the source file, if only a filename is provided then the resources directories will be searched. + :param path: The path to the source file, if only a filename is provided then the resources directories will be + searched, it could be None in some cases when for example it is a generic object. :param description: The ObjectDescription of the object, this contains the joints and links of the object. :param pose: The pose at which the Object should be spawned - :param world: The World in which the object should be spawned, if no world is specified the :py:attr:`~World.current_world` will be used. + :param world: The World in which the object should be spawned, if no world is specified the + :py:attr:`~World.current_world` will be used. :param color: The rgba_color with which the object should be spawned. :param ignore_cached_files: If true the file will be spawned while ignoring cached files. + :param scale_mesh: The scale of the mesh. """ - super().__init__(-1, world) + super().__init__(-1, world if world is not None else World.current_world) + + pose = Pose() if pose is None else pose - if pose is None: - pose = Pose() - if name in [obj.name for obj in self.world.objects]: - rospy.logerr(f"An object with the name {name} already exists in the world.") - return None self.name: str = name + self.path: Optional[str] = path self.obj_type: ObjectType = obj_type self.color: Color = color - self.description = description() + self._resolve_description(path, description) self.cache_manager = self.world.cache_manager self.local_transformer = LocalTransformer() self.original_pose = self.local_transformer.transform_pose(pose, "map") self._current_pose = self.original_pose - self.id, self.path = self._load_object_and_get_id(path, ignore_cached_files) + if path is not None: + self.path = self.world.preprocess_object_file_and_get_its_cache_path(path, ignore_cached_files, + self.description, self.name, + scale_mesh=scale_mesh) + + self.description.update_description_from_file(self.path) + + if self.obj_type == ObjectType.ROBOT and not self.world.is_prospection_world: + self._update_world_robot_and_description() - self.description.update_description_from_file(self.path) + self.id = self._spawn_object_and_get_id() - self.tf_frame = ((self.prospection_world_prefix if self.world.is_prospection_world else "") - + f"{self.name}") + self.tf_frame = (self.tf_prospection_world_prefix if self.world.is_prospection_world else "") + self.name self._init_joint_name_and_id_map() self._init_link_name_and_id_map() @@ -89,26 +112,178 @@ def __init__(self, name: str, obj_type: ObjectType, path: str, self.attachments: Dict[Object, Attachment] = {} - if not self.world.is_prospection_world: - self._add_to_world_sync_obj_queue() + self.world.add_object(self) - self.world.objects.append(self) + def _resolve_description(self, path: Optional[str] = None, description: Optional[ObjectDescription] = None) -> None: + """ + Find the correct description type of the object and initialize it and set the description of this object to it. - if self.obj_type == ObjectType.ROBOT and not self.world.is_prospection_world: - rdm = RobotDescriptionManager() - rdm.load_description(self.description.name) - World.robot = self + :param path: The path to the source file. + :param description: The ObjectDescription of the object. + """ + if description is not None: + self.description = description + return + if path is None: + raise ObjectDescriptionUndefined(self.name) + extension = Path(path).suffix + if extension in self.extension_to_description_type: + self.description = self.extension_to_description_type[extension]() + elif extension in ObjectDescription.mesh_extensions: + self.description = self.world.conf.default_description_type() + else: + raise UnsupportedFileExtension(self.name, path) + + def set_mobile_robot_pose(self, pose: Pose) -> None: + """ + Set the goal for the mobile base joints of a mobile robot to reach a target pose. This is used for example when + the simulator does not support setting the pose of the robot directly (e.g. MuJoCo). + + :param pose: The target pose. + """ + goal = self.get_mobile_base_joint_goal(pose) + self.set_multiple_joint_positions(goal) + + def get_mobile_base_joint_goal(self, pose: Pose) -> Dict[str, float]: + """ + Get the goal for the mobile base joints of a mobile robot to reach a target pose. + + :param pose: The target pose. + :return: The goal for the mobile base joints. + """ + target_translation, target_angle = self.get_mobile_base_pose_difference(pose) + # Get the joints of the base link + mobile_base_joints = self.world.get_robot_mobile_base_joints() + return {mobile_base_joints.translation_x: target_translation.x, + mobile_base_joints.translation_y: target_translation.y, + mobile_base_joints.angular_z: target_angle} + + def get_mobile_base_pose_difference(self, pose: Pose) -> Tuple[Point, float]: + """ + Get the difference between the current and the target pose of the mobile base. + + :param pose: The target pose. + :return: The difference between the current and the target pose of the mobile base. + """ + return self.original_pose.get_position_diff(pose), self.original_pose.get_z_angle_difference(pose) + + @property + def joint_actuators(self) -> Optional[Dict[str, str]]: + """ + The joint actuators of the robot. + """ + if self.obj_type == ObjectType.ROBOT: + return self.robot_description.joint_actuators + return None + + @property + def has_actuators(self) -> bool: + """ + True if the object has actuators, otherwise False. + """ + return self.robot_description.has_actuators + + @property + def robot_description(self) -> RobotDescription: + """ + The current robot description. + """ + return self.world.robot_description + + def get_actuator_for_joint(self, joint: Joint) -> Optional[str]: + """ + Get the actuator name for a joint. + + :param joint: The joint object for which to get the actuator. + :return: The name of the actuator. + """ + return self.robot_description.get_actuator_for_joint(joint.name) + + def get_multiple_link_positions(self, links: List[Link]) -> Dict[str, List[float]]: + """ + Get the positions of multiple links of the object. + + :param links: The link objects of which to get the positions. + :return: The positions of the links. + """ + return self.world.get_multiple_link_positions(links) + + def get_multiple_link_orientations(self, links: List[Link]) -> Dict[str, List[float]]: + """ + Get the orientations of multiple links of the object. + + :param links: The link objects of which to get the orientations. + :return: The orientations of the links. + """ + return self.world.get_multiple_link_orientations(links) + + def get_multiple_link_poses(self, links: List[Link]) -> Dict[str, Pose]: + """ + Get the poses of multiple links of the object. + + :param links: The link objects of which to get the poses. + :return: The poses of the links. + """ + return self.world.get_multiple_link_poses(links) + + def get_poses_of_attached_objects(self) -> Dict[Object, Pose]: + """ + Get the poses of the attached objects. + + :return: The poses of the attached objects + """ + return {child_object: attachment.get_child_object_pose() + for child_object, attachment in self.attachments.items() if not attachment.loose} + + def get_target_poses_of_attached_objects_given_parent(self, pose: Pose) -> Dict[Object, Pose]: + """ + Get the target poses of the attached objects of an object. Given the pose of the parent object. (i.e. the poses + to which the attached objects will move when the parent object is at the given pose) + + :param pose: The pose of the parent object. + :return: The target poses of the attached objects + """ + return {child_object: attachment.get_child_object_pose_given_parent(pose) for child_object, attachment + in self.attachments.items() if not attachment.loose} + + @property + def name(self): + """ + The name of the object. + """ + return self._name + + @name.setter + def name(self, name: str): + """ + Set the name of the object. + """ + self._name = name + if name in [obj.name for obj in self.world.objects]: + raise ObjectAlreadyExists(self) @property def pose(self): + """ + The current pose of the object. + """ return self.get_pose() @pose.setter def pose(self, pose: Pose): + """ + Set the pose of the object. + """ self.set_pose(pose) - def _load_object_and_get_id(self, path: Optional[str] = None, - ignore_cached_files: Optional[bool] = False) -> Tuple[int, Union[str, None]]: + @property + def transform(self): + """ + The current transform of the object. + """ + return self.get_pose().to_transform(self.tf_frame) + + def _spawn_object_and_get_id(self) -> int: """ Loads an object to the given World with the given position and orientation. The rgba_color will only be used when an .obj or .stl file is given. @@ -116,31 +291,18 @@ def _load_object_and_get_id(self, path: Optional[str] = None, and this URDf file will be loaded instead. When spawning a URDf file a new file will be created in the cache directory, if there exists none. This new file will have resolved mesh file paths, meaning there will be no references - to ROS packges instead there will be absolute file paths. + to ROS packages instead there will be absolute file paths. - :param path: The path to the description file, if None then no file will be loaded, this is useful when the PyCRAM is not responsible for loading the file but another system is. - :param ignore_cached_files: Whether to ignore files in the cache directory. :return: The unique id of the object and the path of the file that was loaded. """ if isinstance(self.description, GenericObjectDescription): - return self.world.load_generic_object_and_get_id(self.description), path + return self.world.load_generic_object_and_get_id(self.description, pose=self._current_pose) - if path is not None: - try: - path = self.world.update_cache_dir_with_object(path, ignore_cached_files, self) - except FileNotFoundError as e: - logging.error("Could not generate description from file.") - raise e + path = self.path if self.world.conf.let_pycram_handle_spawning else self.name try: - simulator_object_path = path - if simulator_object_path is None: - # This is useful when the object is already loaded in the simulator so it would use its name instead of - # its path - simulator_object_path = self.name - obj_id = self.world.load_object_and_get_id(simulator_object_path, Pose(self.get_position_as_list(), - self.get_orientation_as_list())) - return obj_id, path + obj_id = self.world.load_object_and_get_id(path, self._current_pose, self.obj_type) + return obj_id except Exception as e: logging.error( @@ -149,9 +311,31 @@ def _load_object_and_get_id(self, path: Optional[str] = None, os.remove(path) raise e + def _update_world_robot_and_description(self): + """ + Initialize the robot description of the object, load the description from the RobotDescriptionManager and set + the robot as the current robot in the World. Also add the virtual mobile base joints to the robot. + """ + rdm = RobotDescriptionManager() + rdm.load_description(self.description.name) + World.robot = self + self._add_virtual_move_base_joints() + + def _add_virtual_move_base_joints(self): + """ + Add the virtual mobile base joints to the robot description. + """ + virtual_joints = self.robot_description.virtual_mobile_base_joints + if virtual_joints is None: + return + child_link = self.description.get_root() + axes = virtual_joints.get_axes() + for joint_name, joint_type in virtual_joints.get_types().items(): + self.description.add_joint(joint_name, child_link, joint_type, axes[joint_name], is_virtual=True) + def _init_joint_name_and_id_map(self) -> None: """ - Creates a dictionary which maps the joint names to their unique ids and vice versa. + Create a dictionary which maps the joint names to their unique ids and vice versa. """ n_joints = len(self.joint_names) self.joint_name_to_id = dict(zip(self.joint_names, range(n_joints))) @@ -159,7 +343,7 @@ def _init_joint_name_and_id_map(self) -> None: def _init_link_name_and_id_map(self) -> None: """ - Creates a dictionary which maps the link names to their unique ids and vice versa. + Create a dictionary which maps the link names to their unique ids and vice versa. """ n_links = len(self.link_names) self.link_name_to_id: Dict[str, int] = dict(zip(self.link_names, range(n_links))) @@ -168,7 +352,7 @@ def _init_link_name_and_id_map(self) -> None: def _init_links_and_update_transforms(self) -> None: """ - Initializes the link objects from the URDF file and creates a dictionary which maps the link names to the + Initialize the link objects from the URDF file and creates a dictionary which maps the link names to the corresponding link objects. """ self.links = {} @@ -188,32 +372,54 @@ def _init_joints(self): """ self.joints = {} for joint_name, joint_id in self.joint_name_to_id.items(): - joint_description = self.description.get_joint_by_name(joint_name) - self.joints[joint_name] = self.description.Joint(joint_id, joint_description, self) + parsed_joint_description = self.description.get_joint_by_name(joint_name) + is_virtual = self.is_joint_virtual(joint_name) + self.joints[joint_name] = self.description.Joint(joint_id, parsed_joint_description, self, is_virtual) - def _add_to_world_sync_obj_queue(self) -> None: + def is_joint_virtual(self, name: str): """ - Adds this object to the objects queue of the WorldSync object of the World. + Check if a joint is virtual. """ - self.world.world_sync.add_obj_queue.put(self) + return self.description.is_joint_virtual(name) + + @property + def virtual_joint_names(self): + """ + The names of the virtual joints. + """ + return self.description.virtual_joint_names + + @property + def virtual_joints(self): + """ + The virtual joints as a list. + """ + return [joint for joint in self.joints.values() if joint.is_virtual] + + @property + def has_one_link(self) -> bool: + """ + True if the object has only one link, otherwise False. + """ + return len(self.links) == 1 @property def link_names(self) -> List[str]: """ - :return: The name of each link as a list. + The names of the links as a list. """ return self.world.get_object_link_names(self) @property def joint_names(self) -> List[str]: """ - :return: The name of each joint as a list. + The names of the joints as a list. """ return self.world.get_object_joint_names(self) def get_link(self, link_name: str) -> ObjectDescription.Link: """ - Returns the link object with the given name. + Return the link object with the given name. :param link_name: The name of the link. :return: The link object. @@ -222,7 +428,7 @@ def get_link(self, link_name: str) -> ObjectDescription.Link: def get_link_pose(self, link_name: str) -> Pose: """ - Returns the pose of the link with the given name. + Return the pose of the link with the given name. :param link_name: The name of the link. :return: The pose of the link. @@ -231,7 +437,7 @@ def get_link_pose(self, link_name: str) -> Pose: def get_link_position(self, link_name: str) -> Point: """ - Returns the position of the link with the given name. + Return the position of the link with the given name. :param link_name: The name of the link. :return: The position of the link. @@ -240,7 +446,7 @@ def get_link_position(self, link_name: str) -> Point: def get_link_position_as_list(self, link_name: str) -> List[float]: """ - Returns the position of the link with the given name. + Return the position of the link with the given name. :param link_name: The name of the link. :return: The position of the link. @@ -249,7 +455,7 @@ def get_link_position_as_list(self, link_name: str) -> List[float]: def get_link_orientation(self, link_name: str) -> Quaternion: """ - Returns the orientation of the link with the given name. + Return the orientation of the link with the given name. :param link_name: The name of the link. :return: The orientation of the link. @@ -258,7 +464,7 @@ def get_link_orientation(self, link_name: str) -> Quaternion: def get_link_orientation_as_list(self, link_name: str) -> List[float]: """ - Returns the orientation of the link with the given name. + Return the orientation of the link with the given name. :param link_name: The name of the link. :return: The orientation of the link. @@ -267,7 +473,7 @@ def get_link_orientation_as_list(self, link_name: str) -> List[float]: def get_link_tf_frame(self, link_name: str) -> str: """ - Returns the tf frame of the link with the given name. + Return the tf frame of the link with the given name. :param link_name: The name of the link. :return: The tf frame of the link. @@ -276,7 +482,7 @@ def get_link_tf_frame(self, link_name: str) -> str: def get_link_axis_aligned_bounding_box(self, link_name: str) -> AxisAlignedBoundingBox: """ - Returns the axis aligned bounding box of the link with the given name. + Return the axis aligned bounding box of the link with the given name. :param link_name: The name of the link. :return: The axis aligned bounding box of the link. @@ -285,7 +491,7 @@ def get_link_axis_aligned_bounding_box(self, link_name: str) -> AxisAlignedBound def get_transform_between_links(self, from_link: str, to_link: str) -> Transform: """ - Returns the transform between two links. + Return the transform between two links. :param from_link: The name of the link from which the transform should be calculated. :param to_link: The name of the link to which the transform should be calculated. @@ -294,7 +500,7 @@ def get_transform_between_links(self, from_link: str, to_link: str) -> Transform def get_link_color(self, link_name: str) -> Color: """ - Returns the color of the link with the given name. + Return the color of the link with the given name. :param link_name: The name of the link. :return: The color of the link. @@ -303,7 +509,7 @@ def get_link_color(self, link_name: str) -> Color: def set_link_color(self, link_name: str, color: List[float]) -> None: """ - Sets the color of the link with the given name. + Set the color of the link with the given name. :param link_name: The name of the link. :param color: The new color of the link. @@ -312,7 +518,7 @@ def set_link_color(self, link_name: str, color: List[float]) -> None: def get_link_geometry(self, link_name: str) -> Union[VisualShape, None]: """ - Returns the geometry of the link with the given name. + Return the geometry of the link with the given name. :param link_name: The name of the link. :return: The geometry of the link. @@ -321,7 +527,7 @@ def get_link_geometry(self, link_name: str) -> Union[VisualShape, None]: def get_link_transform(self, link_name: str) -> Transform: """ - Returns the transform of the link with the given name. + Return the transform of the link with the given name. :param link_name: The name of the link. :return: The transform of the link. @@ -330,7 +536,7 @@ def get_link_transform(self, link_name: str) -> Transform: def get_link_origin(self, link_name: str) -> Pose: """ - Returns the origin of the link with the given name. + Return the origin of the link with the given name. :param link_name: The name of the link. :return: The origin of the link as a 'Pose'. @@ -339,7 +545,7 @@ def get_link_origin(self, link_name: str) -> Pose: def get_link_origin_transform(self, link_name: str) -> Transform: """ - Returns the origin transform of the link with the given name. + Return the origin transform of the link with the given name. :param link_name: The name of the link. :return: The origin transform of the link. @@ -362,16 +568,16 @@ def __repr__(self): def remove(self) -> None: """ - Removes this object from the World it currently resides in. + Remove this object from the World it currently resides in. For the object to be removed it has to be detached from all objects it - is currently attached to. After this is done a call to world remove object is done + is currently attached to. After this call world remove object to remove this Object from the simulation/world. """ self.world.remove_object(self) - def reset(self, remove_saved_states=True) -> None: + def reset(self, remove_saved_states=False) -> None: """ - Resets the Object to the state it was first spawned in. + Reset the Object to the state it was first spawned in. All attached objects will be detached, all joints will be set to the default position of 0 and the object will be set to the position and orientation in which it was spawned. @@ -384,13 +590,23 @@ def reset(self, remove_saved_states=True) -> None: if remove_saved_states: self.remove_saved_states() + def has_type_environment(self) -> bool: + """ + Check if the object is of type environment. + + :return: True if the object is of type environment, False otherwise. + """ + return self.obj_type == ObjectType.ENVIRONMENT + def attach(self, child_object: Object, parent_link: Optional[str] = None, child_link: Optional[str] = None, - bidirectional: Optional[bool] = True) -> None: + bidirectional: bool = True, + coincide_the_objects: bool = False, + parent_to_child_transform: Optional[Transform] = None) -> None: """ - Attaches another object to this object. This is done by + Attach another object to this object. This is done by saving the transformation between the given link, if there is one, and the base pose of the other object. Additionally, the name of the link, to which the object is attached, will be saved. @@ -403,11 +619,15 @@ def attach(self, :param parent_link: The link name of this object. :param child_link: The link name of the other object. :param bidirectional: If the attachment should be a loose attachment. + :param coincide_the_objects: If True the object frames will be coincided. + :param parent_to_child_transform: The transform from the parent to the child object. """ parent_link = self.links[parent_link] if parent_link else self.root_link child_link = child_object.links[child_link] if child_link else child_object.root_link - attachment = Attachment(parent_link, child_link, bidirectional) + if coincide_the_objects and parent_to_child_transform is None: + parent_to_child_transform = Transform() + attachment = Attachment(parent_link, child_link, bidirectional, parent_to_child_transform) self.attachments[child_object] = attachment child_object.attachments[self] = attachment.get_inverse() @@ -416,7 +636,7 @@ def attach(self, def detach(self, child_object: Object) -> None: """ - Detaches another object from this object. This is done by + Detache another object from this object. This is done by deleting the attachment from the attachments dictionary of both objects and deleting the constraint of the simulator. Afterward the detachment event of the corresponding World will be fired. @@ -441,7 +661,7 @@ def update_attachment_with_object(self, child_object: Object): def get_position(self) -> Point: """ - Returns the position of this Object as a list of xyz. + Return the position of this Object as a list of xyz. :return: The current position of this object """ @@ -449,7 +669,7 @@ def get_position(self) -> Point: def get_orientation(self) -> Pose.orientation: """ - Returns the orientation of this object as a list of xyzw, representing a quaternion. + Return the orientation of this object as a list of xyzw, representing a quaternion. :return: A list of xyzw """ @@ -457,7 +677,7 @@ def get_orientation(self) -> Pose.orientation: def get_position_as_list(self) -> List[float]: """ - Returns the position of this Object as a list of xyz. + Return the position of this Object as a list of xyz. :return: The current position of this object """ @@ -465,7 +685,7 @@ def get_position_as_list(self) -> List[float]: def get_base_position_as_list(self) -> List[float]: """ - Returns the position of this Object as a list of xyz. + Return the position of this Object as a list of xyz. :return: The current position of this object """ @@ -473,7 +693,7 @@ def get_base_position_as_list(self) -> List[float]: def get_orientation_as_list(self) -> List[float]: """ - Returns the orientation of this object as a list of xyzw, representing a quaternion. + Return the orientation of this object as a list of xyzw, representing a quaternion. :return: A list of xyzw """ @@ -481,15 +701,17 @@ def get_orientation_as_list(self) -> List[float]: def get_pose(self) -> Pose: """ - Returns the position of this object as a list of xyz. Alias for :func:`~Object.get_position`. + Return the position of this object as a list of xyz. Alias for :func:`~Object.get_position`. :return: The current pose of this object """ + if self.world.conf.update_poses_from_sim_on_get: + self.update_pose() return self._current_pose - def set_pose(self, pose: Pose, base: Optional[bool] = False, set_attachments: Optional[bool] = True) -> None: + def set_pose(self, pose: Pose, base: bool = False, set_attachments: bool = True) -> None: """ - Sets the Pose of the object. + Set the Pose of the object. :param pose: New Pose for the object :param base: If True places the object base instead of origin at the specified position and orientation @@ -505,23 +727,24 @@ def set_pose(self, pose: Pose, base: Optional[bool] = False, set_attachments: Op self._set_attached_objects_poses() def reset_base_pose(self, pose: Pose): - self.world.reset_object_base_pose(self, pose) - self.update_pose() + if self.world.reset_object_base_pose(self, pose): + self.update_pose() def update_pose(self): """ - Updates the current pose of this object from the world, and updates the poses of all links. + Update the current pose of this object from the world, and updates the poses of all links. """ self._current_pose = self.world.get_object_pose(self) + # TODO: Probably not needed, need to test self._update_all_links_poses() self.update_link_transforms() def _update_all_links_poses(self): """ - Updates the poses of all links by getting them from the simulator. + Update the poses of all links by getting them from the simulator. """ for link in self.links.values(): - link._update_pose() + link.update_pose() def move_base_to_origin_pose(self) -> None: """ @@ -532,7 +755,7 @@ def move_base_to_origin_pose(self) -> None: def save_state(self, state_id) -> None: """ - Saves the state of this object by saving the state of all links and attachments. + Save the state of this object by saving the state of all links and attachments. :param state_id: The unique id of the state. """ @@ -542,7 +765,7 @@ def save_state(self, state_id) -> None: def save_links_states(self, state_id: int) -> None: """ - Saves the state of all links of this object. + Save the state of all links of this object. :param state_id: The unique id of the state. """ @@ -551,7 +774,7 @@ def save_links_states(self, state_id: int) -> None: def save_joints_states(self, state_id: int) -> None: """ - Saves the state of all joints of this object. + Save the state of all joints of this object. :param state_id: The unique id of the state. """ @@ -560,40 +783,107 @@ def save_joints_states(self, state_id: int) -> None: @property def current_state(self) -> ObjectState: - return ObjectState(self.get_pose().copy(), self.attachments.copy(), self.link_states.copy(), self.joint_states.copy()) + """ + The current state of this object as an ObjectState. + """ + return ObjectState(self.get_pose().copy(), self.attachments.copy(), self.link_states.copy(), + self.joint_states.copy(), self.world.conf.get_pose_tolerance()) @current_state.setter def current_state(self, state: ObjectState) -> None: - if self.get_pose().dist(state.pose) != 0.0: + """ + Set the current state of this object to the given state. + """ + if self.current_state != state: self.set_pose(state.pose, base=False, set_attachments=False) - - self.set_attachments(state.attachments) - self.link_states = state.link_states - self.joint_states = state.joint_states + self.set_attachments(state.attachments) + self.link_states = state.link_states + self.joint_states = state.joint_states def set_attachments(self, attachments: Dict[Object, Attachment]) -> None: """ - Sets the attachments of this object to the given attachments. + Set the attachments of this object to the given attachments. + + :param attachments: A dictionary with the object as key and the attachment as value. + """ + self.detach_objects_not_in_attachments(attachments) + self.attach_objects_in_attachments(attachments) + + def detach_objects_not_in_attachments(self, attachments: Dict[Object, Attachment]) -> None: + """ + Detach objects that are not in the attachments list and are in the current attachments list. + + :param attachments: A dictionary with the object as key and the attachment as value. + """ + copy_of_attachments = self.attachments.copy() + for obj, attachment in copy_of_attachments.items(): + original_obj = obj + if self.world.is_prospection_world and len(attachments) > 0 \ + and not list(attachments.keys())[0].world.is_prospection_world: + obj = self.world.get_object_for_prospection_object(obj) + if obj not in attachments: + if attachment.is_inverse: + original_obj.detach(self) + else: + self.detach(original_obj) + + def attach_objects_in_attachments(self, attachments: Dict[Object, Attachment]) -> None: + """ + Attach objects that are in the given attachments list but not in the current attachments list. :param attachments: A dictionary with the object as key and the attachment as value. """ for obj, attachment in attachments.items(): - if self.world.is_prospection_world and not obj.world.is_prospection_world: - # In case this object is in the prospection world and the other object is not, the attachment will no - # be set. - continue + is_prospection = self.world.is_prospection_world and not obj.world.is_prospection_world + if is_prospection: + obj = self.world.get_prospection_object_for_object(obj) if obj in self.attachments: if self.attachments[obj] != attachment: - self.detach(obj) + if attachment.is_inverse: + obj.detach(self) + else: + self.detach(obj) else: continue - self.attach(obj, attachment.parent_link.name, attachment.child_link.name, - attachment.bidirectional) + self.mimic_attachment_with_object(attachment, obj) + + def mimic_attachment_with_object(self, attachment: Attachment, child_object: Object) -> None: + """ + Mimic the given attachment for this and the given child objects. + + :param attachment: The attachment to mimic. + :param child_object: The child object. + """ + att_transform = self.get_attachment_transform_with_object(attachment, child_object) + if attachment.is_inverse: + child_object.attach(self, attachment.child_link.name, attachment.parent_link.name, + attachment.bidirectional, + parent_to_child_transform=att_transform.invert()) + else: + self.attach(child_object, attachment.parent_link.name, attachment.child_link.name, + attachment.bidirectional, parent_to_child_transform=att_transform) + + def get_attachment_transform_with_object(self, attachment: Attachment, child_object: Object) -> Transform: + """ + Return the attachment transform for the given parent and child objects, taking into account the prospection + world. + + :param attachment: The attachment. + :param child_object: The child object. + :return: The attachment transform. + """ + if self.world != child_object.world: + raise WorldMismatchErrorBetweenObjects(self, child_object) + att_transform = attachment.parent_to_child_transform.copy() + if self.world.is_prospection_world and not attachment.parent_object.world.is_prospection_world: + att_transform.frame = self.tf_prospection_world_prefix + att_transform.frame + att_transform.child_frame_id = self.tf_prospection_world_prefix + att_transform.child_frame_id + return att_transform @property def link_states(self) -> Dict[int, LinkState]: """ - Returns the current state of all links of this object. + The current state of all links of this object. :return: A dictionary with the link id as key and the current state of the link as value. """ @@ -602,7 +892,7 @@ def link_states(self) -> Dict[int, LinkState]: @link_states.setter def link_states(self, link_states: Dict[int, LinkState]) -> None: """ - Sets the current state of all links of this object. + Set the current state of all links of this object. :param link_states: A dictionary with the link id as key and the current state of the link as value. """ @@ -612,7 +902,7 @@ def link_states(self, link_states: Dict[int, LinkState]) -> None: @property def joint_states(self) -> Dict[int, JointState]: """ - Returns the current state of all joints of this object. + The current state of all joints of this object. :return: A dictionary with the joint id as key and the current state of the joint as value. """ @@ -621,16 +911,20 @@ def joint_states(self) -> Dict[int, JointState]: @joint_states.setter def joint_states(self, joint_states: Dict[int, JointState]) -> None: """ - Sets the current state of all joints of this object. + Set the current state of all joints of this object. :param joint_states: A dictionary with the joint id as key and the current state of the joint as value. """ for joint in self.joints.values(): - joint.current_state = joint_states[joint.id] + if joint.name not in self.robot_virtual_move_base_joints_names(): + joint.current_state = joint_states[joint.id] + + def robot_virtual_move_base_joints_names(self): + return self.robot_description.virtual_mobile_base_joints.names def remove_saved_states(self) -> None: """ - Removes all saved states of this object. + Remove all saved states of this object. """ super().remove_saved_states() self.remove_links_saved_states() @@ -638,28 +932,30 @@ def remove_saved_states(self) -> None: def remove_links_saved_states(self) -> None: """ - Removes all saved states of the links of this object. + Remove all saved states of the links of this object. """ for link in self.links.values(): link.remove_saved_states() def remove_joints_saved_states(self) -> None: """ - Removes all saved states of the joints of this object. + Remove all saved states of the joints of this object. """ for joint in self.joints.values(): joint.remove_saved_states() def _set_attached_objects_poses(self, already_moved_objects: Optional[List[Object]] = None) -> None: """ - Updates the positions of all attached objects. This is done + Update the positions of all attached objects. This is done by calculating the new pose in world coordinate frame and setting the base pose of the attached objects to this new pose. - After this the _set_attached_objects method of all attached objects - will be called. + After this call _set_attached_objects method for all attached objects. - :param already_moved_objects: A list of Objects that were already moved, these will be excluded to prevent loops in the update. + :param already_moved_objects: A list of Objects that were already moved, these will be excluded to prevent loops + in the update. """ + if not self.world.conf.let_pycram_move_attached_objects: + return if already_moved_objects is None: already_moved_objects = [] @@ -675,13 +971,12 @@ def _set_attached_objects_poses(self, already_moved_objects: Optional[List[Objec child.update_attachment_with_object(self) else: - link_to_object = attachment.parent_to_child_transform - child.set_pose(link_to_object.to_pose(), set_attachments=False) + child.set_pose(attachment.get_child_link_target_pose(), set_attachments=False) child._set_attached_objects_poses(already_moved_objects + [self]) def set_position(self, position: Union[Pose, Point, List], base=False) -> None: """ - Sets this Object to the given position, if base is true the bottom of the Object will be placed at the position + Set this Object to the given position, if base is true, place the bottom of the Object at the position instead of the origin in the center of the Object. The given position can either be a Pose, in this case only the position is used or a geometry_msgs.msg/Point which is the position part of a Pose. @@ -694,10 +989,13 @@ def set_position(self, position: Union[Pose, Point, List], base=False) -> None: pose.frame = position.frame elif isinstance(position, Point): target_position = position - elif isinstance(position, list): - target_position = position + elif isinstance(position, List): + if len(position) == 3: + target_position = Point(*position) + else: + raise ValueError("The given position has to be a list of 3 values.") else: - raise TypeError("The given position has to be a Pose, Point or a list of xyz.") + raise TypeError("The given position has to be a Pose, Point or an iterable of xyz values.") pose.position = target_position pose.orientation = self.get_orientation() @@ -705,7 +1003,7 @@ def set_position(self, position: Union[Pose, Point, List], base=False) -> None: def set_orientation(self, orientation: Union[Pose, Quaternion, List, Tuple, np.ndarray]) -> None: """ - Sets the orientation of the Object to the given orientation. Orientation can either be a Pose, in this case only + Set the orientation of the Object to the given orientation. Orientation can either be a Pose, in this case only the orientation of this pose is used or a geometry_msgs.msg/Quaternion which is the orientation of a Pose. :param orientation: Target orientation given as a list of xyzw. @@ -728,7 +1026,7 @@ def set_orientation(self, orientation: Union[Pose, Quaternion, List, Tuple, np.n def get_joint_id(self, name: str) -> int: """ - Returns the unique id for a joint name. As used by the world/simulator. + Return the unique id for a joint name. As used by the world/simulator. :param name: The joint name :return: The unique id @@ -737,35 +1035,35 @@ def get_joint_id(self, name: str) -> int: def get_root_link_description(self) -> LinkDescription: """ - Returns the root link of the URDF of this object. + Return the root link of the URDF of this object. :return: The root link as defined in the URDF of this object. """ for link_description in self.description.links: - if link_description.name == self.root_link_name: + if link_description.name == self.description.get_root(): return link_description @property def root_link(self) -> ObjectDescription.Link: """ - Returns the root link of this object. + The root link of this object. :return: The root link of this object. """ return self.links[self.description.get_root()] @property - def root_link_name(self) -> str: + def tip_link(self) -> ObjectDescription.Link: """ - Returns the name of the root link of this object. + The tip link of this object. - :return: The name of the root link of this object. + :return: The tip link of this object. """ - return self.description.get_root() + return self.links[self.description.get_tip()] def get_root_link_id(self) -> int: """ - Returns the unique id of the root link of this object. + Return the unique id of the root link of this object. :return: The unique id of the root link of this object. """ @@ -773,7 +1071,7 @@ def get_root_link_id(self) -> int: def get_link_id(self, link_name: str) -> int: """ - Returns a unique id for a link name. + Return a unique id for a link name. :param link_name: The name of the link. :return: The unique id of the link. @@ -782,7 +1080,7 @@ def get_link_id(self, link_name: str) -> int: def get_link_by_id(self, link_id: int) -> ObjectDescription.Link: """ - Returns the link for a given unique link id + Return the link for a given unique link id :param link_id: The unique id of the link. :return: The link object. @@ -791,40 +1089,50 @@ def get_link_by_id(self, link_id: int) -> ObjectDescription.Link: def reset_all_joints_positions(self) -> None: """ - Sets the current position of all joints to 0. This is useful if the joints should be reset to their default + Set the current position of all joints to 0. This is useful if the joints should be reset to their default """ - joint_names = list(self.joint_name_to_id.keys()) + joint_names = [joint.name for joint in self.joints.values()] + if len(joint_names) == 0: + return joint_positions = [0] * len(joint_names) - self.set_joint_positions(dict(zip(joint_names, joint_positions))) + self.set_multiple_joint_positions(dict(zip(joint_names, joint_positions))) - def set_joint_positions(self, joint_poses: dict) -> None: + def set_joint_position(self, joint_name: str, joint_position: float) -> None: """ - Sets the current position of multiple joints at once, this method should be preferred when setting - multiple joints at once instead of running :func:`~Object.set_joint_position` in a loop. + Set the position of the given joint to the given joint pose and updates the poses of all attached objects. - :param joint_poses: + :param joint_name: The name of the joint + :param joint_position: The target pose for this joint """ - for joint_name, joint_position in joint_poses.items(): - self.joints[joint_name].position = joint_position - # self.update_pose() - self._update_all_links_poses() - self.update_link_transforms() - self._set_attached_objects_poses() + if self.world.reset_joint_position(self.joints[joint_name], joint_position): + self._update_on_joint_position_change() - def set_joint_position(self, joint_name: str, joint_position: float) -> None: + @deprecated("Use set_multiple_joint_positions instead") + def set_joint_positions(self, joint_positions: Dict[str, float]) -> None: + self.set_multiple_joint_positions(joint_positions) + + def set_multiple_joint_positions(self, joint_positions: Dict[str, float]) -> None: """ - Sets the position of the given joint to the given joint pose and updates the poses of all attached objects. + Set the current position of multiple joints at once, this method should be preferred when setting + multiple joints at once instead of running :func:`~Object.set_joint_position` in a loop. - :param joint_name: The name of the joint - :param joint_position: The target pose for this joint + :param joint_positions: A dictionary with the joint names as keys and the target positions as values. """ - self.joints[joint_name].position = joint_position + joint_positions = {self.joints[joint_name]: joint_position + for joint_name, joint_position in joint_positions.items()} + if self.world.set_multiple_joint_positions(joint_positions): + self._update_on_joint_position_change() + + def _update_on_joint_position_change(self): + self.update_pose() self._update_all_links_poses() self.update_link_transforms() self._set_attached_objects_poses() def get_joint_position(self, joint_name: str) -> float: """ + Return the current position of the given joint. + :param joint_name: The name of the joint :return: The current position of the given joint """ @@ -832,6 +1140,8 @@ def get_joint_position(self, joint_name: str) -> float: def get_joint_damping(self, joint_name: str) -> float: """ + Return the damping of the given joint (friction). + :param joint_name: The name of the joint :return: The damping of the given joint """ @@ -839,6 +1149,8 @@ def get_joint_damping(self, joint_name: str) -> float: def get_joint_upper_limit(self, joint_name: str) -> float: """ + Return the upper limit of the given joint. + :param joint_name: The name of the joint :return: The upper limit of the given joint """ @@ -846,6 +1158,8 @@ def get_joint_upper_limit(self, joint_name: str) -> float: def get_joint_lower_limit(self, joint_name: str) -> float: """ + Return the lower limit of the given joint. + :param joint_name: The name of the joint :return: The lower limit of the given joint """ @@ -853,6 +1167,8 @@ def get_joint_lower_limit(self, joint_name: str) -> float: def get_joint_axis(self, joint_name: str) -> Point: """ + Return the axis of the given joint. + :param joint_name: The name of the joint :return: The axis of the given joint """ @@ -860,6 +1176,8 @@ def get_joint_axis(self, joint_name: str) -> Point: def get_joint_type(self, joint_name: str) -> JointType: """ + Return the type of the given joint. + :param joint_name: The name of the joint :return: The type of the given joint """ @@ -867,6 +1185,8 @@ def get_joint_type(self, joint_name: str) -> JointType: def get_joint_limits(self, joint_name: str) -> Tuple[float, float]: """ + Return the lower and upper limits of the given joint. + :param joint_name: The name of the joint :return: The lower and upper limits of the given joint """ @@ -874,6 +1194,8 @@ def get_joint_limits(self, joint_name: str) -> Tuple[float, float]: def get_joint_child_link(self, joint_name: str) -> ObjectDescription.Link: """ + Return the child link of the given joint. + :param joint_name: The name of the joint :return: The child link of the given joint """ @@ -881,6 +1203,8 @@ def get_joint_child_link(self, joint_name: str) -> ObjectDescription.Link: def get_joint_parent_link(self, joint_name: str) -> ObjectDescription.Link: """ + Return the parent link of the given joint. + :param joint_name: The name of the joint :return: The parent link of the given joint """ @@ -888,7 +1212,7 @@ def get_joint_parent_link(self, joint_name: str) -> ObjectDescription.Link: def find_joint_above_link(self, link_name: str, joint_type: JointType) -> str: """ - Traverses the chain from 'link' to the URDF origin and returns the first joint that is of type 'joint_type'. + Traverse the chain from 'link' to the URDF origin and return the first joint that is of type 'joint_type'. :param link_name: AbstractLink name above which the joint should be found :param joint_type: Joint type that should be searched for @@ -905,9 +1229,18 @@ def find_joint_above_link(self, link_name: str, joint_type: JointType) -> str: rospy.logwarn(f"No joint of type {joint_type} found above link {link_name}") return container_joint + def get_multiple_joint_positions(self, joint_names: List[str]) -> Dict[str, float]: + """ + Return the positions of multiple joints at once. + + :param joint_names: A list of joint names. + :return: A dictionary with the joint names as keys and the joint positions as values. + """ + return self.world.get_multiple_joint_positions([self.joints[joint_name] for joint_name in joint_names]) + def get_positions_of_all_joints(self) -> Dict[str, float]: """ - Returns the positions of all joints of the object as a dictionary of joint names and joint positions. + Return the positions of all joints of the object as a dictionary of joint names and joint positions. :return: A dictionary with all joints positions'. """ @@ -915,22 +1248,24 @@ def get_positions_of_all_joints(self) -> Dict[str, float]: def update_link_transforms(self, transform_time: Optional[rospy.Time] = None) -> None: """ - Updates the transforms of all links of this object using time 'transform_time' or the current ros time. + Update the transforms of all links of this object using time 'transform_time' or the current ros time. + + :param transform_time: The time to use for the transform update. """ for link in self.links.values(): link.update_transform(transform_time) - def contact_points(self) -> List: + def contact_points(self) -> ContactPointsList: """ - Returns a list of contact points of this Object with other Objects. + Return a list of contact points of this Object with other Objects. :return: A list of all contact points with other objects """ return self.world.get_object_contact_points(self) - def contact_points_simulated(self) -> List: + def contact_points_simulated(self) -> ContactPointsList: """ - Returns a list of all contact points between this Object and other Objects after stepping the simulation once. + Return a list of all contact points between this Object and other Objects after stepping the simulation once. :return: A list of contact points between this Object and other Objects """ @@ -940,9 +1275,28 @@ def contact_points_simulated(self) -> List: self.world.restore_state(state_id) return contact_points + def closest_points(self, max_distance: float) -> ClosestPointsList: + """ + Return a list of closest points between this Object and other Objects. + + :param max_distance: The maximum distance between the closest points + :return: A list of closest points between this Object and other Objects + """ + return self.world.get_object_closest_points(self, max_distance) + + def closest_points_with_obj(self, other_object: Object, max_distance: float) -> ClosestPointsList: + """ + Return a list of closest points between this Object and another Object. + + :param other_object: The other object + :param max_distance: The maximum distance between the closest points + :return: A list of closest points between this Object and the other Object + """ + return self.world.get_closest_points_between_objects(self, other_object, max_distance) + def set_color(self, rgba_color: Color) -> None: """ - Changes the color of this object, the color has to be given as a list + Change the color of this object, the color has to be given as a list of RGBA values. :param rgba_color: The color as Color object with RGBA values between 0 and 1 @@ -957,7 +1311,7 @@ def set_color(self, rgba_color: Color) -> None: def get_color(self) -> Union[Color, Dict[str, Color]]: """ - This method returns the rgba_color of this object. The return is either: + Return the rgba_color of this object. The return is either: 1. A Color object with RGBA values, this is the case if the object only has one link (this happens for example if the object is spawned from a .obj or .stl file) @@ -965,7 +1319,8 @@ def get_color(self) -> Union[Color, Dict[str, Color]]: Please keep in mind that not every link may have a rgba_color. This is dependent on the URDF from which the object is spawned. - :return: The rgba_color as Color object with RGBA values between 0 and 1 or a dict with the link name as key and the rgba_color as value. + :return: The rgba_color as Color object with RGBA values between 0 and 1 or a dict with the link name as key and + the rgba_color as value. """ link_to_color_dict = self.links_colors @@ -983,12 +1338,16 @@ def links_colors(self) -> Dict[str, Color]: def get_axis_aligned_bounding_box(self) -> AxisAlignedBoundingBox: """ + Return the axis aligned bounding box of this object. + :return: The axis aligned bounding box of this object. """ return self.world.get_object_axis_aligned_bounding_box(self) def get_base_origin(self) -> Pose: """ + Return the origin of the base/bottom of this object. + :return: the origin of the base/bottom of this object. """ aabb = self.get_axis_aligned_bounding_box() @@ -999,49 +1358,43 @@ def get_base_origin(self) -> Pose: def get_joint_by_id(self, joint_id: int) -> Joint: """ - Returns the joint object with the given id. + Return the joint object with the given id. :param joint_id: The unique id of the joint. :return: The joint object. """ return dict([(joint.id, joint) for joint in self.joints.values()])[joint_id] - def copy_to_prospection(self) -> Object: - """ - Copies this object to the prospection world. - - :return: The copied object in the prospection world. - """ - obj = Object(self.name, self.obj_type, self.path, type(self.description), self.get_pose(), - self.world.prospection_world, self.color) - obj.current_state = self.current_state - return obj - def get_link_for_attached_objects(self) -> Dict[Object, ObjectDescription.Link]: """ - Returns a dictionary which maps attached object to the link of this object to which the given object is attached. + Return a dictionary which maps attached object to the link of this object to which the given object is attached. :return: The link of this object to which the given object is attached. """ return {obj: attachment.parent_link for obj, attachment in self.attachments.items()} + def copy_to_prospection(self) -> Object: + """ + Copy this object to the prospection world. + + :return: The copied object in the prospection world. + """ + return self.copy_to_world(self.world.prospection_world) - def __copy__(self) -> Object: + def copy_to_world(self, world: World) -> Object: """ - Returns a copy of this object. The copy will have the same name, type, path, description, pose, world and color. + Copy this object to the given world. - :return: A copy of this object. + :param world: The world to which the object should be copied. + :return: The copied object in the given world. """ - obj = Object(self.name, self.obj_type, self.path, type(self.description), self.get_pose(), - self.world.prospection_world, self.color) - obj.current_state = self.current_state + obj = Object(self.name, self.obj_type, self.path, self.description, self.get_pose(), + world, self.color) return obj def __eq__(self, other): - if not isinstance(other, Object): - return False - return (self.id == other.id and self.world == other.world and self.name == other.name - and self.obj_type == other.obj_type) + return (isinstance(other, Object) and self.id == other.id and self.name == other.name + and self.world == other.world) def __hash__(self): - return hash((self.name, self.obj_type, self.id, self.world.id)) + return hash((self.id, self.name, self.world)) diff --git a/src/pycram/world_reasoning.py b/src/pycram/world_reasoning.py index f4892676b..8fd0501b9 100644 --- a/src/pycram/world_reasoning.py +++ b/src/pycram/world_reasoning.py @@ -1,12 +1,14 @@ -from typing_extensions import List, Tuple, Optional, Union, Dict - import numpy as np +from typing_extensions import List, Tuple, Optional, Union, Dict -from .external_interfaces.ik import try_to_reach, try_to_reach_with_grasp +from .datastructures.dataclasses import ContactPointsList from .datastructures.pose import Pose, Transform +from .datastructures.world import World, UseProspectionWorld +from .external_interfaces.ik import try_to_reach, try_to_reach_with_grasp from .robot_description import RobotDescription +from .utils import RayTestUtils from .world_concepts.world_object import Object -from .datastructures.world import World, UseProspectionWorld +from .config import world_conf as conf def stable(obj: Object) -> bool: @@ -49,41 +51,44 @@ def contact( with UseProspectionWorld(): prospection_obj1 = World.current_world.get_prospection_object_for_object(object1) prospection_obj2 = World.current_world.get_prospection_object_for_object(object2) - World.current_world.perform_collision_detection() - con_points = World.current_world.get_contact_points_between_two_objects(prospection_obj1, prospection_obj2) - + con_points: ContactPointsList = World.current_world.get_contact_points_between_two_objects(prospection_obj1, + prospection_obj2) + objects_are_in_contact = len(con_points) > 0 if return_links: - contact_links = [] - for point in con_points: - contact_links.append((prospection_obj1.get_link_by_id(point[3]), - prospection_obj2.get_link_by_id(point[4]))) - return con_points != (), contact_links - + contact_links = [(point.link_a, point.link_b) for point in con_points] + return objects_are_in_contact, contact_links else: - return con_points != () + return objects_are_in_contact def get_visible_objects( camera_pose: Pose, - front_facing_axis: Optional[List[float]] = None) -> Tuple[np.ndarray, Pose]: + front_facing_axis: Optional[List[float]] = None, + plot_segmentation_mask: bool = False) -> Tuple[np.ndarray, Pose]: """ - Returns a segmentation mask of the objects that are visible from the given camera pose and the front facing axis. + Return a segmentation mask of the objects that are visible from the given camera pose and the front facing axis. :param camera_pose: The pose of the camera in world coordinate frame. :param front_facing_axis: The axis, of the camera frame, which faces to the front of the robot. Given as list of xyz + :param plot_segmentation_mask: If the segmentation mask should be plotted :return: A segmentation mask of the objects that are visible and the pose of the point at exactly 2 meters in front of the camera in the direction of the front facing axis with respect to the world coordinate frame. """ - front_facing_axis = RobotDescription.current_robot_description.get_default_camera().front_facing_axis + if front_facing_axis is None: + front_facing_axis = RobotDescription.current_robot_description.get_default_camera().front_facing_axis - world_to_cam = camera_pose.to_transform("camera") + camera_frame = RobotDescription.current_robot_description.get_camera_frame() + world_to_cam = camera_pose.to_transform(camera_frame) - cam_to_point = Transform(list(np.multiply(front_facing_axis, 2)), [0, 0, 0, 1], "camera", + cam_to_point = Transform(list(np.multiply(front_facing_axis, 2)), [0, 0, 0, 1], camera_frame, "point") target_point = (world_to_cam * cam_to_point).to_pose() seg_mask = World.current_world.get_images_for_target(target_point, camera_pose)[2] + if plot_segmentation_mask: + RayTestUtils.plot_segmentation_mask(seg_mask) + return seg_mask, target_point @@ -91,7 +96,8 @@ def visible( obj: Object, camera_pose: Pose, front_facing_axis: Optional[List[float]] = None, - threshold: float = 0.8) -> bool: + threshold: float = 0.8, + plot_segmentation_mask: bool = False) -> bool: """ Checks if an object is visible from a given position. This will be achieved by rendering the object alone and counting the visible pixel, then rendering the complete scene and compare the visible pixels with the @@ -101,6 +107,7 @@ def visible( :param camera_pose: The pose of the camera in map frame :param front_facing_axis: The axis, of the camera frame, which faces to the front of the robot. Given as list of xyz :param threshold: The minimum percentage of the object that needs to be visible for this method to return true. + :param plot_segmentation_mask: If the segmentation mask should be plotted. :return: True if the object is visible from the camera_position False if not """ with UseProspectionWorld(): @@ -115,7 +122,7 @@ def visible( else: obj.set_pose(Pose([100, 100, 0], [0, 0, 0, 1]), set_attachments=False) - seg_mask, target_point = get_visible_objects(camera_pose, front_facing_axis) + seg_mask, target_point = get_visible_objects(camera_pose, front_facing_axis, plot_segmentation_mask) max_pixel = np.array(seg_mask == prospection_obj.id).sum() World.current_world.restore_state(state_id) @@ -133,7 +140,8 @@ def visible( def occluding( obj: Object, camera_pose: Pose, - front_facing_axis: Optional[List[float]] = None) -> List[Object]: + front_facing_axis: Optional[List[float]] = None, + plot_segmentation_mask: bool = False) -> List[Object]: """ Lists all objects which are occluding the given object. This works similar to 'visible'. First the object alone will be rendered and the position of the pixels of the object in the picture will be saved. @@ -143,6 +151,7 @@ def occluding( :param obj: The object for which occlusion should be checked :param camera_pose: The pose of the camera in world coordinate frame :param front_facing_axis: The axis, of the camera frame, which faces to the front of the robot. Given as list of xyz + :param plot_segmentation_mask: If the segmentation mask should be plotted :return: A list of occluding objects """ @@ -156,7 +165,7 @@ def occluding( else: other_obj.set_pose(Pose([100, 100, 0], [0, 0, 0, 1])) - seg_mask, target_point = get_visible_objects(camera_pose, front_facing_axis) + seg_mask, target_point = get_visible_objects(camera_pose, front_facing_axis, plot_segmentation_mask) # All indices where the object that could be occluded is in the image # [0] at the end is to reduce by one dimension because dstack adds an unnecessary dimension @@ -224,17 +233,15 @@ def blocking( :return: A list of objects the robot is in collision with when reaching for the specified object or None if the pose or object is not reachable. """ - prospection_robot = World.current_world.get_prospection_object_for_object(robot) with UseProspectionWorld(): + prospection_robot = World.current_world.get_prospection_object_for_object(robot) if grasp: try_to_reach_with_grasp(pose_or_object, prospection_robot, gripper_name, grasp) else: try_to_reach(pose_or_object, prospection_robot, gripper_name) - block = [] - for obj in World.current_world.objects: - if contact(prospection_robot, obj): - block.append(World.current_world.get_object_for_prospection_object(obj)) + block = [World.current_world.get_object_for_prospection_object(obj) for obj in World.current_world.objects + if contact(prospection_robot, obj)] return block @@ -257,7 +264,7 @@ def link_pose_for_joint_config( joint_config: Dict[str, float], link_name: str) -> Pose: """ - Returns the pose a link would be in if the given joint configuration would be applied to the object. + Get the pose a link would be in if the given joint configuration would be applied to the object. This is done by using the respective object in the prospection world and applying the joint configuration to this one. After applying the joint configuration the link position is taken from there. diff --git a/src/pycram/worlds/bullet_world.py b/src/pycram/worlds/bullet_world.py index b3f93cb5a..c80e925f1 100755 --- a/src/pycram/worlds/bullet_world.py +++ b/src/pycram/worlds/bullet_world.py @@ -9,14 +9,17 @@ import rosgraph import rospy from geometry_msgs.msg import Point -from typing_extensions import List, Optional, Dict +from typing_extensions import List, Optional, Dict, Any -from ..datastructures.dataclasses import Color, AxisAlignedBoundingBox, MultiBody, VisualShape, BoxVisualShape +from ..datastructures.dataclasses import Color, AxisAlignedBoundingBox, MultiBody, VisualShape, BoxVisualShape, \ + ClosestPoint, LateralFriction, ContactPoint, ContactPointsList, ClosestPointsList from ..datastructures.enums import ObjectType, WorldMode, JointType from ..datastructures.pose import Pose from ..datastructures.world import World from ..object_descriptors.generic import ObjectDescription as GenericObjectDescription from ..object_descriptors.urdf import ObjectDescription +from ..validation.goal_validator import (validate_multiple_joint_positions, validate_joint_position, + validate_object_pose, validate_multiple_object_poses) from ..world_concepts.constraints import Constraint from ..world_concepts.world_object import Object @@ -32,13 +35,11 @@ class is the main interface to the Bullet Physics Engine and should be used to s manipulate the Bullet World. """ - extension: str = ObjectDescription.get_file_extension() - # Check is for sphinx autoAPI to be able to work in a CI workflow if rosgraph.is_master_online(): # and "/pycram" not in rosnode.get_node_names(): rospy.init_node('pycram') - def __init__(self, mode: WorldMode = WorldMode.DIRECT, is_prospection_world: bool = False, sim_frequency=240): + def __init__(self, mode: WorldMode = WorldMode.DIRECT, is_prospection_world: bool = False): """ Creates a new simulation, the type decides of the simulation should be a rendered window or just run in the background. There can only be one rendered simulation. @@ -47,7 +48,7 @@ def __init__(self, mode: WorldMode = WorldMode.DIRECT, is_prospection_world: boo :param mode: Can either be "GUI" for rendered window or "DIRECT" for non-rendered. The default is "GUI" :param is_prospection_world: For internal usage, decides if this BulletWorld should be used as a shadow world. """ - super().__init__(mode=mode, is_prospection_world=is_prospection_world, simulation_frequency=sim_frequency) + super().__init__(mode=mode, is_prospection_world=is_prospection_world) # This disables file caching from PyBullet, since this would also cache # files that can not be loaded @@ -61,7 +62,7 @@ def __init__(self, mode: WorldMode = WorldMode.DIRECT, is_prospection_world: boo self.set_gravity([0, 0, -9.8]) if not is_prospection_world: - _ = Object("floor", ObjectType.ENVIRONMENT, "plane" + self.extension, + _ = Object("floor", ObjectType.ENVIRONMENT, "plane.urdf", world=self) def _init_world(self, mode: WorldMode): @@ -69,26 +70,30 @@ def _init_world(self, mode: WorldMode): self._gui_thread.start() time.sleep(0.1) - def load_generic_object_and_get_id(self, description: GenericObjectDescription) -> int: + def load_generic_object_and_get_id(self, description: GenericObjectDescription, + pose: Optional[Pose] = None) -> int: """ Creates a visual and collision box in the simulation. """ # Create visual shape vis_shape = p.createVisualShape(p.GEOM_BOX, halfExtents=description.shape_data, - rgbaColor=description.color.get_rgba()) + rgbaColor=description.color.get_rgba(), physicsClientId=self.id) # Create collision shape - col_shape = p.createCollisionShape(p.GEOM_BOX, halfExtents=description.shape_data) + col_shape = p.createCollisionShape(p.GEOM_BOX, halfExtents=description.shape_data, physicsClientId=self.id) # Create MultiBody with both visual and collision shapes obj_id = p.createMultiBody(baseMass=1.0, baseCollisionShapeIndex=col_shape, baseVisualShapeIndex=vis_shape, basePosition=description.origin.position_as_list(), - baseOrientation=description.origin.orientation_as_list()) + baseOrientation=description.origin.orientation_as_list(), physicsClientId=self.id) + if pose is not None: + self._set_object_pose_by_id(obj_id, pose) # Assuming you have a list to keep track of created objects return obj_id - def load_object_and_get_id(self, path: Optional[str] = None, pose: Optional[Pose] = None) -> int: + def load_object_and_get_id(self, path: Optional[str] = None, pose: Optional[Pose] = None, + obj_type: Optional[ObjectType] = None) -> int: if pose is None: pose = Pose() return self._load_object_and_get_id(path, pose) @@ -100,11 +105,21 @@ def _load_object_and_get_id(self, path: str, pose: Pose) -> int: basePosition=pose.position_as_list(), baseOrientation=pose.orientation_as_list(), physicsClientId=self.id) - def remove_object_from_simulator(self, obj: Object) -> None: - p.removeBody(obj.id, self.id) + def _remove_visual_object(self, obj_id: int) -> bool: + self._remove_body(obj_id) + return True + + def remove_object_from_simulator(self, obj: Object) -> bool: + self._remove_body(obj.id) + return True + + def _remove_body(self, body_id: int) -> Any: + """ + Remove a body from PyBullet using the body id. - def remove_object_by_id(self, obj_id: int) -> None: - p.removeBody(obj_id, self.id) + :param body_id: The id of the body. + """ + return p.removeBody(body_id, self.id) def add_constraint(self, constraint: Constraint) -> int: @@ -131,6 +146,21 @@ def get_object_joint_names(self, obj: Object) -> List[str]: return [p.getJointInfo(obj.id, i, physicsClientId=self.id)[1].decode('utf-8') for i in range(self.get_object_number_of_joints(obj))] + def get_multiple_link_poses(self, links: List[Link]) -> Dict[str, Pose]: + return {link.name: self.get_link_pose(link) for link in links} + + def get_multiple_link_positions(self, links: List[Link]) -> Dict[str, List[float]]: + return {link.name: self.get_link_position(link) for link in links} + + def get_multiple_link_orientations(self, links: List[Link]) -> Dict[str, List[float]]: + return {link.name: self.get_link_orientation(link) for link in links} + + def get_link_position(self, link: Link) -> List[float]: + return self.get_link_pose(link).position_as_list() + + def get_link_orientation(self, link: Link) -> List[float]: + return self.get_link_pose(link).orientation_as_list() + def get_link_pose(self, link: ObjectDescription.Link) -> Pose: bullet_link_state = p.getLinkState(link.object_id, link.id, physicsClientId=self.id) return Pose(*bullet_link_state[4:6]) @@ -148,29 +178,94 @@ def get_object_number_of_links(self, obj: Object) -> int: def perform_collision_detection(self) -> None: p.performCollisionDetection(physicsClientId=self.id) - def get_object_contact_points(self, obj: Object) -> List: + def get_object_contact_points(self, obj: Object) -> ContactPointsList: """ - For a more detailed explanation of the - returned list please look at: - `PyBullet Doc `_ + Get the contact points of the object with akk other objects in the world. The contact points are returned as a + ContactPointsList object. + + :param obj: The object for which the contact points should be returned. + :return: The contact points of the object with all other objects in the world. """ self.perform_collision_detection() - return p.getContactPoints(obj.id, physicsClientId=self.id) + points_list = p.getContactPoints(obj.id, physicsClientId=self.id) + return ContactPointsList([ContactPoint(**self.parse_points_list_to_args(point)) for point in points_list + if len(point) > 0]) - def get_contact_points_between_two_objects(self, obj1: Object, obj2: Object) -> List: + def get_contact_points_between_two_objects(self, obj_a: Object, obj_b: Object) -> ContactPointsList: self.perform_collision_detection() - return p.getContactPoints(obj1.id, obj2.id, physicsClientId=self.id) + points_list = p.getContactPoints(obj_a.id, obj_b.id, physicsClientId=self.id) + return ContactPointsList([ContactPoint(**self.parse_points_list_to_args(point)) for point in points_list + if len(point) > 0]) + + def get_closest_points_between_objects(self, obj_a: Object, obj_b: Object, distance: float) -> ClosestPointsList: + points_list = p.getClosestPoints(obj_a.id, obj_b.id, distance, physicsClientId=self.id) + return ClosestPointsList([ClosestPoint(**self.parse_points_list_to_args(point)) for point in points_list + if len(point) > 0]) - def reset_joint_position(self, joint: ObjectDescription.Joint, joint_position: str) -> None: + def parse_points_list_to_args(self, point: List) -> Dict: + """ + Parses the list of points to a list of dictionaries with the keys as the names of the arguments of the + ContactPoint class. + + :param point: The list of points. + """ + return {"link_a": self.get_object_by_id(point[1]).get_link_by_id(point[3]), + "link_b": self.get_object_by_id(point[2]).get_link_by_id(point[4]), + "position_on_object_a": point[5], + "position_on_object_b": point[6], + "normal_on_b": point[7], + "distance": point[8], + "normal_force": point[9], + "lateral_friction_1": LateralFriction(point[10], point[11]), + "lateral_friction_2": LateralFriction(point[12], point[13])} + + @validate_multiple_joint_positions + def set_multiple_joint_positions(self, joint_positions: Dict[Joint, float]) -> bool: + for joint, joint_position in joint_positions.items(): + self.reset_joint_position(joint, joint_position) + return True + + @validate_joint_position + def reset_joint_position(self, joint: Joint, joint_position: float) -> bool: p.resetJointState(joint.object_id, joint.id, joint_position, physicsClientId=self.id) + return True + + def get_multiple_joint_positions(self, joints: List[Joint]) -> Dict[str, float]: + return {joint.name: self.get_joint_position(joint) for joint in joints} + + @validate_multiple_object_poses + def reset_multiple_objects_base_poses(self, objects: Dict[Object, Pose]) -> bool: + for obj, pose in objects.items(): + self.reset_object_base_pose(obj, pose) + return True + + @validate_object_pose + def reset_object_base_pose(self, obj: Object, pose: Pose) -> bool: + return self._set_object_pose_by_id(obj.id, pose) - def reset_object_base_pose(self, obj: Object, pose: Pose) -> None: - p.resetBasePositionAndOrientation(obj.id, pose.position_as_list(), pose.orientation_as_list(), + def _set_object_pose_by_id(self, obj_id: int, pose: Pose) -> bool: + p.resetBasePositionAndOrientation(obj_id, pose.position_as_list(), pose.orientation_as_list(), physicsClientId=self.id) + return True def step(self): p.stepSimulation(physicsClientId=self.id) + def get_multiple_object_poses(self, objects: List[Object]) -> Dict[str, Pose]: + return {obj.name: self.get_object_pose(obj) for obj in objects} + + def get_multiple_object_positions(self, objects: List[Object]) -> Dict[str, List[float]]: + return {obj.name: self.get_object_pose(obj).position_as_list() for obj in objects} + + def get_multiple_object_orientations(self, objects: List[Object]) -> Dict[str, List[float]]: + return {obj.name: self.get_object_pose(obj).orientation_as_list() for obj in objects} + + def get_object_position(self, obj: Object) -> List[float]: + return self.get_object_pose(obj).position_as_list() + + def get_object_orientation(self, obj: Object) -> List[float]: + return self.get_object_pose(obj).orientation_as_list() + def get_object_pose(self, obj: Object) -> Pose: return Pose(*p.getBasePositionAndOrientation(obj.id, physicsClientId=self.id)) @@ -213,7 +308,7 @@ def join_gui_thread_if_exists(self): if self._gui_thread: self._gui_thread.join() - def save_physics_simulator_state(self) -> int: + def save_physics_simulator_state(self, use_same_id: bool = False) -> int: return p.saveState(physicsClientId=self.id) def restore_physics_simulator_state(self, state_id): @@ -222,14 +317,15 @@ def restore_physics_simulator_state(self, state_id): def remove_physics_simulator_state(self, state_id: int): p.removeState(state_id, physicsClientId=self.id) - def add_vis_axis(self, pose: Pose, - length: Optional[float] = 0.2) -> None: + def _add_vis_axis(self, pose: Pose, + length: Optional[float] = 0.2) -> int: """ Creates a Visual object which represents the coordinate frame at the given position and orientation. There can be an unlimited amount of vis axis objects. :param pose: The pose at which the axis should be spawned :param length: Optional parameter to configure the length of the axes + :return: The id of the spawned object """ pose_in_map = self.local_transformer.transform_pose(pose, "map") @@ -251,9 +347,11 @@ def add_vis_axis(self, pose: Pose, link_joint_axis=[Point(1, 0, 0), Point(0, 1, 0), Point(0, 0, 1)], link_collision_shape_indices=[-1, -1, -1]) - self.vis_axis.append(self.create_multi_body(multibody)) + body_id = self._create_multi_body(multibody) + self.vis_axis.append(body_id) + return body_id - def remove_vis_axis(self) -> None: + def _remove_vis_axis(self) -> None: """ Removes all spawned vis axis objects that are currently in this BulletWorld. """ @@ -270,13 +368,13 @@ def ray_test_batch(self, from_positions: List[List[float]], to_positions: List[L return p.rayTestBatch(from_positions, to_positions, numThreads=num_threads, physicsClientId=self.id) - def create_visual_shape(self, visual_shape: VisualShape) -> int: + def _create_visual_shape(self, visual_shape: VisualShape) -> int: return p.createVisualShape(visual_shape.visual_geometry_type.value, rgbaColor=visual_shape.rgba_color.get_rgba(), visualFramePosition=visual_shape.visual_frame_position, physicsClientId=self.id, **visual_shape.shape_data()) - def create_multi_body(self, multi_body: MultiBody) -> int: + def _create_multi_body(self, multi_body: MultiBody) -> int: return p.createMultiBody(baseVisualShapeIndex=-multi_body.base_visual_shape_index, linkVisualShapeIndices=multi_body.link_visual_shape_indices, basePosition=multi_body.base_pose.position_as_list(), @@ -309,9 +407,9 @@ def get_images_for_target(self, return list(p.getCameraImage(size, size, view_matrix, projection_matrix, physicsClientId=self.id))[2:5] - def add_text(self, text: str, position: List[float], orientation: Optional[List[float]] = None, - size: Optional[float] = None, color: Optional[Color] = Color(), life_time: Optional[float] = 0, - parent_object_id: Optional[int] = None, parent_link_id: Optional[int] = None) -> int: + def _add_text(self, text: str, position: List[float], orientation: Optional[List[float]] = None, + size: Optional[float] = None, color: Optional[Color] = Color(), life_time: Optional[float] = 0, + parent_object_id: Optional[int] = None, parent_link_id: Optional[int] = None) -> int: args = {} if orientation: args["textOrientation"] = orientation @@ -325,7 +423,7 @@ def add_text(self, text: str, position: List[float], orientation: Optional[List[ args["parentLinkIndex"] = parent_link_id return p.addUserDebugText(text, position, color.get_rgb(), physicsClientId=self.id, **args) - def remove_text(self, text_id: Optional[int] = None) -> None: + def _remove_text(self, text_id: Optional[int] = None) -> None: if text_id is not None: p.removeUserDebugItem(text_id, physicsClientId=self.id) else: @@ -415,7 +513,7 @@ def run(self): width, height, dist = (p.getDebugVisualizerCamera()[0], p.getDebugVisualizerCamera()[1], p.getDebugVisualizerCamera()[10]) - #print("width: ", width, "height: ", height, "dist: ", dist) + # print("width: ", width, "height: ", height, "dist: ", dist) camera_target_position = p.getDebugVisualizerCamera(self.world.id)[11] # Get vectors used for movement on x,y,z Vector @@ -570,5 +668,6 @@ def run(self): cameraTargetPosition=camera_target_position, physicsClientId=self.world.id) if visible == 0: camera_target_position = (0.0, -50, 50) - p.resetBasePositionAndOrientation(sphere_uid, camera_target_position, [0, 0, 0, 1], physicsClientId=self.world.id) + p.resetBasePositionAndOrientation(sphere_uid, camera_target_position, [0, 0, 0, 1], + physicsClientId=self.world.id) time.sleep(1. / 80.) diff --git a/src/pycram/worlds/multiverse.py b/src/pycram/worlds/multiverse.py new file mode 100644 index 000000000..fd174155b --- /dev/null +++ b/src/pycram/worlds/multiverse.py @@ -0,0 +1,656 @@ +import logging +from time import sleep + +import numpy as np +import rospy +from tf.transformations import quaternion_matrix +from typing_extensions import List, Dict, Optional, Union, Tuple + +from .multiverse_communication.client_manager import MultiverseClientManager +from .multiverse_communication.clients import MultiverseController, MultiverseReader, MultiverseWriter, MultiverseAPI +from ..config.multiverse_conf import MultiverseConfig +from ..datastructures.dataclasses import AxisAlignedBoundingBox, Color, ContactPointsList, ContactPoint +from ..datastructures.enums import WorldMode, JointType, ObjectType, MultiverseBodyProperty, MultiverseJointPosition, \ + MultiverseJointCMD +from ..datastructures.pose import Pose +from ..datastructures.world import World +from ..description import Link, Joint, ObjectDescription +from ..object_descriptors.mjcf import ObjectDescription as MJCF +from ..robot_description import RobotDescription +from ..utils import RayTestUtils, wxyz_to_xyzw, xyzw_to_wxyz +from ..validation.goal_validator import validate_object_pose, validate_multiple_joint_positions, \ + validate_joint_position, validate_multiple_object_poses +from ..world_concepts.constraints import Constraint +from ..world_concepts.world_object import Object + + +class Multiverse(World): + """ + This class implements an interface between Multiverse and PyCRAM. + """ + + conf: MultiverseConfig = MultiverseConfig + """ + The Multiverse configuration. + """ + + supported_joint_types = (JointType.REVOLUTE, JointType.CONTINUOUS, JointType.PRISMATIC) + """ + A Tuple for the supported pycram joint types in Multiverse. + """ + + added_multiverse_resources: bool = False + """ + A flag to check if the multiverse resources have been added. + """ + + simulation: Optional[str] = None + """ + The simulation name to be used in the Multiverse world (this is the name defined in + the multiverse configuration file). + """ + + Object.extension_to_description_type[MJCF.get_file_extension()] = MJCF + """ + Add the MJCF description extension to the extension to description type mapping for the objects. + """ + + def __init__(self, mode: Optional[WorldMode] = WorldMode.DIRECT, + is_prospection: Optional[bool] = False, + simulation_name: str = "pycram_test", + clear_cache: bool = False): + """ + Initialize the Multiverse Socket and the PyCram World. + + :param mode: The mode of the world (DIRECT or GUI). + :param is_prospection: Whether the world is prospection or not. + :param simulation_name: The name of the simulation. + :param clear_cache: Whether to clear the cache or not. + """ + + self.latest_save_id: Optional[int] = None + self.saved_simulator_states: Dict = {} + self._make_sure_multiverse_resources_are_added(clear_cache=clear_cache) + + if Multiverse.simulation is None: + if simulation_name is None: + logging.error("Simulation name not provided") + raise ValueError("Simulation name not provided") + Multiverse.simulation = simulation_name + + self.simulation = (self.conf.prospection_world_prefix if is_prospection else "") + Multiverse.simulation + self.client_manager = MultiverseClientManager(self.conf.simulation_wait_time_factor) + self._init_clients(is_prospection=is_prospection) + + World.__init__(self, mode, is_prospection) + + self._init_constraint_and_object_id_name_map_collections() + + self.ray_test_utils = RayTestUtils(self.ray_test_batch, self.object_id_to_name) + + if not self.is_prospection_world: + self._spawn_floor() + + if self.conf.use_static_mode: + self.api_requester.pause_simulation() + + def _init_clients(self, is_prospection: bool = False): + """ + Initialize the Multiverse clients that will be used to communicate with the Multiverse server. + Each client is responsible for a specific task, e.g. reading data from the server, writing data to the serve, + calling the API, or controlling the robot joints. + + :param is_prospection: Whether the world is prospection or not. + """ + self.reader: MultiverseReader = self.client_manager.create_reader( + is_prospection_world=is_prospection) + self.writer: MultiverseWriter = self.client_manager.create_writer( + self.simulation, + is_prospection_world=is_prospection) + self.api_requester: MultiverseAPI = self.client_manager.create_api_requester( + self.simulation, + is_prospection_world=is_prospection) + if self.conf.use_controller: + self.joint_controller: MultiverseController = self.client_manager.create_controller( + is_prospection_world=is_prospection) + + def _init_constraint_and_object_id_name_map_collections(self): + self.last_object_id: int = -1 + self.last_constraint_id: int = -1 + self.constraints: Dict[int, Constraint] = {} + self.object_name_to_id: Dict[str, int] = {} + self.object_id_to_name: Dict[int, str] = {} + + def _init_world(self, mode: WorldMode): + pass + + def _make_sure_multiverse_resources_are_added(self, clear_cache: bool = False): + """ + Add the multiverse resources to the pycram world resources, and change the data directory and cache manager. + + :param clear_cache: Whether to clear the cache or not. + """ + if not self.added_multiverse_resources: + if clear_cache: + World.cache_manager.clear_cache() + World.add_resource_path(self.conf.resources_path, prepend=True) + World.change_cache_dir_path(self.conf.resources_path) + self.added_multiverse_resources = True + + def remove_multiverse_resources(self): + """ + Remove the multiverse resources from the pycram world resources. + """ + if self.added_multiverse_resources: + World.remove_resource_path(self.conf.resources_path) + World.change_cache_dir_path(self.conf.cache_dir) + self.added_multiverse_resources = False + + def _spawn_floor(self): + """ + Spawn the plane in the simulator. + """ + self.floor = Object("floor", ObjectType.ENVIRONMENT, "plane.urdf", + world=self) + + def get_images_for_target(self, target_pose: Pose, + cam_pose: Pose, + size: int = 256, + camera_min_distance: float = 0.1, + camera_max_distance: int = 3, + plot: bool = False) -> List[np.ndarray]: + """ + Uses ray test to get the images for the target object. (target_pose is currently not used) + """ + camera_description = RobotDescription.current_robot_description.get_default_camera() + camera_frame = RobotDescription.current_robot_description.get_camera_frame() + return self.ray_test_utils.get_images_for_target(cam_pose, camera_description, camera_frame, + size, camera_min_distance, camera_max_distance, plot) + + @staticmethod + def get_joint_position_name(joint: Joint) -> MultiverseJointPosition: + """ + Get the attribute name of the joint position in the Multiverse from the pycram joint type. + + :param joint: The joint. + """ + return MultiverseJointPosition.from_pycram_joint_type(joint.type) + + def spawn_robot_with_controller(self, name: str, pose: Pose) -> None: + """ + Spawn the robot in the simulator. + + :param name: The name of the robot. + :param pose: The pose of the robot. + """ + actuator_joint_commands = { + actuator_name: [self.get_joint_cmd_name(self.robot_description.joint_types[joint_name]).value] + for joint_name, actuator_name in self.robot_joint_actuators.items() + } + self.joint_controller.init_controller(actuator_joint_commands) + self.writer.spawn_robot_with_actuators(name, pose.position_as_list(), + xyzw_to_wxyz(pose.orientation_as_list()), + actuator_joint_commands) + + def load_object_and_get_id(self, name: Optional[str] = None, + pose: Optional[Pose] = None, + obj_type: Optional[ObjectType] = None) -> int: + """ + Spawn the object in the simulator and return the object id. Object name has to be unique and has to be same as + the name of the object in the description file. + + :param name: The name of the object to be loaded. + :param pose: The pose of the object. + :param obj_type: The type of the object. + """ + if pose is None: + pose = Pose() + + # Do not spawn objects with type environment as they should be already present in the simulator through the + # multiverse description file (.muv file). + if not obj_type == ObjectType.ENVIRONMENT: + self.spawn_object(name, obj_type, pose) + + return self._update_object_id_name_maps_and_get_latest_id(name) + + def spawn_object(self, name: str, object_type: ObjectType, pose: Pose) -> None: + """ + Spawn the object in the simulator. + + :param name: The name of the object. + :param object_type: The type of the object. + :param pose: The pose of the object. + """ + if object_type == ObjectType.ROBOT and self.conf.use_controller: + self.spawn_robot_with_controller(name, pose) + else: + self._set_body_pose(name, pose) + + def _update_object_id_name_maps_and_get_latest_id(self, name: str) -> int: + """ + Update the object id name maps and return the latest object id. + + :param name: The name of the object. + :return: The latest object id. + """ + self.last_object_id += 1 + self.object_name_to_id[name] = self.last_object_id + self.object_id_to_name[self.last_object_id] = name + return self.last_object_id + + def get_object_joint_names(self, obj: Object) -> List[str]: + return [joint.name for joint in obj.description.joints if joint.type in self.supported_joint_types] + + def get_object_link_names(self, obj: Object) -> List[str]: + return [link.name for link in obj.description.links] + + def get_link_position(self, link: Link) -> List[float]: + return self.reader.get_body_position(link.name) + + def get_link_orientation(self, link: Link) -> List[float]: + return self.reader.get_body_orientation(link.name) + + def get_multiple_link_positions(self, links: List[Link]) -> Dict[str, List[float]]: + return self.reader.get_multiple_body_positions([link.name for link in links]) + + def get_multiple_link_orientations(self, links: List[Link]) -> Dict[str, List[float]]: + return self.reader.get_multiple_body_orientations([link.name for link in links]) + + @validate_joint_position + def reset_joint_position(self, joint: Joint, joint_position: float) -> bool: + if self.conf.use_controller and self.joint_has_actuator(joint): + self._reset_joint_position_using_controller(joint, joint_position) + else: + self._set_multiple_joint_positions_without_controller({joint: joint_position}) + return True + + def _reset_joint_position_using_controller(self, joint: Joint, joint_position: float) -> bool: + """ + Reset the position of a joint in the simulator using the controller. + + :param joint: The joint. + :param joint_position: The position of the joint. + :return: True if the joint position is reset successfully. + """ + self.joint_controller.set_body_property(self.get_actuator_for_joint(joint), + self.get_joint_cmd_name(joint.type), + [joint_position]) + return True + + @validate_multiple_joint_positions + def set_multiple_joint_positions(self, joint_positions: Dict[Joint, float]) -> bool: + """ + Set the positions of multiple joints in the simulator. Also check if the joint is controlled by an actuator + and use the controller to set the joint position if the joint is controlled. + + :param joint_positions: The dictionary of joints and positions. + :return: True if the joint positions are set successfully (this means that the joint positions are set without + errors, but not necessarily that the joint positions are set to the specified values). + """ + + if self.conf.use_controller: + controlled_joints = self.get_controlled_joints(list(joint_positions.keys())) + if len(controlled_joints) > 0: + controlled_joint_positions = {joint: joint_positions[joint] for joint in controlled_joints} + self._set_multiple_joint_positions_using_controller(controlled_joint_positions) + joint_positions = {joint: joint_positions[joint] for joint in joint_positions.keys() + if joint not in controlled_joints} + if len(joint_positions) > 0: + self._set_multiple_joint_positions_without_controller(joint_positions) + + return True + + def get_controlled_joints(self, joints: Optional[List[Joint]] = None) -> List[Joint]: + """ + Get the joints that are controlled by an actuator from the list of joints. + + :param joints: The list of joints to check. + :return: The list of controlled joints. + """ + joints = self.robot.joints if joints is None else joints + return [joint for joint in joints if self.joint_has_actuator(joint)] + + def _set_multiple_joint_positions_without_controller(self, joint_positions: Dict[Joint, float]) -> None: + """ + Set the positions of multiple joints in the simulator without using the controller. + + :param joint_positions: The dictionary of joints and positions. + """ + joints_data = {joint.name: {self.get_joint_position_name(joint): [position]} + for joint, position in joint_positions.items()} + self.writer.send_multiple_body_data_to_server(joints_data) + + def _set_multiple_joint_positions_using_controller(self, joint_positions: Dict[Joint, float]) -> bool: + """ + Set the positions of multiple joints in the simulator using the controller. + + :param joint_positions: The dictionary of joints and positions. + """ + controlled_joints_data = {self.get_actuator_for_joint(joint): + {self.get_joint_cmd_name(joint.type): [position]} + for joint, position in joint_positions.items()} + self.joint_controller.send_multiple_body_data_to_server(controlled_joints_data) + return True + + def get_joint_position(self, joint: Joint) -> Optional[float]: + joint_position_name = self.get_joint_position_name(joint) + data = self.reader.get_body_data(joint.name, [joint_position_name]) + if data is not None: + return data[joint_position_name.value][0] + + def get_multiple_joint_positions(self, joints: List[Joint]) -> Optional[Dict[str, float]]: + joint_names = [joint.name for joint in joints] + data = self.reader.get_multiple_body_data(joint_names, {joint.name: [self.get_joint_position_name(joint)] + for joint in joints}) + if data is not None: + return {name: list(value.values())[0][0] for name, value in data.items()} + + @staticmethod + def get_joint_cmd_name(joint_type: JointType) -> MultiverseJointCMD: + """ + Get the attribute name of the joint command in the Multiverse from the pycram joint type. + + :param joint_type: The pycram joint type. + """ + return MultiverseJointCMD.from_pycram_joint_type(joint_type) + + def get_link_pose(self, link: Link) -> Optional[Pose]: + return self._get_body_pose(link.name) + + def get_multiple_link_poses(self, links: List[Link]) -> Dict[str, Pose]: + return self._get_multiple_body_poses([link.name for link in links]) + + def get_object_pose(self, obj: Object) -> Pose: + if obj.has_type_environment(): + return Pose() + return self._get_body_pose(obj.name) + + def get_multiple_object_poses(self, objects: List[Object]) -> Dict[str, Pose]: + """ + Set the poses of multiple objects in the simulator. If the object is of type environment, the pose will be + the default pose. + + :param objects: The list of objects. + :return: The dictionary of object names and poses. + """ + non_env_objects = [obj for obj in objects if not obj.has_type_environment()] + all_poses = self._get_multiple_body_poses([obj.name for obj in non_env_objects]) + all_poses.update({obj.name: Pose() for obj in objects if obj.has_type_environment()}) + return all_poses + + @validate_object_pose + def reset_object_base_pose(self, obj: Object, pose: Pose) -> bool: + if obj.has_type_environment(): + return False + + if (obj.obj_type == ObjectType.ROBOT and + RobotDescription.current_robot_description.virtual_mobile_base_joints is not None): + obj.set_mobile_robot_pose(pose) + else: + self._set_body_pose(obj.name, pose) + + return True + + @validate_multiple_object_poses + def reset_multiple_objects_base_poses(self, objects: Dict[Object, Pose]) -> None: + """ + Reset the poses of multiple objects in the simulator. + + :param objects: The dictionary of objects and poses. + """ + for obj in objects.keys(): + if (obj.obj_type == ObjectType.ROBOT and + RobotDescription.current_robot_description.virtual_mobile_base_joints is not None): + obj.set_mobile_robot_pose(objects[obj]) + objects = {obj: pose for obj, pose in objects.items() if obj.obj_type not in [ObjectType.ENVIRONMENT, + ObjectType.ROBOT]} + self._set_multiple_body_poses({obj.name: pose for obj, pose in objects.items()}) + + def _set_body_pose(self, body_name: str, pose: Pose) -> None: + """ + Reset the pose of a body (object, link, or joint) in the simulator. + + :param body_name: The name of the body. + :param pose: The pose of the body. + """ + self._set_multiple_body_poses({body_name: pose}) + + def _set_multiple_body_poses(self, body_poses: Dict[str, Pose]) -> None: + """ + Reset the poses of multiple bodies in the simulator. + + :param body_poses: The dictionary of body names and poses. + """ + self.writer.set_multiple_body_poses({name: {MultiverseBodyProperty.POSITION: pose.position_as_list(), + MultiverseBodyProperty.ORIENTATION: + xyzw_to_wxyz(pose.orientation_as_list()), + MultiverseBodyProperty.RELATIVE_VELOCITY: [0.0] * 6} + for name, pose in body_poses.items()}) + + def _get_body_pose(self, body_name: str, wait: Optional[bool] = True) -> Optional[Pose]: + """ + Get the pose of a body in the simulator. + + :param body_name: The name of the body. + :param wait: Whether to wait until the pose is received. + :return: The pose of the body. + """ + data = self.reader.get_body_pose(body_name, wait) + return Pose(data[MultiverseBodyProperty.POSITION.value], + wxyz_to_xyzw(data[MultiverseBodyProperty.ORIENTATION.value])) + + def _get_multiple_body_poses(self, body_names: List[str]) -> Dict[str, Pose]: + """ + Get the poses of multiple bodies in the simulator. + + :param body_names: The list of body names. + """ + return self.reader.get_multiple_body_poses(body_names) + + def get_multiple_object_positions(self, objects: List[Object]) -> Dict[str, List[float]]: + return self.reader.get_multiple_body_positions([obj.name for obj in objects]) + + def get_object_position(self, obj: Object) -> List[float]: + return self.reader.get_body_position(obj.name) + + def get_multiple_object_orientations(self, objects: List[Object]) -> Dict[str, List[float]]: + return self.reader.get_multiple_body_orientations([obj.name for obj in objects]) + + def get_object_orientation(self, obj: Object) -> List[float]: + return self.reader.get_body_orientation(obj.name) + + def multiverse_reset_world(self): + """ + Reset the world using the Multiverse API. + """ + self.writer.reset_world() + + def disconnect_from_physics_server(self) -> None: + MultiverseClientManager.stop_all_clients() + + def join_threads(self) -> None: + self.reader.stop_thread = True + self.reader.join() + + def _remove_visual_object(self, obj_id: int) -> bool: + rospy.logwarn("Currently multiverse does not create visual objects") + return False + + def remove_object_from_simulator(self, obj: Object) -> bool: + if obj.obj_type != ObjectType.ENVIRONMENT: + self.writer.remove_body(obj.name) + return True + rospy.logwarn("Cannot remove environment objects") + return False + + def add_constraint(self, constraint: Constraint) -> int: + + if constraint.type != JointType.FIXED: + logging.error("Only fixed constraints are supported in Multiverse") + raise ValueError + + if not self.conf.let_pycram_move_attached_objects: + self.api_requester.attach(constraint) + + return self._update_constraint_collection_and_get_latest_id(constraint) + + def _update_constraint_collection_and_get_latest_id(self, constraint: Constraint) -> int: + """ + Update the constraint collection and return the latest constraint id. + + :param constraint: The constraint to be added. + :return: The latest constraint id. + """ + self.last_constraint_id += 1 + self.constraints[self.last_constraint_id] = constraint + return self.last_constraint_id + + def remove_constraint(self, constraint_id) -> None: + constraint = self.constraints.pop(constraint_id) + self.api_requester.detach(constraint) + + def perform_collision_detection(self) -> None: + ... + + def get_object_contact_points(self, obj: Object) -> ContactPointsList: + """ + Note: Currently Multiverse only gets one contact point per contact objects. + """ + multiverse_contact_points = self.api_requester.get_contact_points(obj) + contact_points = ContactPointsList([]) + body_link = None + for point in multiverse_contact_points: + if point.body_name == "world": + point.body_name = "floor" + body_object = self.get_object_by_name(point.body_name) + if body_object is None: + for obj in self.objects: + for link in obj.links.values(): + if link.name == point.body_name: + body_link = link + break + else: + body_link = body_object.root_link + if body_link is None: + logging.error(f"Body link not found: {point.body_name}") + raise ValueError(f"Body link not found: {point.body_name}") + contact_points.append(ContactPoint(obj.root_link, body_link)) + contact_points[-1].force_x_in_world_frame = point.contact_force[0] + contact_points[-1].force_y_in_world_frame = point.contact_force[1] + contact_points[-1].force_z_in_world_frame = point.contact_force[2] + contact_points[-1].normal_on_b = point.contact_force[2] + contact_points[-1].normal_force = point.contact_force[2] + return contact_points + + @staticmethod + def _get_normal_force_on_object_from_contact_force(obj: Object, contact_force: List[float]) -> float: + """ + Get the normal force on an object from the contact force exerted by another object that is expressed in the + world frame. Thus transforming the contact force to the object frame is necessary. + + :param obj: The object. + :param contact_force: The contact force. + :return: The normal force on the object. + """ + obj_quat = obj.get_orientation_as_list() + obj_rot_matrix = quaternion_matrix(obj_quat)[:3, :3] + # invert the rotation matrix to get the transformation from world to object frame + obj_rot_matrix = np.linalg.inv(obj_rot_matrix) + contact_force_array = obj_rot_matrix @ np.array(contact_force).reshape(3, 1) + return contact_force_array.flatten().tolist()[2] + + def get_contact_points_between_two_objects(self, obj1: Object, obj2: Object) -> ContactPointsList: + obj1_contact_points = self.get_object_contact_points(obj1) + return obj1_contact_points.get_points_of_object(obj2) + + def ray_test(self, from_position: List[float], to_position: List[float]) -> Optional[int]: + ray_test_result = self.ray_test_batch([from_position], [to_position])[0] + return ray_test_result[0] if ray_test_result[0] != -1 else None + + def ray_test_batch(self, from_positions: List[List[float]], + to_positions: List[List[float]], + num_threads: int = 1, + return_distance: bool = False) -> Union[List, Tuple[List, List[float]]]: + """ + Note: Currently, num_threads is not used in Multiverse. + """ + ray_results = self.api_requester.get_objects_intersected_with_rays(from_positions, to_positions) + results = [] + distances = [] + for ray_result in ray_results: + results.append([]) + if ray_result.intersected(): + body_name = ray_result.body_name + if body_name == "world": + results[-1].append(0) # The floor id, which is always 0 since the floor is spawned first. + elif body_name in self.object_name_to_id.keys(): + results[-1].append(self.object_name_to_id[body_name]) + else: + for obj in self.objects: + if body_name in obj.links.keys(): + results[-1].append(obj.id) + break + else: + results[-1].append(-1) + if return_distance: + distances.append(ray_result.distance) + if return_distance: + return results, distances + else: + return results + + def step(self): + """ + Perform a simulation step in the simulator, this is useful when use_static_mode is True. + """ + if self.conf.use_static_mode: + self.api_requester.unpause_simulation() + sleep(self.simulation_time_step) + self.api_requester.pause_simulation() + + def save_physics_simulator_state(self, use_same_id: bool = False) -> int: + self.latest_save_id = 0 if self.latest_save_id is None else self.latest_save_id + int(not use_same_id) + save_name = f"save_{self.latest_save_id}" + self.saved_simulator_states[self.latest_save_id] = self.api_requester.save(save_name) + return self.latest_save_id + + def remove_physics_simulator_state(self, state_id: int) -> None: + self.saved_simulator_states.pop(state_id) + + def restore_physics_simulator_state(self, state_id: int) -> None: + self.api_requester.load(self.saved_simulator_states[state_id]) + + def set_link_color(self, link: Link, rgba_color: Color): + logging.warning("set_link_color is not implemented in Multiverse") + + def get_link_color(self, link: Link) -> Color: + logging.warning("get_link_color is not implemented in Multiverse") + return Color() + + def get_colors_of_object_links(self, obj: Object) -> Dict[str, Color]: + logging.warning("get_colors_of_object_links is not implemented in Multiverse") + return {} + + def get_object_axis_aligned_bounding_box(self, obj: Object) -> AxisAlignedBoundingBox: + logging.error("get_object_axis_aligned_bounding_box is not implemented in Multiverse") + raise NotImplementedError + + def get_link_axis_aligned_bounding_box(self, link: Link) -> AxisAlignedBoundingBox: + logging.error("get_link_axis_aligned_bounding_box is not implemented in Multiverse") + raise NotImplementedError + + def set_realtime(self, real_time: bool) -> None: + logging.warning("set_realtime is not implemented as an API in Multiverse, it is configured in the" + "multiverse configuration file (.muv file) as rtf_required where a value of 1 means real-time") + + def set_gravity(self, gravity_vector: List[float]) -> None: + logging.warning("set_gravity is not implemented in Multiverse") + + def check_object_exists(self, obj: Object) -> bool: + """ + Check if the object exists in the Multiverse world. + + :param obj: The object. + :return: True if the object exists, False otherwise. + """ + return self.api_requester.check_object_exists(obj) diff --git a/src/pycram/worlds/multiverse_communication/__init__.py b/src/pycram/worlds/multiverse_communication/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pycram/worlds/multiverse_communication/client_manager.py b/src/pycram/worlds/multiverse_communication/client_manager.py new file mode 100644 index 000000000..a1b768172 --- /dev/null +++ b/src/pycram/worlds/multiverse_communication/client_manager.py @@ -0,0 +1,97 @@ +from typing_extensions import Optional, Type, Union, Dict + +from ...worlds.multiverse_communication.clients import MultiverseWriter, MultiverseAPI, MultiverseClient, \ + MultiverseReader, MultiverseController + +from ...config.multiverse_conf import MultiverseConfig as Conf + + +class MultiverseClientManager: + BASE_PORT: int = Conf.BASE_CLIENT_PORT + """ + The base port of the Multiverse client. + """ + clients: Optional[Dict[str, MultiverseClient]] = {} + """ + The list of Multiverse clients. + """ + last_used_port: int = BASE_PORT + + def __init__(self, simulation_wait_time_factor: Optional[float] = 1.0): + """ + Initialize the Multiverse client manager. + + :param simulation_wait_time_factor: The simulation wait time factor. + """ + self.simulation_wait_time_factor = simulation_wait_time_factor + + def create_reader(self, is_prospection_world: Optional[bool] = False) -> MultiverseReader: + """ + Create a Multiverse reader client. + + :param is_prospection_world: Whether the reader is connected to the prospection world. + """ + return self.create_client(MultiverseReader, "reader", is_prospection_world) + + def create_writer(self, simulation: str, is_prospection_world: Optional[bool] = False) -> MultiverseWriter: + """ + Create a Multiverse writer client. + + :param simulation: The name of the simulation that the writer is connected to + (usually the name defined in the .muv file). + :param is_prospection_world: Whether the writer is connected to the prospection world. + """ + return self.create_client(MultiverseWriter, "writer", is_prospection_world, + simulation=simulation) + + def create_controller(self, is_prospection_world: Optional[bool] = False) -> MultiverseController: + """ + Create a Multiverse controller client. + + :param is_prospection_world: Whether the controller is connected to the prospection world. + """ + return self.create_client(MultiverseController, "controller", is_prospection_world) + + def create_api_requester(self, simulation: str, is_prospection_world: Optional[bool] = False) -> MultiverseAPI: + """ + Create a Multiverse API client. + + :param simulation: The name of the simulation that the API is connected to + (usually the name defined in the .muv file). + :param is_prospection_world: Whether the API is connected to the prospection world. + """ + return self.create_client(MultiverseAPI, "api_requester", is_prospection_world, simulation=simulation) + + def create_client(self, + client_type: Type[MultiverseClient], + name: Optional[str] = None, + is_prospection_world: Optional[bool] = False, + **kwargs) -> Union[MultiverseClient, MultiverseAPI, + MultiverseReader, MultiverseWriter, MultiverseController]: + """ + Create a Multiverse client. + + :param client_type: The type of the client to create. + :param name: The name of the client. + :param is_prospection_world: Whether the client is connected to the prospection world. + :param kwargs: Any other keyword arguments that should be passed to the client constructor. + """ + MultiverseClientManager.last_used_port += 1 + name = (name or client_type.__name__) + f"_{self.last_used_port}" + client = client_type(name, self.last_used_port, is_prospection_world=is_prospection_world, + simulation_wait_time_factor=self.simulation_wait_time_factor, **kwargs) + self.clients[name] = client + return client + + @classmethod + def stop_all_clients(cls): + """ + Stop all clients. + """ + for client in cls.clients: + if isinstance(client, MultiverseReader): + client.stop_thread = True + client.join() + elif isinstance(client, MultiverseClient): + client.stop() + cls.clients = {} diff --git a/src/pycram/worlds/multiverse_communication/clients.py b/src/pycram/worlds/multiverse_communication/clients.py new file mode 100644 index 000000000..d98983ab1 --- /dev/null +++ b/src/pycram/worlds/multiverse_communication/clients.py @@ -0,0 +1,832 @@ +import datetime +import logging +import os +import threading +from time import time, sleep + +import rospy +from typing_extensions import List, Dict, Tuple, Optional, Callable, Union + +from .socket import MultiverseSocket, MultiverseMetaData +from ...config.multiverse_conf import MultiverseConfig as Conf +from ...datastructures.dataclasses import RayResult, MultiverseContactPoint +from ...datastructures.enums import (MultiverseAPIName as API, MultiverseBodyProperty as BodyProperty, + MultiverseProperty as Property) +from ...datastructures.pose import Pose +from ...utils import wxyz_to_xyzw +from ...world_concepts.constraints import Constraint +from ...world_concepts.world_object import Object, Link + + +class MultiverseClient(MultiverseSocket): + + def __init__(self, name: str, port: int, is_prospection_world: bool = False, + simulation_wait_time_factor: float = 1.0, **kwargs): + """ + Initialize the Multiverse client, which connects to the Multiverse server. + + :param name: The name of the client. + :param port: The port of the client. + :param is_prospection_world: Whether the client is connected to the prospection world. + :param simulation_wait_time_factor: The simulation wait time factor (default is 1.0), which can be used to + increase or decrease the wait time for the simulation. + """ + meta_data = MultiverseMetaData() + meta_data.simulation_name = (Conf.prospection_world_prefix if is_prospection_world else "") + name + meta_data.world_name = ((Conf.prospection_world_prefix if is_prospection_world else "") + + meta_data.world_name) + self.is_prospection_world = is_prospection_world + super().__init__(port=str(port), meta_data=meta_data) + self.simulation_wait_time_factor = simulation_wait_time_factor + self.run() + + +class MultiverseReader(MultiverseClient): + MAX_WAIT_TIME_FOR_DATA: datetime.timedelta = Conf.READER_MAX_WAIT_TIME_FOR_DATA + """ + The maximum wait time for the data in seconds. + """ + + def __init__(self, name: str, port: int, is_prospection_world: bool = False, + simulation_wait_time_factor: float = 1.0, **kwargs): + """ + Initialize the Multiverse reader, which reads the data from the Multiverse server in a separate thread. + This class provides methods to get data (e.g., position, orientation) from the Multiverse server. + + :param port: The port of the Multiverse reader client. + :param is_prospection_world: Whether the reader is connected to the prospection world. + :param simulation_wait_time_factor: The simulation wait time factor. + """ + super().__init__(name, port, is_prospection_world, simulation_wait_time_factor=simulation_wait_time_factor) + + self.request_meta_data["receive"][""] = [""] + + self.data_lock = threading.Lock() + self.thread = threading.Thread(target=self.receive_all_data_from_server) + self.stop_thread = False + + self.thread.start() + + def get_body_pose(self, name: str, wait: bool = False) -> Optional[Dict[str, List[float]]]: + """ + Get the body pose from the multiverse server. + + :param name: The name of the body. + :param wait: Whether to wait for the data. + :return: The position and orientation of the body. + """ + return self.get_body_data(name, [BodyProperty.POSITION, BodyProperty.ORIENTATION], wait=wait) + + def get_multiple_body_poses(self, body_names: List[str], wait: bool = False) -> Optional[Dict[str, Pose]]: + """ + Get the body poses from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param wait: Whether to wait for the data. + :return: The positions and orientations of the bodies as a dictionary. + """ + data = self.get_multiple_body_data(body_names, + {name: [BodyProperty.POSITION, BodyProperty.ORIENTATION] + for name in body_names + }, + wait=wait) + if data is not None: + return {name: Pose(data[name][BodyProperty.POSITION.value], + wxyz_to_xyzw(data[name][BodyProperty.ORIENTATION.value])) + for name in body_names} + + def get_body_position(self, name: str, wait: bool = False) -> Optional[List[float]]: + """ + Get the body position from the multiverse server. + + :param name: The name of the body. + :param wait: Whether to wait for the data. + :return: The position of the body. + """ + return self.get_body_property(name, BodyProperty.POSITION, wait=wait) + + def get_multiple_body_positions(self, body_names: List[str], + wait: bool = False) -> Optional[Dict[str, List[float]]]: + """ + Get the body positions from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param wait: Whether to wait for the data. + :return: The positions of the bodies as a dictionary. + """ + return self.get_multiple_body_properties(body_names, [BodyProperty.POSITION], wait=wait) + + def get_body_orientation(self, name: str, wait: bool = False) -> Optional[List[float]]: + """ + Get the body orientation from the multiverse server. + + :param name: The name of the body. + :param wait: Whether to wait for the data. + :return: The orientation of the body. + """ + return self.get_body_property(name, BodyProperty.ORIENTATION, wait=wait) + + def get_multiple_body_orientations(self, body_names: List[str], + wait: bool = False) -> Optional[Dict[str, List[float]]]: + """ + Get the body orientations from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param wait: Whether to wait for the data. + :return: The orientations of the bodies as a dictionary. + """ + data = self.get_multiple_body_properties(body_names, [BodyProperty.ORIENTATION], wait=wait) + if data is not None: + return {name: wxyz_to_xyzw(data[name][BodyProperty.ORIENTATION.value]) for name in body_names} + + def get_body_property(self, name: str, property_: Property, wait: bool = False) -> Optional[List[float]]: + """ + Get the body property from the multiverse server. + + :param name: The name of the body. + :param property_: The property of the body as a Property. + :param wait: Whether to wait for the data. + :return: The property of the body. + """ + data = self.get_body_data(name, [property_], wait=wait) + if data is not None: + return data[property_.value] + + def get_multiple_body_properties(self, body_names: List[str], properties: List[Property], + wait: bool = False) -> Optional[Dict[str, Dict[str, List[float]]]]: + """ + Get the body properties from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param properties: The properties of the bodies. + :param wait: Whether to wait for the data. + :return: The properties of the bodies as a dictionary. + """ + return self.get_multiple_body_data(body_names, {name: properties for name in body_names}, wait=wait) + + def get_body_data(self, name: str, + properties: Optional[List[Property]] = None, + wait: bool = False) -> Optional[Dict]: + """ + Get the body data from the multiverse server. + + :param name: The name of the body. + :param properties: The properties of the body. + :param wait: Whether to wait for the data. + :return: The body data as a dictionary. + """ + if wait: + return self.wait_for_body_data(name, properties) + + data = self.get_received_data() + if self.check_for_body_data(name, data, properties): + return data[name] + + def get_multiple_body_data(self, body_names: List[str], + properties: Optional[Dict[str, List[Property]]] = None, + wait: bool = False) -> Optional[Dict]: + """ + Get the body data from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param properties: The properties of the bodies. + :param wait: Whether to wait for the data. + :return: The body data as a dictionary. + """ + + if wait: + return self.wait_for_multiple_body_data(body_names, properties) + + data = self.get_received_data() + if self.check_multiple_body_data(body_names, data, properties): + return {name: data[name] for name in body_names} + + def wait_for_body_data(self, name: str, properties: Optional[List[Property]] = None) -> Dict: + """ + Wait for the body data from the multiverse server. + + :param name: The name of the body. + :param properties: The properties of the body. + :return: The body data as a dictionary. + """ + return self._wait_for_body_data_template(name, self.check_for_body_data, properties)[name] + + def wait_for_multiple_body_data(self, body_names: List[str], + properties: Optional[Dict[str, List[Property]]] = None) -> Dict: + """ + Wait for the body data from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param properties: The properties of the bodies. + :return: The body data as a dictionary. + """ + return self._wait_for_body_data_template(body_names, self.check_multiple_body_data, properties) + + def _wait_for_body_data_template(self, body_names: Union[str, List[str]], + check_func: Callable[[Union[str, List[str]], Dict, Union[Dict, List]], bool], + properties: Optional[Union[Dict, List]] = None) -> Dict: + """ + Wait for the body data from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param properties: The properties of the bodies. + :param check_func: The function to check if the data is received. + :return: The body data as a dictionary. + """ + start = time() + data_received_flag = False + while time() - start < self.MAX_WAIT_TIME_FOR_DATA.total_seconds(): + received_data = self.get_received_data() + data_received_flag = check_func(body_names, received_data, properties) + if data_received_flag: + return received_data + if not data_received_flag: + properties_str = "Data" if properties is None else f"Properties {properties}" + msg = f"{properties_str} for {body_names} not received within {self.MAX_WAIT_TIME_FOR_DATA} seconds" + logging.error(msg) + raise ValueError(msg) + + def check_multiple_body_data(self, body_names: List[str], data: Dict, + properties: Optional[Dict[str, List[Property]]] = None) -> bool: + """ + Check if the body data is received from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param data: The data received from the multiverse server. + :param properties: The properties of the bodies. + :return: Whether the body data is received. + """ + if properties is None: + return all([self.check_for_body_data(name, data) for name in body_names]) + else: + return all([self.check_for_body_data(name, data, properties[name]) for name in body_names]) + + @staticmethod + def check_for_body_data(name: str, data: Dict, properties: Optional[List[Property]] = None) -> bool: + """ + Check if the body data is received from the multiverse server. + + :param name: The name of the body. + :param data: The data received from the multiverse server. + :param properties: The properties of the body. + :return: Whether the body data is received. + """ + if properties is None: + return name in data + else: + return name in data and all([prop.value in data[name] and None not in data[name][prop.value] + for prop in properties]) + + def get_received_data(self): + """ + Get the latest received data from the multiverse server. + """ + self.data_lock.acquire() + data = self.response_meta_data["receive"] + self.data_lock.release() + return data + + def receive_all_data_from_server(self): + """ + Get all data from the multiverse server. + """ + while not self.stop_thread: + self.request_meta_data["receive"][""] = [""] + self.data_lock.acquire() + self.send_and_receive_meta_data() + self.data_lock.release() + sleep(0.01) + self.stop() + + def join(self): + self.thread.join() + + +class MultiverseWriter(MultiverseClient): + + def __init__(self, name: str, port: int, simulation: Optional[str] = None, + is_prospection_world: bool = False, + simulation_wait_time_factor: float = 1.0, **kwargs): + """ + Initialize the Multiverse writer, which writes the data to the Multiverse server. + This class provides methods to send data (e.g., position, orientation) to the Multiverse server. + + :param port: The port of the Multiverse writer client. + :param simulation: The name of the simulation that the writer is connected to + (usually the name defined in the .muv file). + :param is_prospection_world: Whether the writer is connected to the prospection world. + :param simulation_wait_time_factor: The wait time factor for the simulation (default is 1.0), which can be used + to increase or decrease the wait time for the simulation. + """ + super().__init__(name, port, is_prospection_world, simulation_wait_time_factor=simulation_wait_time_factor) + self.simulation = simulation + + def spawn_robot_with_actuators(self, robot_name: str, position: List[float], orientation: List[float], + actuator_joint_commands: Optional[Dict[str, List[str]]] = None) -> None: + """ + Spawn the robot with controlled actuators in the simulation. + + :param robot_name: The name of the robot. + :param position: The position of the robot. + :param orientation: The orientation of the robot. + :param actuator_joint_commands: A dictionary mapping actuator names to joint command names. + """ + send_meta_data = {robot_name: [BodyProperty.POSITION.value, BodyProperty.ORIENTATION.value, + BodyProperty.RELATIVE_VELOCITY.value]} + relative_velocity = [0.0] * 6 + data = [self.sim_time, *position, *orientation, *relative_velocity] + self.send_data_to_server(data, send_meta_data=send_meta_data, receive_meta_data=actuator_joint_commands) + + def _reset_request_meta_data(self, set_simulation_name: bool = True): + """ + Reset the request metadata. + + :param set_simulation_name: Whether to set the simulation name to the value of self.simulation_name. + """ + self.request_meta_data = { + "meta_data": self._meta_data.__dict__.copy(), + "send": {}, + "receive": {}, + } + if self.simulation is not None and set_simulation_name: + self.request_meta_data["meta_data"]["simulation_name"] = self.simulation + + def set_body_pose(self, body_name: str, position: List[float], orientation: List[float]) -> None: + """ + Set the body pose in the simulation. + + :param body_name: The name of the body. + :param position: The position of the body. + :param orientation: The orientation of the body. + """ + self.send_body_data_to_server(body_name, + {BodyProperty.POSITION: position, + BodyProperty.ORIENTATION: orientation, + BodyProperty.RELATIVE_VELOCITY: [0.0] * 6}) + + def set_multiple_body_poses(self, body_data: Dict[str, Dict[BodyProperty, List[float]]]) -> None: + """ + Set the body poses in the simulation for multiple bodies. + + :param body_data: The data to be sent for multiple bodies. + """ + self.send_multiple_body_data_to_server(body_data) + + def set_body_position(self, body_name: str, position: List[float]) -> None: + """ + Set the body position in the simulation. + + :param body_name: The name of the body. + :param position: The position of the body. + """ + self.set_body_property(body_name, BodyProperty.POSITION, position) + + def set_body_orientation(self, body_name: str, orientation: List[float]) -> None: + """ + Set the body orientation in the simulation. + + :param body_name: The name of the body. + :param orientation: The orientation of the body. + """ + self.set_body_property(body_name, BodyProperty.ORIENTATION, orientation) + + def set_body_property(self, body_name: str, property_: Property, value: List[float]) -> None: + """ + Set the body property in the simulation. + + :param body_name: The name of the body. + :param property_: The property of the body. + :param value: The value of the property. + """ + self.send_body_data_to_server(body_name, {property_: value}) + + def remove_body(self, body_name: str) -> None: + """ + Remove the body from the simulation. + + :param body_name: The name of the body. + """ + self.send_data_to_server([self.sim_time], + send_meta_data={body_name: []}, + receive_meta_data={body_name: []}) + + def reset_world(self) -> None: + """ + Reset the world in the simulation. + """ + self.send_data_to_server([0], set_simulation_name=False) + + def send_body_data_to_server(self, body_name: str, body_data: Dict[Property, List[float]]) -> Dict: + """ + Send data to the multiverse server. + + :param body_name: The name of the body. + :param body_data: The data to be sent. + :return: The response from the server. + """ + send_meta_data = {body_name: list(map(str, body_data.keys()))} + flattened_data = [value for data in body_data.values() for value in data] + return self.send_data_to_server([self.sim_time, *flattened_data], send_meta_data=send_meta_data) + + def send_multiple_body_data_to_server(self, body_data: Dict[str, Dict[Property, List[float]]]) -> Dict: + """ + Send data to the multiverse server for multiple bodies. + + :param body_data: The data to be sent for multiple bodies. + :return: The response from the server. + """ + send_meta_data = {body_name: list(map(str, data.keys())) for body_name, data in body_data.items()} + response_meta_data = self.send_meta_data_and_get_response(send_meta_data) + body_names = list(response_meta_data["send"].keys()) + flattened_data = [value for body_name in body_names for data in body_data[body_name].values() + for value in data] + self.send_data = [self.sim_time, *flattened_data] + self.send_and_receive_data() + return self.response_meta_data + + def send_meta_data_and_get_response(self, send_meta_data: Dict) -> Dict: + """ + Send metadata to the multiverse server and get the response. + + :param send_meta_data: The metadata to be sent. + :return: The response from the server. + """ + self._reset_request_meta_data() + self.request_meta_data["send"] = send_meta_data + self.send_and_receive_meta_data() + return self.response_meta_data + + def send_data_to_server(self, data: List, + send_meta_data: Optional[Dict] = None, + receive_meta_data: Optional[Dict] = None, + set_simulation_name: bool = True) -> Dict: + """ + Send data to the multiverse server. + + :param data: The data to be sent. + :param send_meta_data: The metadata to be sent. + :param receive_meta_data: The metadata to be received. + :param set_simulation_name: Whether to set the simulation name to the value of self.simulation. + :return: The response from the server. + """ + self._reset_request_meta_data(set_simulation_name=set_simulation_name) + if send_meta_data: + self.request_meta_data["send"] = send_meta_data + if receive_meta_data: + self.request_meta_data["receive"] = receive_meta_data + self.send_and_receive_meta_data() + self.send_data = data + self.send_and_receive_data() + return self.response_meta_data + + +class MultiverseController(MultiverseWriter): + + def __init__(self, name: str, port: int, is_prospection_world: bool = False, **kwargs): + """ + Initialize the Multiverse controller, which controls the robot in the simulation. + This class provides methods to send controller data to the Multiverse server. + + :param port: The port of the Multiverse controller client. + :param is_prospection_world: Whether the controller is connected to the prospection world. + """ + super().__init__(name, port, is_prospection_world=is_prospection_world) + + def init_controller(self, actuator_joint_commands: Dict[str, List[str]]) -> None: + """ + Initialize the controller by sending the controller data to the multiverse server. + + :param actuator_joint_commands: A dictionary mapping actuator names to joint command names. + """ + self.send_data_to_server([self.sim_time] + [0.0] * len(actuator_joint_commands), + send_meta_data=actuator_joint_commands) + + +class MultiverseAPI(MultiverseClient): + API_REQUEST_WAIT_TIME: datetime.timedelta = datetime.timedelta(milliseconds=200) + """ + The wait time for the API request in seconds. + """ + APIs_THAT_NEED_WAIT_TIME: List[API] = [API.ATTACH] + + def __init__(self, name: str, port: int, simulation: str, is_prospection_world: bool = False, + simulation_wait_time_factor: float = 1.0): + """ + Initialize the Multiverse API, which sends API requests to the Multiverse server. + This class provides methods like attach and detach objects, get contact points, and other API requests. + + :param port: The port of the Multiverse API client. + :param simulation: The name of the simulation that the API is connected to + (usually the name defined in the .muv file). + :param is_prospection_world: Whether the API is connected to the prospection world. + :param simulation_wait_time_factor: The simulation wait time factor, which can be used to increase or decrease + the wait time for the simulation. + """ + super().__init__(name, port, is_prospection_world, simulation_wait_time_factor=simulation_wait_time_factor) + self.simulation = simulation + self.wait: bool = False # Whether to wait after sending the API request. + + def save(self, save_name: str, save_directory: Optional[str] = None) -> str: + """ + Save the current state of the simulation. + + :param save_name: The name of the save. + :param save_directory: The path to save the simulation, can be relative or absolute. If the path is relative, + it will be saved in the saved folder in multiverse. + :return: The save path. + """ + response = self._request_single_api_callback(API.SAVE, self.get_save_path(save_name, save_directory)) + return response[0] + + def load(self, save_name: str, save_directory: Optional[str] = None) -> None: + """ + Load the saved state of the simulation. + + :param save_name: The name of the save. + :param save_directory: The path to load the simulation, can be relative or absolute. If the path is relative, + it will be loaded from the saved folder in multiverse. + """ + self._request_single_api_callback(API.LOAD, self.get_save_path(save_name, save_directory)) + + @staticmethod + def get_save_path(save_name: str, save_directory: Optional[str] = None) -> str: + """ + Get the save path. + + :param save_name: The save name. + :param save_directory: The save directory. + :return: The save path. + """ + return save_name if save_directory is None else os.path.join(save_directory, save_name) + + def attach(self, constraint: Constraint) -> None: + """ + Request to attach the child link to the parent link. + + :param constraint: The constraint. + """ + self.wait = True + parent_link_name, child_link_name = self.get_constraint_link_names(constraint) + attachment_pose = self._get_attachment_pose_as_string(constraint) + self._attach(child_link_name, parent_link_name, attachment_pose) + + def _attach(self, child_link_name: str, parent_link_name: str, attachment_pose: str) -> None: + """ + Attach the child link to the parent link. + + :param child_link_name: The name of the child link. + :param parent_link_name: The name of the parent link. + :param attachment_pose: The attachment pose. + """ + self._request_single_api_callback(API.ATTACH, child_link_name, parent_link_name, + attachment_pose) + + def get_constraint_link_names(self, constraint: Constraint) -> Tuple[str, str]: + """ + Get the link names of the constraint. + + :param constraint: The constraint. + :return: The link names of the constraint. + """ + return self.get_parent_link_name(constraint), self.get_constraint_child_link_name(constraint) + + def get_parent_link_name(self, constraint: Constraint) -> str: + """ + Get the parent link name of the constraint. + + :param constraint: The constraint. + :return: The parent link name of the constraint. + """ + return self.get_link_name_for_constraint(constraint.parent_link) + + def get_constraint_child_link_name(self, constraint: Constraint) -> str: + """ + Get the child link name of the constraint. + + :param constraint: The constraint. + :return: The child link name of the constraint. + """ + return self.get_link_name_for_constraint(constraint.child_link) + + @staticmethod + def get_link_name_for_constraint(link: Link) -> str: + """ + Get the link name from link object, if the link belongs to a one link object, return the object name. + + :param link: The link. + :return: The link name. + """ + return link.name if not link.is_only_link else link.object.name + + def detach(self, constraint: Constraint) -> None: + """ + Request to detach the child link from the parent link. + + :param constraint: The constraint. + """ + parent_link_name, child_link_name = self.get_constraint_link_names(constraint) + self._detach(child_link_name, parent_link_name) + + def _detach(self, child_link_name: str, parent_link_name: str) -> None: + """ + Detach the child link from the parent link. + + :param child_link_name: The name of the child link. + :param parent_link_name: The name of the parent link. + """ + self._request_single_api_callback(API.DETACH, child_link_name, parent_link_name) + + def _get_attachment_pose_as_string(self, constraint: Constraint) -> str: + """ + Get the attachment pose as a string. + + :param constraint: The constraint. + :return: The attachment pose as a string. + """ + pose = constraint.parent_to_child_transform.to_pose() + return self._pose_to_string(pose) + + @staticmethod + def _pose_to_string(pose: Pose) -> str: + """ + Convert the pose to a string. + + :param pose: The pose. + :return: The pose as a string. + """ + return f"{pose.position.x} {pose.position.y} {pose.position.z} {pose.orientation.w} {pose.orientation.x} " \ + f"{pose.orientation.y} {pose.orientation.z}" + + def check_object_exists(self, obj: Object) -> bool: + """ + Check if the object exists in the simulation. + + :param obj: The object. + :return: Whether the object exists in the simulation. + """ + return self._request_single_api_callback(API.EXIST, obj.name)[0] == 'yes' + + def get_contact_points(self, obj: Object) -> List[MultiverseContactPoint]: + """ + Request the contact points of an object, this includes the object names and the contact forces and torques. + + :param obj: The object. + :return: The contact points of the object as a list of MultiverseContactPoint. + """ + api_response_data = self._get_contact_points(obj.name) + body_names = api_response_data[API.GET_CONTACT_BODIES] + contact_efforts = self._parse_constraint_effort(api_response_data[API.GET_CONSTRAINT_EFFORT]) + return [MultiverseContactPoint(body_names[i], contact_efforts[:3], contact_efforts[3:]) + for i in range(len(body_names))] + + def get_objects_intersected_with_rays(self, from_positions: List[List[float]], + to_positions: List[List[float]]) -> List[RayResult]: + """ + Get the rays intersections with the objects from the from_positions to the to_positions. + + :param from_positions: The starting positions of the rays. + :param to_positions: The ending positions of the rays. + :return: The rays intersections with the objects as a list of RayResult. + """ + get_rays_response = self._get_rays(from_positions, to_positions) + return self._parse_get_rays_response(get_rays_response) + + def _get_rays(self, from_positions: List[List[float]], + to_positions: List[List[float]]) -> List[str]: + """ + Get the rays intersections with the objects from the from_positions to the to_positions. + + :param from_positions: The starting positions of the rays. + :param to_positions: The ending positions of the rays. + :return: The rays intersections with the objects as a dictionary. + """ + from_positions = self.list_of_positions_to_string(from_positions) + to_positions = self.list_of_positions_to_string(to_positions) + return self._request_single_api_callback(API.GET_RAYS, from_positions, to_positions) + + @staticmethod + def _parse_get_rays_response(response: List[str]) -> List[RayResult]: + """ + Parse the response of the get rays API. + + :param response: The response of the get rays API as a list of strings. + :return: The rays as a list of lists of floats. + """ + get_rays_results = [] + for ray_response in response: + if ray_response == "None": + get_rays_results.append(RayResult("", -1)) + else: + result = ray_response.split() + result[1] = float(result[1]) + get_rays_results.append(RayResult(*result)) + return get_rays_results + + @staticmethod + def list_of_positions_to_string(positions: List[List[float]]) -> str: + """ + Convert the list of positions to a string. + + :param positions: The list of positions. + :return: The list of positions as a string. + """ + return " ".join([f"{position[0]} {position[1]} {position[2]}" for position in positions]) + + @staticmethod + def _parse_constraint_effort(contact_effort: List[str]) -> List[float]: + """ + Parse the contact effort of an object. + + :param contact_effort: The contact effort of the object as a list of strings. + :return: The contact effort of the object as a list of floats. + """ + contact_effort = contact_effort[0].split() + if 'failed' in contact_effort: + rospy.logwarn("Failed to get contact effort") + return [0.0] * 6 + return list(map(float, contact_effort)) + + def _get_contact_points(self, object_name) -> Dict[API, List]: + """ + Request the contact points of an object. + + :param object_name: The name of the object. + :return: The contact points api response as a dictionary. + """ + return self._request_apis_callbacks({API.GET_CONTACT_BODIES: [object_name], + API.GET_CONSTRAINT_EFFORT: [object_name] + }) + + def pause_simulation(self) -> None: + """ + Pause the simulation. + """ + self._request_single_api_callback(API.PAUSE) + + def unpause_simulation(self) -> None: + """ + Unpause the simulation. + """ + self._request_single_api_callback(API.UNPAUSE) + + def _request_single_api_callback(self, api_name: API, *params) -> List[str]: + """ + Request a single API callback from the server. + + :param api_data: The API data to request the callback. + :return: The API response as a list of strings. + """ + response = self._request_apis_callbacks({api_name: list(params)}) + return response[api_name] + + def _request_apis_callbacks(self, api_data: Dict[API, List]) -> Dict[API, List[str]]: + """ + Request the API callbacks from the server. + + :param api_data: The API data to add to the request metadata. + :return: The API response as a list of strings. + """ + self._reset_api_callback() + for api_name, params in api_data.items(): + self._add_api_request(api_name.value, *params) + self._send_api_request() + responses = self._get_all_apis_responses() + if self.wait: + sleep(self.API_REQUEST_WAIT_TIME.total_seconds() * self.simulation_wait_time_factor) + self.wait = False + return responses + + def _get_all_apis_responses(self) -> Dict[API, List[str]]: + """ + Get all the API responses from the server. + + :return: The API responses as a list of APIData. + """ + list_of_api_responses = self.response_meta_data["api_callbacks_response"][self.simulation] + return {API[api_name.upper()]: response for api_response in list_of_api_responses + for api_name, response in api_response.items()} + + def _add_api_request(self, api_name: str, *params): + """ + Add an API request to the request metadata. + + :param api_name: The name of the API. + :param params: The parameters of the API. + """ + self.request_meta_data["api_callbacks"][self.simulation].append({api_name: list(params)}) + + def _send_api_request(self): + """ + Send the API request to the server. + """ + if "api_callbacks" not in self.request_meta_data: + logging.error("No API request to send") + raise ValueError + self.send_and_receive_meta_data() + self.request_meta_data.pop("api_callbacks") + + def _reset_api_callback(self): + """ + Initialize the API callback in the request metadata. + """ + self.request_meta_data["api_callbacks"] = {self.simulation: []} diff --git a/src/pycram/worlds/multiverse_communication/socket.py b/src/pycram/worlds/multiverse_communication/socket.py new file mode 100644 index 000000000..9692f3e86 --- /dev/null +++ b/src/pycram/worlds/multiverse_communication/socket.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 + +"""Multiverse Client base class.""" + +import rospy +from multiverse_client_pybind import MultiverseClientPybind # noqa +from typing_extensions import Optional, List, Dict, Callable, TypeVar + +from ...datastructures.dataclasses import MultiverseMetaData +from ...config.multiverse_conf import MultiverseConfig as Conf + +T = TypeVar("T") + + +class MultiverseSocket: + + def __init__( + self, + port: str, + host: str = Conf.HOST, + meta_data: MultiverseMetaData = MultiverseMetaData(), + ) -> None: + """ + Initialize the MultiverseSocket, connect to the Multiverse Server and start the communication. + + :param port: The port of the client. + :param host: The host of the client. + :param meta_data: The metadata for the Multiverse Client as MultiverseMetaData. + """ + if not isinstance(port, str) or port == "": + raise ValueError(f"Must specify client port for {self.__class__.__name__}") + self._send_data = None + self.port = port + self.host = host + self._meta_data = meta_data + self.client_name = self._meta_data.simulation_name + self._multiverse_socket = MultiverseClientPybind( + f"{Conf.SERVER_HOST}:{Conf.SERVER_PORT}" + ) + self.request_meta_data = { + "meta_data": self._meta_data.__dict__, + "send": {}, + "receive": {}, + } + self._api_callbacks: Optional[Dict] = None + + self._start_time = 0.0 + + def run(self) -> None: + """Run the client.""" + self.log_info("Start") + self._run() + + def _run(self) -> None: + """Run the client, should call the _connect_and_start() method. It's left to the user to implement this method + in threaded or non-threaded fashion. + """ + self._connect_and_start() + + def stop(self) -> None: + """Stop the client.""" + self._disconnect() + + @property + def request_meta_data(self) -> Dict: + """The request_meta_data which is sent to the server. + """ + return self._request_meta_data + + @request_meta_data.setter + def request_meta_data(self, request_meta_data: Dict) -> None: + """Set the request_meta_data, make sure to clear the `send` and `receive` field before setting the request + """ + self._request_meta_data = request_meta_data + self._multiverse_socket.set_request_meta_data(self._request_meta_data) + + @property + def response_meta_data(self) -> Dict: + """Get the response_meta_data. + + :return: The response_meta_data as a dictionary. + """ + response_meta_data = self._multiverse_socket.get_response_meta_data() + assert isinstance(response_meta_data, dict) + if response_meta_data == {}: + message = f"[Client {self.port}] Receive empty response meta data." + self.log_warn(message) + return response_meta_data + + def send_and_receive_meta_data(self): + """ + Send and receive the metadata, this should be called before sending and receiving data. + """ + self._communicate(True) + + def send_and_receive_data(self): + """ + Send and receive the data, this should be called after sending and receiving the metadata. + """ + self._communicate(False) + + @property + def send_data(self) -> List[float]: + """Get the send_data.""" + return self._send_data + + @send_data.setter + def send_data(self, send_data: List[float]) -> None: + """Set the send_data, the first element should be the current simulation time, + the rest should be the data to send with the following order: + double -> uint8_t -> uint16_t + + :param send_data: The data to send. + """ + assert isinstance(send_data, list) + self._send_data = send_data + self._multiverse_socket.set_send_data(self._send_data) + + @property + def receive_data(self) -> List[float]: + """Get the receive_data, the first element should be the current simulation time, + the rest should be the received data with the following order: + double -> uint8_t -> uint16_t + + :return: The received data. + """ + receive_data = self._multiverse_socket.get_receive_data() + assert isinstance(receive_data, list) + return receive_data + + @property + def api_callbacks(self) -> Dict[str, Callable[[List[str]], List[str]]]: + """Get the api_callbacks. + + :return: The api_callbacks as a dictionary of function names and their respective callbacks. + """ + return self._api_callbacks + + @api_callbacks.setter + def api_callbacks(self, api_callbacks: Dict[str, Callable[[List[str]], List[str]]]) -> None: + """Set the api_callbacks. + + :param api_callbacks: The api_callbacks as a dictionary of function names and their respective callbacks. + """ + self._multiverse_socket.set_api_callbacks(api_callbacks) + self._api_callbacks = api_callbacks + + def _bind_request_meta_data(self, request_meta_data: T) -> T: + """Bind the request_meta_data before sending it to the server. + + :param request_meta_data: The request_meta_data to bind. + :return: The bound request_meta_data. + """ + pass + + def _bind_response_meta_data(self, response_meta_data: T) -> T: + """Bind the response_meta_data after receiving it from the server. + + :param response_meta_data: The response_meta_data to bind. + :return: The bound response_meta_data. + """ + pass + + def _bind_send_data(self, send_data: T) -> T: + """Bind the send_data before sending it to the server. + + :param send_data: The send_data to bind. + :return: The bound send_data. + """ + pass + + def _bind_receive_data(self, receive_data: T) -> T: + """Bind the receive_data after receiving it from the server. + + :param receive_data: The receive_data to bind. + :return: The bound receive_data. + """ + pass + + def _connect_and_start(self) -> None: + """Connect to the server and start the client. + """ + self._multiverse_socket.connect(self.host, self.port) + self._multiverse_socket.start() + self._start_time = self._multiverse_socket.get_time_now() + + def _disconnect(self) -> None: + """Disconnect from the server. + """ + self._multiverse_socket.disconnect() + + def _communicate(self, resend_request_meta_data: bool = False) -> bool: + """Communicate with the server. + + :param resend_request_meta_data: Resend the request metadata. + :return: True if the communication was successful, False otherwise. + """ + return self._multiverse_socket.communicate(resend_request_meta_data) + + def _restart(self) -> None: + """Restart the client. + """ + self._disconnect() + self._connect_and_start() + + def log_info(self, message: str) -> None: + """Log information. + + :param message: The message to log. + """ + rospy.loginfo(self._message_template(message)) + + def log_warn(self, message: str) -> None: + """Warn the user. + + :param message: The message to warn about. + """ + rospy.logwarn(self._message_template(message)) + + def _message_template(self, message: str) -> str: + return (f"[{self.__class__.__name__}:{self.port}]: {message} : sim time {self.sim_time}," + f" world time {self.world_time}") + + @property + def world_time(self) -> float: + """Get the world time from the server. + + :return: The world time. + """ + return self._multiverse_socket.get_world_time() + + @property + def sim_time(self) -> float: + """Get the current simulation time. + + :return: The current simulation time. + """ + return self._multiverse_socket.get_time_now() - self._start_time diff --git a/test/bullet_world_testcase.py b/test/bullet_world_testcase.py index f6bbaba43..31e6ae6f6 100644 --- a/test/bullet_world_testcase.py +++ b/test/bullet_world_testcase.py @@ -2,6 +2,7 @@ import unittest import pycram.tasktree +from pycram.datastructures.world import UseProspectionWorld from pycram.worlds.bullet_world import BulletWorld from pycram.world_concepts.world_object import Object from pycram.datastructures.pose import Pose @@ -29,13 +30,15 @@ def setUpClass(cls): RobotDescription.current_robot_description.name + cls.extension) cls.kitchen = Object("kitchen", ObjectType.ENVIRONMENT, "kitchen" + cls.extension) cls.cereal = Object("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", - ObjectDescription, pose=Pose([1.3, 0.7, 0.95])) + pose=Pose([1.3, 0.7, 0.95])) ProcessModule.execution_delay = False cls.viz_marker_publisher = VizMarkerPublisher() OntologyManager(SOMA_ONTOLOGY_IRI) def setUp(self): - self.world.reset_world() + self.world.reset_world(remove_saved_states=True) + with UseProspectionWorld(): + pass # DO NOT WRITE TESTS HERE!!! # Test related to the BulletWorld should be written in test_bullet_world.py @@ -44,7 +47,9 @@ def setUp(self): def tearDown(self): pycram.tasktree.task_tree.reset_tree() time.sleep(0.05) - self.world.reset_world() + self.world.reset_world(remove_saved_states=True) + with UseProspectionWorld(): + pass @classmethod def tearDownClass(cls): @@ -67,7 +72,7 @@ def setUpClass(cls): RobotDescription.current_robot_description.name + cls.extension) cls.kitchen = Object("kitchen", ObjectType.ENVIRONMENT, "kitchen" + cls.extension) cls.cereal = Object("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", - ObjectDescription, pose=Pose([1.3, 0.7, 0.95])) + pose=Pose([1.3, 0.7, 0.95])) ProcessModule.execution_delay = False cls.viz_marker_publisher = VizMarkerPublisher() diff --git a/test/test_action_designator.py b/test/test_action_designator.py index 55f328535..1c7fb5a3d 100644 --- a/test/test_action_designator.py +++ b/test/test_action_designator.py @@ -20,7 +20,8 @@ def test_move_torso(self): self.assertEqual(description.ground().position, 0.3) with simulated_robot: description.resolve().perform() - self.assertEqual(self.world.robot.get_joint_position(RobotDescription.current_robot_description.torso_joint), 0.3) + self.assertEqual(self.world.robot.get_joint_position(RobotDescription.current_robot_description.torso_joint), + 0.3) def test_set_gripper(self): description = action_designator.SetGripperAction([Arms.LEFT], [GripperState.OPEN, GripperState.CLOSE]) @@ -146,7 +147,3 @@ def test_facing(self): FaceAtPerformable(self.milk.pose).perform() milk_in_robot_frame = LocalTransformer().transform_to_object_frame(self.milk.pose, self.robot) self.assertAlmostEqual(milk_in_robot_frame.position.y, 0.) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_attachment.py b/test/test_attachment.py index d521d752a..a6491eece 100644 --- a/test/test_attachment.py +++ b/test/test_attachment.py @@ -20,6 +20,19 @@ def test_detach(self): self.assertTrue(self.robot not in self.milk.attachments) self.assertTrue(self.milk not in self.robot.attachments) + def test_detach_sync_in_prospection_world(self): + self.milk.attach(self.robot) + with UseProspectionWorld(): + pass + self.milk.detach(self.robot) + with UseProspectionWorld(): + self.assertTrue(self.milk not in self.robot.attachments) + self.assertTrue(self.robot not in self.milk.attachments) + prospection_milk = self.world.get_prospection_object_for_object(self.milk) + prospection_robot = self.world.get_prospection_object_for_object(self.robot) + self.assertTrue(prospection_milk not in prospection_robot.attachments) + self.assertTrue(prospection_robot not in prospection_milk.attachments) + def test_attachment_behavior(self): self.robot.attach(self.milk) @@ -52,27 +65,28 @@ def test_prospection_object_attachments_not_changed_with_real_object(self): time.sleep(0.05) milk_2.attach(cereal_2) time.sleep(0.05) - prospection_milk = self.world.get_prospection_object_for_object(milk_2) - # self.assertTrue(cereal_2 not in prospection_milk.attachments) - prospection_cereal = self.world.get_prospection_object_for_object(cereal_2) - # self.assertTrue(prospection_cereal in prospection_milk.attachments) - self.assertTrue(prospection_milk.attachments == {}) - - # Assert that when prospection object is moved, the real object is not moved with UseProspectionWorld(): + prospection_milk = self.world.get_prospection_object_for_object(milk_2) + # self.assertTrue(cereal_2 not in prospection_milk.attachments) + prospection_cereal = self.world.get_prospection_object_for_object(cereal_2) + # self.assertTrue(prospection_cereal in prospection_milk.attachments) + self.assertTrue(prospection_cereal in prospection_milk.attachments.keys()) + + # Assert that when prospection object is moved, the real object is not moved prospection_milk_pos = prospection_milk.get_position() cereal_pos = cereal_2.get_position() - prospection_cereal_pos = prospection_cereal.get_position() + estimated_prospection_cereal_pos = prospection_cereal.get_position() + estimated_prospection_cereal_pos.x += 1 # Move prospection milk object prospection_milk_pos.x += 1 prospection_milk.set_position(prospection_milk_pos) - # Prospection object should not move + # Prospection cereal should move since it is attached to prospection milk new_prospection_cereal_pose = prospection_cereal.get_position() - self.assertTrue(new_prospection_cereal_pose == prospection_cereal_pos) + self.assertAlmostEqual(new_prospection_cereal_pose.x, estimated_prospection_cereal_pos.x, delta=0.01) - # Real cereal object should not move + # Also Real cereal object should not move since it is not affected by prospection milk new_cereal_pos = cereal_2.get_position() assumed_cereal_pos = cereal_pos self.assertTrue(new_cereal_pos == assumed_cereal_pos) @@ -80,22 +94,6 @@ def test_prospection_object_attachments_not_changed_with_real_object(self): self.world.remove_object(milk_2) self.world.remove_object(cereal_2) - def test_no_attachment_in_prospection_world(self): - milk_2 = Object("milk_2", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) - cereal_2 = Object("cereal_2", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", - pose=Pose([1.3, 0.7, 0.95])) - - milk_2.attach(cereal_2) - - prospection_milk = self.world.get_prospection_object_for_object(milk_2) - prospection_cereal = self.world.get_prospection_object_for_object(cereal_2) - - self.assertTrue(prospection_milk.attachments == {}) - self.assertTrue(prospection_cereal.attachments == {}) - - self.world.remove_object(milk_2) - self.world.remove_object(cereal_2) - def test_attaching_to_robot_and_moving(self): self.robot.attach(self.milk) milk_pos = self.milk.get_position() diff --git a/test/test_bullet_world.py b/test/test_bullet_world.py index ec398df7a..565e98342 100644 --- a/test/test_bullet_world.py +++ b/test/test_bullet_world.py @@ -11,7 +11,7 @@ from pycram.object_descriptors.urdf import ObjectDescription from pycram.datastructures.dataclasses import Color from pycram.world_concepts.world_object import Object -from pycram.datastructures.world import UseProspectionWorld +from pycram.datastructures.world import UseProspectionWorld, World fix_missing_inertial = ObjectDescription.fix_missing_inertial @@ -53,8 +53,7 @@ def test_remove_object(self): self.assertTrue(milk_id in [obj.id for obj in self.world.objects]) self.world.remove_object(self.milk) self.assertTrue(milk_id not in [obj.id for obj in self.world.objects]) - BulletWorldTest.milk = Object("milk", ObjectType.MILK, "milk.stl", - ObjectDescription, pose=Pose([1.3, 1, 0.9])) + BulletWorldTest.milk = Object("milk", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) def test_remove_robot(self): robot_id = self.robot.id @@ -65,7 +64,7 @@ def test_remove_robot(self): RobotDescription.current_robot_description.name + self.extension) def test_get_joint_position(self): - self.assertEqual(self.robot.get_joint_position("head_pan_joint"), 0.0) + self.assertAlmostEqual(self.robot.get_joint_position("head_pan_joint"), 0.0, delta=0.01) def test_get_object_contact_points(self): self.assertEqual(len(self.robot.contact_points()), 0) @@ -136,51 +135,34 @@ def test_equal_world_states(self): time.sleep(2.5) self.robot.set_pose(Pose([1, 0, 0], [0, 0, 0, 1])) self.assertFalse(self.world.world_sync.check_for_equal()) - self.world.prospection_world.object_states = self.world.current_state.object_states - time.sleep(0.05) - self.assertTrue(self.world.world_sync.check_for_equal()) + with UseProspectionWorld(): + self.assertTrue(self.world.world_sync.check_for_equal()) def test_add_resource_path(self): self.world.add_resource_path("test") - self.assertTrue("test" in self.world.data_directory) + self.assertTrue("test" in self.world.get_data_directories()) def test_no_prospection_object_found_for_given_object(self): milk_2 = Object("milk_2", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) - time.sleep(0.05) try: prospection_milk_2 = self.world.get_prospection_object_for_object(milk_2) self.world.remove_object(milk_2) - time.sleep(0.1) self.world.get_prospection_object_for_object(milk_2) self.assertFalse(True) - except ValueError as e: - self.assertTrue(True) - - def test_no_object_found_for_given_prospection_object(self): - milk_2 = Object("milk_2", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) - time.sleep(0.05) - prospection_milk = self.world.get_prospection_object_for_object(milk_2) - self.assertTrue(self.world.get_object_for_prospection_object(prospection_milk) == milk_2) - try: - self.world.remove_object(milk_2) - self.world.get_object_for_prospection_object(prospection_milk) - time.sleep(0.1) - self.assertFalse(True) - except ValueError as e: + except KeyError as e: self.assertTrue(True) - time.sleep(0.05) def test_real_object_position_does_not_change_with_prospection_object(self): milk_2_pos = [1.3, 1, 0.9] milk_2 = Object("milk_3", ObjectType.MILK, "milk.stl", pose=Pose(milk_2_pos)) time.sleep(0.05) milk_2_pos = milk_2.get_position() - prospection_milk = self.world.get_prospection_object_for_object(milk_2) - prospection_milk_pos = prospection_milk.get_position() - self.assertTrue(prospection_milk_pos == milk_2_pos) # Assert that when prospection object is moved, the real object is not moved with UseProspectionWorld(): + prospection_milk = self.world.get_prospection_object_for_object(milk_2) + prospection_milk_pos = prospection_milk.get_position() + self.assertTrue(prospection_milk_pos == milk_2_pos) prospection_milk_pos.x += 1 prospection_milk.set_position(prospection_milk_pos) self.assertTrue(prospection_milk.get_position() != milk_2.get_position()) @@ -191,32 +173,32 @@ def test_prospection_object_position_does_not_change_with_real_object(self): milk_2 = Object("milk_4", ObjectType.MILK, "milk.stl", pose=Pose(milk_2_pos)) time.sleep(0.05) milk_2_pos = milk_2.get_position() - prospection_milk = self.world.get_prospection_object_for_object(milk_2) - prospection_milk_pos = prospection_milk.get_position() - self.assertTrue(prospection_milk_pos == milk_2_pos) # Assert that when real object is moved, the prospection object is not moved with UseProspectionWorld(): + prospection_milk = self.world.get_prospection_object_for_object(milk_2) + prospection_milk_pos = prospection_milk.get_position() + self.assertTrue(prospection_milk_pos == milk_2_pos) milk_2_pos.x += 1 milk_2.set_position(milk_2_pos) self.assertTrue(prospection_milk.get_position() != milk_2.get_position()) self.world.remove_object(milk_2) def test_add_vis_axis(self): - self.world.add_vis_axis(self.robot.get_link_pose(RobotDescription.current_robot_description.get_camera_frame())) + self.world.add_vis_axis(self.robot.get_link_pose(RobotDescription.current_robot_description.get_camera_link())) self.assertTrue(len(self.world.vis_axis) == 1) self.world.remove_vis_axis() self.assertTrue(len(self.world.vis_axis) == 0) def test_add_text(self): - link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_frame()) + link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_link()) text_id = self.world.add_text("test", link.position_as_list, link.orientation_as_list, 1, Color(1, 0, 0, 1), 3, link.object_id, link.id) if self.world.mode == WorldMode.GUI: time.sleep(4) def test_remove_text(self): - link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_frame()) + link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_link()) text_id_1 = self.world.add_text("test 1", link.pose.position_as_list(), link.pose.orientation_as_list(), 1, Color(1, 0, 0, 1), 0, link.object_id, link.id) text_id = self.world.add_text("test 2", link.pose.position_as_list(), link.pose.orientation_as_list(), 1, @@ -229,7 +211,7 @@ def test_remove_text(self): time.sleep(3) def test_remove_all_text(self): - link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_frame()) + link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_link()) text_id_1 = self.world.add_text("test 1", link.pose.position_as_list(), link.pose.orientation_as_list(), 1, Color(1, 0, 0, 1), 0, link.object_id, link.id) text_id = self.world.add_text("test 2", link.pose.position_as_list(), link.pose.orientation_as_list(), 1, diff --git a/test/test_bullet_world_reasoning.py b/test/test_bullet_world_reasoning.py index 3fafe27d4..8d8c1061b 100644 --- a/test/test_bullet_world_reasoning.py +++ b/test/test_bullet_world_reasoning.py @@ -20,28 +20,35 @@ def test_visible(self): self.milk.set_pose(Pose([1.5, 0, 1.2])) self.robot.set_pose(Pose()) time.sleep(1) - camera_frame = RobotDescription.current_robot_description.get_camera_frame() - self.world.add_vis_axis(self.robot.get_link_pose(camera_frame)) - self.assertTrue(btr.visible(self.milk, self.robot.get_link_pose(camera_frame), + camera_link = RobotDescription.current_robot_description.get_camera_link() + self.world.add_vis_axis(self.robot.get_link_pose(camera_link)) + self.assertTrue(btr.visible(self.milk, self.robot.get_link_pose(camera_link), RobotDescription.current_robot_description.get_default_camera().front_facing_axis)) def test_occluding(self): self.milk.set_pose(Pose([3, 0, 1.2])) self.robot.set_pose(Pose()) - self.assertTrue(btr.occluding(self.milk, self.robot.get_link_pose(RobotDescription.current_robot_description.get_camera_frame()), + self.assertTrue(btr.occluding(self.milk, self.robot.get_link_pose( + RobotDescription.current_robot_description.get_camera_link()), RobotDescription.current_robot_description.get_default_camera().front_facing_axis) != []) def test_reachable(self): self.robot.set_pose(Pose()) time.sleep(1) - self.assertTrue(btr.reachable(Pose([0.5, -0.7, 1]), self.robot, RobotDescription.current_robot_description.kinematic_chains["right"].get_tool_frame())) - self.assertFalse(btr.reachable(Pose([2, 2, 1]), self.robot, RobotDescription.current_robot_description.kinematic_chains["right"].get_tool_frame())) + self.assertTrue(btr.reachable(Pose([0.5, -0.7, 1]), self.robot, + RobotDescription.current_robot_description.kinematic_chains[ + "right"].get_tool_frame())) + self.assertFalse(btr.reachable(Pose([2, 2, 1]), self.robot, + RobotDescription.current_robot_description.kinematic_chains[ + "right"].get_tool_frame())) def test_blocking(self): self.milk.set_pose(Pose([0.5, -0.7, 1])) self.robot.set_pose(Pose()) time.sleep(2) - self.assertTrue(btr.blocking(Pose([0.5, -0.7, 1]), self.robot, RobotDescription.current_robot_description.kinematic_chains["right"].get_tool_frame()) != []) + blocking = btr.blocking(Pose([0.5, -0.7, 1]), self.robot, + RobotDescription.current_robot_description.kinematic_chains["right"].get_tool_frame()) + self.assertTrue(blocking != []) def test_supporting(self): self.milk.set_pose(Pose([1.3, 0, 0.9])) diff --git a/test/test_cache_manager.py b/test/test_cache_manager.py index bad803f22..9f208d90f 100644 --- a/test/test_cache_manager.py +++ b/test/test_cache_manager.py @@ -1,20 +1,18 @@ +import os from pathlib import Path from bullet_world_testcase import BulletWorldTestCase -from pycram.datastructures.enums import ObjectType -from pycram.world_concepts.world_object import Object -import pathlib +from pycram.object_descriptors.urdf import ObjectDescription as URDFObject +from pycram.config import world_conf as conf class TestCacheManager(BulletWorldTestCase): def test_generate_description_and_write_to_cache(self): cache_manager = self.world.cache_manager - file_path = pathlib.Path(__file__).parent.resolve() - path = str(file_path) + "/../resources/apartment.urdf" + path = os.path.join(self.world.conf.resources_path, "objects/apartment.urdf") extension = Path(path).suffix - cache_path = self.world.cache_dir + "apartment.urdf" - apartment = Object("apartment", ObjectType.ENVIRONMENT, path) - cache_manager.generate_description_and_write_to_cache(path, apartment.name, extension, cache_path, - apartment.description) - self.assertTrue(cache_manager.is_cached(path, apartment.description)) + cache_path = os.path.join(self.world.conf.cache_dir, "apartment.urdf") + apartment = URDFObject(path) + apartment.generate_description_from_file(path, "apartment", extension, cache_path) + self.assertTrue(cache_manager.is_cached(path, apartment)) diff --git a/test/test_database_resolver.py b/test/test_database_resolver.py index 3fc21284a..7392dbac2 100644 --- a/test/test_database_resolver.py +++ b/test/test_database_resolver.py @@ -2,7 +2,7 @@ import unittest import sqlalchemy import sqlalchemy.orm -import pycram.plan_failures +import pycram.failures from pycram.world_concepts.world_object import Object from pycram.datastructures.world import World from pycram.designators import action_designator diff --git a/test/test_description.py b/test/test_description.py index a0d12b3b2..e324384ac 100644 --- a/test/test_description.py +++ b/test/test_description.py @@ -1,3 +1,4 @@ +import os.path import pathlib from bullet_world_testcase import BulletWorldTestCase @@ -22,10 +23,16 @@ def test_joint_child_link(self): def test_generate_description_from_mesh(self): file_path = pathlib.Path(__file__).parent.resolve() - self.assertTrue(self.milk.description.generate_description_from_file(str(file_path) + "/../resources/cached/milk.stl", - "milk", ".stl")) + cache_path = self.world.cache_manager.cache_dir + cache_path = os.path.join(cache_path, f"{self.milk.description.name}.urdf") + self.milk.description.generate_from_mesh_file(str(file_path) + "/../resources/milk.stl", "milk", cache_path) + self.assertTrue(self.world.cache_manager.is_cached(f"{self.milk.name}", self.milk.description)) def test_generate_description_from_description_file(self): file_path = pathlib.Path(__file__).parent.resolve() - self.assertTrue(self.milk.description.generate_description_from_file(str(file_path) + "/../resources/cached/milk.urdf", - "milk", ".urdf")) + file_extension = self.robot.description.get_file_extension() + pr2_path = str(file_path) + f"/../resources/robots/{self.robot.description.name}{file_extension}" + cache_path = self.world.cache_manager.cache_dir + cache_path = os.path.join(cache_path, f"{self.robot.description.name}.urdf") + self.robot.description.generate_from_description_file(pr2_path, cache_path) + self.assertTrue(self.world.cache_manager.is_cached(self.robot.name, self.robot.description)) diff --git a/test/test_error_checkers.py b/test/test_error_checkers.py new file mode 100644 index 000000000..63bf06416 --- /dev/null +++ b/test/test_error_checkers.py @@ -0,0 +1,131 @@ +from unittest import TestCase + +import numpy as np +from tf.transformations import quaternion_from_euler + +from pycram.datastructures.enums import JointType +from pycram.validation.error_checkers import calculate_angle_between_quaternions, \ + PoseErrorChecker, PositionErrorChecker, OrientationErrorChecker, RevoluteJointPositionErrorChecker, \ + PrismaticJointPositionErrorChecker, MultiJointPositionErrorChecker + +from pycram.datastructures.pose import Pose + + +class TestErrorCheckers(TestCase): + @classmethod + def setUpClass(cls): + pass + + @classmethod + def tearDownClass(cls): + pass + + def tearDown(self): + pass + + def test_calculate_quaternion_error(self): + quat_1 = [0.0, 0.0, 0.0, 1.0] + quat_2 = [0.0, 0.0, 0.0, 1.0] + error = calculate_angle_between_quaternions(quat_1, quat_2) + self.assertEqual(error, 0.0) + quat_2 = quaternion_from_euler(0, 0, np.pi/2) + error = calculate_angle_between_quaternions(quat_1, quat_2) + self.assertEqual(error, np.pi/2) + + def test_pose_error_checker(self): + pose_1 = Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]) + pose_2 = Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]) + error_checker = PoseErrorChecker() + error = error_checker.calculate_error(pose_1, pose_2) + self.assertEqual(error, [0.0, 0.0]) + self.assertTrue(error_checker.is_error_acceptable(pose_1, pose_2)) + quat = quaternion_from_euler(0, np.pi/2, 0) + pose_2 = Pose([0, 1, np.sqrt(3)], quat) + error = error_checker.calculate_error(pose_1, pose_2) + self.assertAlmostEqual(error[0], 2, places=5) + self.assertEqual(error[1], np.pi/2) + self.assertFalse(error_checker.is_error_acceptable(pose_1, pose_2)) + quat = quaternion_from_euler(0, 0, np.pi/360) + pose_2 = Pose([0, 0.0001, 0.0001], quat) + self.assertTrue(error_checker.is_error_acceptable(pose_1, pose_2)) + quat = quaternion_from_euler(0, 0, np.pi / 179) + pose_2 = Pose([0, 0.0001, 0.0001], quat) + self.assertFalse(error_checker.is_error_acceptable(pose_1, pose_2)) + + def test_position_error_checker(self): + position_1 = [0.0, 0.0, 0.0] + position_2 = [0.0, 0.0, 0.0] + error_checker = PositionErrorChecker() + error = error_checker.calculate_error(position_1, position_2) + self.assertEqual(error, 0.0) + self.assertTrue(error_checker.is_error_acceptable(position_1, position_2)) + position_2 = [1.0, 1.0, 1.0] + error = error_checker.calculate_error(position_1, position_2) + self.assertAlmostEqual(error, np.sqrt(3), places=5) + self.assertFalse(error_checker.is_error_acceptable(position_1, position_2)) + + def test_orientation_error_checker(self): + quat_1 = [0.0, 0.0, 0.0, 1.0] + quat_2 = [0.0, 0.0, 0.0, 1.0] + error_checker = OrientationErrorChecker() + error = error_checker.calculate_error(quat_1, quat_2) + self.assertEqual(error, 0.0) + self.assertTrue(error_checker.is_error_acceptable(quat_1, quat_2)) + quat_2 = quaternion_from_euler(0, 0, np.pi/2) + error = error_checker.calculate_error(quat_1, quat_2) + self.assertEqual(error, np.pi/2) + self.assertFalse(error_checker.is_error_acceptable(quat_1, quat_2)) + + def test_revolute_joint_position_error_checker(self): + position_1 = 0.0 + position_2 = 0.0 + error_checker = RevoluteJointPositionErrorChecker() + error = error_checker.calculate_error(position_1, position_2) + self.assertEqual(error, 0.0) + self.assertTrue(error_checker.is_error_acceptable(position_1, position_2)) + position_2 = np.pi/2 + error = error_checker.calculate_error(position_1, position_2) + self.assertEqual(error, np.pi/2) + self.assertFalse(error_checker.is_error_acceptable(position_1, position_2)) + + def test_prismatic_joint_position_error_checker(self): + position_1 = 0.0 + position_2 = 0.0 + error_checker = PrismaticJointPositionErrorChecker() + error = error_checker.calculate_error(position_1, position_2) + self.assertEqual(error, 0.0) + self.assertTrue(error_checker.is_error_acceptable(position_1, position_2)) + position_2 = 1.0 + error = error_checker.calculate_error(position_1, position_2) + self.assertEqual(error, 1.0) + self.assertFalse(error_checker.is_error_acceptable(position_1, position_2)) + + def test_list_of_poses_error_checker(self): + poses_1 = [Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]), + Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0])] + poses_2 = [Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]), + Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0])] + error_checker = PoseErrorChecker(is_iterable=True) + error = error_checker.calculate_error(poses_1, poses_2) + self.assertEqual(error, [[0.0, 0.0], [0.0, 0.0]]) + self.assertTrue(error_checker.is_error_acceptable(poses_1, poses_2)) + quat = quaternion_from_euler(0, np.pi/2, 0) + poses_2 = [Pose([0, 1, np.sqrt(3)], quat), + Pose([0, 1, np.sqrt(3)], quat)] + error = error_checker.calculate_error(poses_1, poses_2) + self.assertAlmostEqual(error[0][0], 2, places=5) + self.assertEqual(error[0][1], np.pi/2) + self.assertAlmostEqual(error[1][0], 2, places=5) + self.assertEqual(error[1][1], np.pi/2) + self.assertFalse(error_checker.is_error_acceptable(poses_1, poses_2)) + + def test_multi_joint_error_checker(self): + positions_1 = [0.0, 0.0] + positions_2 = [np.pi/2, 0.1] + joint_types = [JointType.REVOLUTE, JointType.PRISMATIC] + error_checker = MultiJointPositionErrorChecker(joint_types) + error = error_checker.calculate_error(positions_1, positions_2) + self.assertEqual(error, [np.pi/2, 0.1]) + self.assertFalse(error_checker.is_error_acceptable(positions_1, positions_2)) + positions_2 = [np.pi/180, 0.0001] + self.assertTrue(error_checker.is_error_acceptable(positions_1, positions_2)) diff --git a/test/test_failure_handling.py b/test/test_failure_handling.py index b28420f96..190a48922 100644 --- a/test/test_failure_handling.py +++ b/test/test_failure_handling.py @@ -7,7 +7,7 @@ from pycram.designators.action_designator import ParkArmsAction from pycram.datastructures.enums import ObjectType, Arms, WorldMode from pycram.failure_handling import Retry -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure from pycram.process_module import ProcessModule, simulated_robot from pycram.robot_description import RobotDescription from pycram.object_descriptors.urdf import ObjectDescription @@ -33,8 +33,8 @@ class FailureHandlingTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.world = BulletWorld(WorldMode.DIRECT) - cls.robot = Object(RobotDescription.current_robot_description.name, ObjectType.ROBOT, RobotDescription.current_robot_description.name + extension, - ObjectDescription) + cls.robot = Object(RobotDescription.current_robot_description.name, ObjectType.ROBOT, + RobotDescription.current_robot_description.name + extension) ProcessModule.execution_delay = True def setUp(self): diff --git a/test/test_goal_validator.py b/test/test_goal_validator.py new file mode 100644 index 000000000..9b79cc114 --- /dev/null +++ b/test/test_goal_validator.py @@ -0,0 +1,321 @@ +import numpy as np +from tf.transformations import quaternion_from_euler +from typing_extensions import Optional, List + +from bullet_world_testcase import BulletWorldTestCase +from pycram.datastructures.enums import JointType +from pycram.datastructures.pose import Pose +from pycram.robot_description import RobotDescription +from pycram.validation.error_checkers import PoseErrorChecker, PositionErrorChecker, \ + OrientationErrorChecker, RevoluteJointPositionErrorChecker, PrismaticJointPositionErrorChecker, \ + MultiJointPositionErrorChecker +from pycram.validation.goal_validator import GoalValidator, PoseGoalValidator, \ + PositionGoalValidator, OrientationGoalValidator, JointPositionGoalValidator, MultiJointPositionGoalValidator, \ + MultiPoseGoalValidator, MultiPositionGoalValidator, MultiOrientationGoalValidator + + +class TestGoalValidator(BulletWorldTestCase): + + def test_single_pose_goal(self): + pose_goal_validators = PoseGoalValidator(self.milk.get_pose) + self.validate_pose_goal(pose_goal_validators) + + def test_single_pose_goal_generic(self): + pose_goal_validators = GoalValidator(PoseErrorChecker(), self.milk.get_pose) + self.validate_pose_goal(pose_goal_validators) + + def validate_pose_goal(self, goal_validator): + milk_goal_pose = Pose([1.3, 1.5, 0.9]) + goal_validator.register_goal(milk_goal_pose) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], 0.5, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], 0, places=5) + self.milk.set_pose(milk_goal_pose) + self.assertEqual(self.milk.get_pose(), milk_goal_pose) + self.assertTrue(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 1) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], 0, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], 0, places=5) + + def test_single_position_goal_generic(self): + goal_validator = GoalValidator(PositionErrorChecker(), self.cereal.get_position_as_list) + self.validate_position_goal(goal_validator) + + def test_single_position_goal(self): + goal_validator = PositionGoalValidator(self.cereal.get_position_as_list) + self.validate_position_goal(goal_validator) + + def validate_position_goal(self, goal_validator): + cereal_goal_position = [1.3, 1.5, 0.95] + goal_validator.register_goal(cereal_goal_position) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertEqual(goal_validator.current_error, 0.8) + self.cereal.set_position(cereal_goal_position) + self.assertEqual(self.cereal.get_position_as_list(), cereal_goal_position) + self.assertTrue(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 1) + self.assertEqual(goal_validator.current_error, 0) + + def test_single_orientation_goal_generic(self): + goal_validator = GoalValidator(OrientationErrorChecker(), self.cereal.get_orientation_as_list) + self.validate_orientation_goal(goal_validator) + + def test_single_orientation_goal(self): + goal_validator = OrientationGoalValidator(self.cereal.get_orientation_as_list) + self.validate_orientation_goal(goal_validator) + + def validate_orientation_goal(self, goal_validator): + cereal_goal_orientation = quaternion_from_euler(0, 0, np.pi / 2) + goal_validator.register_goal(cereal_goal_orientation) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertEqual(goal_validator.current_error, [np.pi / 2]) + self.cereal.set_orientation(cereal_goal_orientation) + for v1, v2 in zip(self.cereal.get_orientation_as_list(), cereal_goal_orientation.tolist()): + self.assertAlmostEqual(v1, v2, places=5) + self.assertTrue(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, 1, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], 0, places=5) + + def test_single_revolute_joint_position_goal_generic(self): + goal_validator = GoalValidator(RevoluteJointPositionErrorChecker(), self.robot.get_joint_position) + self.validate_revolute_joint_position_goal(goal_validator) + + def test_single_revolute_joint_position_goal(self): + goal_validator = JointPositionGoalValidator(self.robot.get_joint_position) + self.validate_revolute_joint_position_goal(goal_validator, JointType.REVOLUTE) + + def validate_revolute_joint_position_goal(self, goal_validator, joint_type: Optional[JointType] = None): + goal_joint_position = -np.pi / 4 + joint_name = 'l_shoulder_lift_joint' + if joint_type is not None: + goal_validator.register_goal(goal_joint_position, joint_type, joint_name) + else: + goal_validator.register_goal(goal_joint_position, joint_name) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertEqual(goal_validator.current_error, abs(goal_joint_position)) + + for percent in [0.5, 1]: + self.robot.set_joint_position('l_shoulder_lift_joint', goal_joint_position * percent) + self.assertEqual(self.robot.get_joint_position('l_shoulder_lift_joint'), goal_joint_position * percent) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], abs(goal_joint_position) * (1 - percent), + places=5) + + def test_single_prismatic_joint_position_goal_generic(self): + goal_validator = GoalValidator(PrismaticJointPositionErrorChecker(), self.robot.get_joint_position) + self.validate_prismatic_joint_position_goal(goal_validator) + + def test_single_prismatic_joint_position_goal(self): + goal_validator = JointPositionGoalValidator(self.robot.get_joint_position) + self.validate_prismatic_joint_position_goal(goal_validator, JointType.PRISMATIC) + + def validate_prismatic_joint_position_goal(self, goal_validator, joint_type: Optional[JointType] = None): + goal_joint_position = 0.2 + torso = RobotDescription.current_robot_description.torso_joint + if joint_type is not None: + goal_validator.register_goal(goal_joint_position, joint_type, torso) + else: + goal_validator.register_goal(goal_joint_position, torso) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertEqual(goal_validator.current_error, abs(goal_joint_position)) + + for percent in [0.5, 1]: + self.robot.set_joint_position(torso, goal_joint_position * percent) + self.assertEqual(self.robot.get_joint_position(torso), goal_joint_position * percent) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], abs(goal_joint_position) * (1 - percent), + places=5) + + def test_multi_joint_goal_generic(self): + joint_types = [JointType.PRISMATIC, JointType.REVOLUTE] + goal_validator = GoalValidator(MultiJointPositionErrorChecker(joint_types), + lambda x: list(self.robot.get_multiple_joint_positions(x).values())) + self.validate_multi_joint_goal(goal_validator) + + def test_multi_joint_goal(self): + joint_types = [JointType.PRISMATIC, JointType.REVOLUTE] + goal_validator = MultiJointPositionGoalValidator( + lambda x: list(self.robot.get_multiple_joint_positions(x).values())) + self.validate_multi_joint_goal(goal_validator, joint_types) + + def validate_multi_joint_goal(self, goal_validator, joint_types: Optional[List[JointType]] = None): + goal_joint_positions = np.array([0.2, -np.pi / 4]) + joint_names = ['torso_lift_joint', 'l_shoulder_lift_joint'] + if joint_types is not None: + goal_validator.register_goal(goal_joint_positions, joint_types, joint_names) + else: + goal_validator.register_goal(goal_joint_positions, joint_names) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertTrue(np.allclose(goal_validator.current_error, np.array([0.2, abs(-np.pi / 4)]), atol=0.001)) + + for percent in [0.5, 1]: + current_joint_positions = goal_joint_positions * percent + self.robot.set_multiple_joint_positions(dict(zip(joint_names, current_joint_positions.tolist()))) + self.assertTrue(np.allclose(self.robot.get_joint_position('torso_lift_joint'), current_joint_positions[0], + atol=0.001)) + self.assertTrue( + np.allclose(self.robot.get_joint_position('l_shoulder_lift_joint'), current_joint_positions[1], + atol=0.001)) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], abs(0.2) * (1 - percent), places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], abs(-np.pi / 4) * (1 - percent), places=5) + + def test_list_of_poses_goal_generic(self): + goal_validator = GoalValidator(PoseErrorChecker(is_iterable=True), + lambda: [self.robot.get_pose(), self.robot.get_pose()]) + self.validate_list_of_poses_goal(goal_validator) + + def test_list_of_poses_goal(self): + goal_validator = MultiPoseGoalValidator(lambda: [self.robot.get_pose(), self.robot.get_pose()]) + self.validate_list_of_poses_goal(goal_validator) + + def validate_list_of_poses_goal(self, goal_validator): + position_goal = [0.0, 1.0, 0.0] + orientation_goal = np.array([0, 0, np.pi / 2]) + poses_goal = [Pose(position_goal, quaternion_from_euler(*orientation_goal.tolist())), + Pose(position_goal, quaternion_from_euler(*orientation_goal.tolist()))] + goal_validator.register_goal(poses_goal) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertTrue( + np.allclose(goal_validator.current_error, np.array([1.0, np.pi / 2, 1.0, np.pi / 2]), atol=0.001)) + + for percent in [0.5, 1]: + current_orientation_goal = orientation_goal * percent + current_pose_goal = Pose([0.0, 1.0 * percent, 0.0], + quaternion_from_euler(*current_orientation_goal.tolist())) + self.robot.set_pose(current_pose_goal) + self.assertTrue(np.allclose(self.robot.get_position_as_list(), current_pose_goal.position_as_list(), + atol=0.001)) + self.assertTrue(np.allclose(self.robot.get_orientation_as_list(), current_pose_goal.orientation_as_list(), + atol=0.001)) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], 1 - percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], np.pi * (1 - percent) / 2, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[2], (1 - percent), places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[3], np.pi * (1 - percent) / 2, places=5) + + def test_list_of_positions_goal_generic(self): + goal_validator = GoalValidator(PositionErrorChecker(is_iterable=True), + lambda: [self.robot.get_position_as_list(), self.robot.get_position_as_list()]) + self.validate_list_of_positions_goal(goal_validator) + + def test_list_of_positions_goal(self): + goal_validator = MultiPositionGoalValidator(lambda: [self.robot.get_position_as_list(), + self.robot.get_position_as_list()]) + self.validate_list_of_positions_goal(goal_validator) + + def validate_list_of_positions_goal(self, goal_validator): + position_goal = [0.0, 1.0, 0.0] + positions_goal = [position_goal, position_goal] + goal_validator.register_goal(positions_goal) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertTrue(np.allclose(goal_validator.current_error, np.array([1.0, 1.0]), atol=0.001)) + + for percent in [0.5, 1]: + current_position_goal = [0.0, 1.0 * percent, 0.0] + self.robot.set_position(current_position_goal) + self.assertTrue(np.allclose(self.robot.get_position_as_list(), current_position_goal, atol=0.001)) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], 1 - percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], 1 - percent, places=5) + + def test_list_of_orientations_goal_generic(self): + goal_validator = GoalValidator(OrientationErrorChecker(is_iterable=True), + lambda: [self.robot.get_orientation_as_list(), + self.robot.get_orientation_as_list()]) + self.validate_list_of_orientations_goal(goal_validator) + + def test_list_of_orientations_goal(self): + goal_validator = MultiOrientationGoalValidator(lambda: [self.robot.get_orientation_as_list(), + self.robot.get_orientation_as_list()]) + self.validate_list_of_orientations_goal(goal_validator) + + def validate_list_of_orientations_goal(self, goal_validator): + orientation_goal = np.array([0, 0, np.pi / 2]) + orientations_goals = [quaternion_from_euler(*orientation_goal.tolist()), + quaternion_from_euler(*orientation_goal.tolist())] + goal_validator.register_goal(orientations_goals) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertTrue(np.allclose(goal_validator.current_error, np.array([np.pi / 2, np.pi / 2]), atol=0.001)) + + for percent in [0.5, 1]: + current_orientation_goal = orientation_goal * percent + self.robot.set_orientation(quaternion_from_euler(*current_orientation_goal.tolist())) + self.assertTrue(np.allclose(self.robot.get_orientation_as_list(), + quaternion_from_euler(*current_orientation_goal.tolist()), + atol=0.001)) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], np.pi * (1 - percent) / 2, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], np.pi * (1 - percent) / 2, places=5) + + def test_list_of_revolute_joint_positions_goal_generic(self): + goal_validator = GoalValidator(RevoluteJointPositionErrorChecker(is_iterable=True), + lambda x: list(self.robot.get_multiple_joint_positions(x).values())) + self.validate_list_of_revolute_joint_positions_goal(goal_validator) + + def test_list_of_revolute_joint_positions_goal(self): + goal_validator = MultiJointPositionGoalValidator( + lambda x: list(self.robot.get_multiple_joint_positions(x).values())) + self.validate_list_of_revolute_joint_positions_goal(goal_validator, [JointType.REVOLUTE, JointType.REVOLUTE]) + + def validate_list_of_revolute_joint_positions_goal(self, goal_validator, + joint_types: Optional[List[JointType]] = None): + goal_joint_position = -np.pi / 4 + goal_joint_positions = np.array([goal_joint_position, goal_joint_position]) + joint_names = ['l_shoulder_lift_joint', 'r_shoulder_lift_joint'] + if joint_types is not None: + goal_validator.register_goal(goal_joint_positions, joint_types, joint_names) + else: + goal_validator.register_goal(goal_joint_positions, joint_names) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertTrue(np.allclose(goal_validator.current_error, + np.array([abs(goal_joint_position), abs(goal_joint_position)]), atol=0.001)) + + for percent in [0.5, 1]: + current_joint_position = goal_joint_positions * percent + self.robot.set_multiple_joint_positions(dict(zip(joint_names, current_joint_position))) + self.assertTrue(np.allclose(list(self.robot.get_multiple_joint_positions(joint_names).values()), + current_joint_position, atol=0.001)) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], abs(goal_joint_position) * (1 - percent), + places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], abs(goal_joint_position) * (1 - percent), + places=5) diff --git a/test/test_language.py b/test/test_language.py index bfa41d647..362db9c0e 100644 --- a/test/test_language.py +++ b/test/test_language.py @@ -6,7 +6,7 @@ from pycram.datastructures.enums import ObjectType, State from pycram.failure_handling import RetryMonitor from pycram.fluent import Fluent -from pycram.plan_failures import PlanFailure, NotALanguageExpression +from pycram.failures import PlanFailure, NotALanguageExpression from pycram.datastructures.pose import Pose from pycram.language import Sequential, Language, Parallel, TryAll, TryInOrder, Monitor, Code from pycram.process_module import simulated_robot diff --git a/test/test_mjcf.py b/test/test_mjcf.py new file mode 100644 index 000000000..bcb9e2262 --- /dev/null +++ b/test/test_mjcf.py @@ -0,0 +1,40 @@ +from unittest import TestCase, skipIf +from dm_control import mjcf +try: + from pycram.object_descriptors.mjcf import ObjectDescription as MJCFObjDesc +except ImportError: + MJCFObjDesc = None + + +@skipIf(MJCFObjDesc is None, "Multiverse not found.") +class TestMjcf(TestCase): + model: MJCFObjDesc + + @classmethod + def setUpClass(cls): + # Example usage + model = mjcf.RootElement("test") + + model.default.dclass = 'default' + + # Define a simple model with bodies and joints + body1 = model.worldbody.add('body', name='body1') + body2 = body1.add('body', name='body2') + joint1 = body2.add('joint', name='joint1', type='hinge') + + body3 = body2.add('body', name='body3') + joint2 = body3.add('joint', name='joint2', type='slide') + + cls.model = MJCFObjDesc() + print(model.to_xml_string()) + cls.model.update_description_from_string(model.to_xml_string()) + + def test_child_map(self): + self.assertEqual(self.model.child_map, {'body1': [('joint1', 'body2')], 'body2': [('joint2', 'body3')]}) + + def test_parent_map(self): + self.assertEqual(self.model.parent_map, {'body2': ('joint1', 'body1'), 'body3': ('joint2', 'body2')}) + + def test_get_chain(self): + self.assertEqual(self.model.get_chain('body1', 'body3'), + ['body1', 'joint1', 'body2', 'joint2', 'body3']) diff --git a/test/test_move_and_pick_up.py b/test/test_move_and_pick_up.py index 013a8708d..2c4268950 100644 --- a/test/test_move_and_pick_up.py +++ b/test/test_move_and_pick_up.py @@ -9,7 +9,7 @@ from pycram.designators.action_designator import MoveTorsoActionPerformable from pycram.designators.specialized_designators.probabilistic.probabilistic_action import (MoveAndPickUp, GaussianCostmapModel) -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure from pycram.process_module import simulated_robot diff --git a/test/test_multiverse.py b/test/test_multiverse.py new file mode 100644 index 000000000..4c1426211 --- /dev/null +++ b/test/test_multiverse.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 +import os +import unittest + +import numpy as np +import psutil +from tf.transformations import quaternion_from_euler, quaternion_multiply +from typing_extensions import Optional, List + +from pycram.datastructures.dataclasses import ContactPointsList, ContactPoint +from pycram.datastructures.enums import ObjectType, Arms, JointType +from pycram.datastructures.pose import Pose +from pycram.robot_description import RobotDescriptionManager +from pycram.world_concepts.world_object import Object +from pycram.validation.error_checkers import calculate_angle_between_quaternions +from pycram.helper import get_robot_mjcf_path, parse_mjcf_actuators + +multiverse_installed = True +try: + from pycram.worlds.multiverse import Multiverse +except ImportError: + multiverse_installed = False + +processes = psutil.process_iter() +process_names = [p.name() for p in processes] +multiverse_running = True +mujoco_running = True +if 'multiverse_server' not in process_names: + multiverse_running = False +if 'mujoco' not in process_names: + mujoco_running = False + + +@unittest.skipIf(not multiverse_installed, "Multiverse is not installed.") +@unittest.skipIf(not multiverse_running, "Multiverse server is not running.") +@unittest.skipIf(not mujoco_running, "Mujoco is not running.") +class MultiversePyCRAMTestCase(unittest.TestCase): + if multiverse_installed: + multiverse: Multiverse + big_bowl: Optional[Object] = None + + @classmethod + def setUpClass(cls): + if not multiverse_installed: + return + cls.multiverse = Multiverse() + + @classmethod + def tearDownClass(cls): + cls.multiverse.exit(remove_saved_states=True) + cls.multiverse.remove_multiverse_resources() + + def tearDown(self): + self.multiverse.remove_all_objects() + + def test_spawn_mesh_object(self): + milk = Object("milk", ObjectType.MILK, "milk.stl", pose=Pose([1, 1, 0.1])) + self.assert_poses_are_equal(milk.get_pose(), Pose([1, 1, 0.1])) + self.multiverse.simulate(0.2) + contact_points = milk.contact_points() + self.assertTrue(len(contact_points) > 0) + + def test_parse_mjcf_actuators(self): + mjcf_file = get_robot_mjcf_path("pal_robotics", "tiago_dual") + self.assertTrue(os.path.exists(mjcf_file)) + joint_actuators = parse_mjcf_actuators(mjcf_file) + self.assertIsInstance(joint_actuators, dict) + self.assertTrue(len(joint_actuators) > 0) + self.assertTrue("arm_left_1_joint" in joint_actuators) + self.assertTrue("arm_right_1_joint" in joint_actuators) + self.assertTrue(joint_actuators["arm_right_1_joint"] == "arm_right_1_actuator") + + def test_get_actuator_for_joint(self): + robot = self.spawn_robot() + joint_name = "arm_right_1_joint" + actuator_name = robot.get_actuator_for_joint(robot.joints[joint_name]) + self.assertEqual(actuator_name, "arm_right_1_actuator") + + def test_get_images_for_target(self): + robot = self.spawn_robot(robot_name='pr2') + camera_description = self.multiverse.robot_description.get_default_camera() + camera_link_name = camera_description.link_name + camera_pose = robot.get_link_pose(camera_link_name) + camera_frame = self.multiverse.robot_description.get_camera_frame() + camera_front_facing_axis = camera_description.front_facing_axis + milk_spawn_position = np.array(camera_front_facing_axis) * 0.5 + orientation = camera_pose.to_transform(camera_frame).invert().rotation_as_list() + milk = self.spawn_milk(milk_spawn_position.tolist(), orientation, frame=camera_frame) + _, depth, segmentation_mask = self.multiverse.get_images_for_target(milk.pose, camera_pose, plot=False) + self.assertIsInstance(depth, np.ndarray) + self.assertIsInstance(segmentation_mask, np.ndarray) + self.assertTrue(depth.shape == (256, 256)) + self.assertTrue(segmentation_mask.shape == (256, 256)) + self.assertTrue(milk.id in np.unique(segmentation_mask).flatten().tolist()) + avg_depth_of_milk = np.mean(depth[segmentation_mask == milk.id]) + self.assertAlmostEqual(avg_depth_of_milk, 0.5, delta=0.1) + + def test_reset_world(self): + set_position = [1, 1, 0.1] + milk = self.spawn_milk(set_position) + milk.set_position(set_position) + milk_position = milk.get_position_as_list() + self.assert_list_is_equal(milk_position[:2], set_position[:2], delta=self.multiverse.conf.position_tolerance) + self.multiverse.reset_world() + milk_pose = milk.get_pose() + self.assert_list_is_equal(milk_pose.position_as_list()[:2], + milk.original_pose.position_as_list()[:2], + delta=self.multiverse.conf.position_tolerance) + self.assert_orientation_is_equal(milk_pose.orientation_as_list(), milk.original_pose.orientation_as_list()) + + def test_spawn_robot_with_actuators_directly_from_multiverse(self): + if self.multiverse.conf.use_controller: + robot_name = "tiago_dual" + rdm = RobotDescriptionManager() + rdm.load_description(robot_name) + self.multiverse.spawn_robot_with_controller(robot_name, Pose([-2, -2, 0.001])) + + def test_spawn_object(self): + milk = self.spawn_milk([1, 1, 0.1]) + self.assertIsInstance(milk, Object) + milk_pose = milk.get_pose() + self.assert_list_is_equal(milk_pose.position_as_list()[:2], [1, 1], + delta=self.multiverse.conf.position_tolerance) + self.assert_orientation_is_equal(milk_pose.orientation_as_list(), milk.original_pose.orientation_as_list()) + + def test_remove_object(self): + milk = self.spawn_milk([1, 1, 0.1]) + milk.remove() + self.assertTrue(milk not in self.multiverse.objects) + self.assertFalse(self.multiverse.check_object_exists(milk)) + + def test_check_object_exists(self): + milk = self.spawn_milk([1, 1, 0.1]) + self.assertTrue(self.multiverse.check_object_exists(milk)) + + def test_set_position(self): + milk = self.spawn_milk([1, 1, 0.1]) + original_milk_position = milk.get_position_as_list() + original_milk_position[0] += 1 + milk.set_position(original_milk_position) + milk_position = milk.get_position_as_list() + self.assert_list_is_equal(milk_position[:2], original_milk_position[:2], + delta=self.multiverse.conf.position_tolerance) + + def test_update_position(self): + milk = self.spawn_milk([1, 1, 0.1]) + milk.update_pose() + milk_position = milk.get_position_as_list() + self.assert_list_is_equal(milk_position[:2], [1, 1], delta=self.multiverse.conf.position_tolerance) + + def test_set_joint_position(self): + if self.multiverse.robot is None: + robot = self.spawn_robot() + else: + robot = self.multiverse.robot + step = 0.2 + for joint in ['torso_lift_joint']: + joint_type = robot.joints[joint].type + original_joint_position = robot.get_joint_position(joint) + robot.set_joint_position(joint, original_joint_position + step) + joint_position = robot.get_joint_position(joint) + if not self.multiverse.conf.use_controller: + delta = self.multiverse.conf.prismatic_joint_position_tolerance if joint_type == JointType.PRISMATIC \ + else self.multiverse.conf.revolute_joint_position_tolerance + else: + delta = 0.18 + self.assertAlmostEqual(joint_position, original_joint_position + step, delta=delta) + + def test_spawn_robot(self): + if self.multiverse.robot is not None: + robot = self.multiverse.robot + else: + robot = self.spawn_robot(robot_name="pr2") + self.assertIsInstance(robot, Object) + self.assertTrue(robot in self.multiverse.objects) + self.assertTrue(self.multiverse.robot.name == robot.name) + + def test_destroy_robot(self): + if self.multiverse.robot is None: + self.spawn_robot() + self.assertTrue(self.multiverse.robot in self.multiverse.objects) + self.multiverse.robot.remove() + self.assertTrue(self.multiverse.robot not in self.multiverse.objects) + + def test_respawn_robot(self): + self.spawn_robot() + self.assertTrue(self.multiverse.robot in self.multiverse.objects) + self.multiverse.robot.remove() + self.assertTrue(self.multiverse.robot not in self.multiverse.objects) + self.spawn_robot() + self.assertTrue(self.multiverse.robot in self.multiverse.objects) + + def test_set_robot_position(self): + step = -1 + for i in range(3): + self.spawn_robot() + new_position = [-3 + step * i, -3 + step * i, 0.001] + self.multiverse.robot.set_position(new_position) + robot_position = self.multiverse.robot.get_position_as_list() + self.assert_list_is_equal(robot_position[:2], new_position[:2], + delta=self.multiverse.conf.position_tolerance) + self.tearDown() + + def test_set_robot_orientation(self): + self.spawn_robot() + for i in range(3): + current_quaternion = self.multiverse.robot.get_orientation_as_list() + # rotate by 45 degrees without using euler angles + rotation_quaternion = quaternion_from_euler(0, 0, np.pi / 4) + new_quaternion = quaternion_multiply(current_quaternion, rotation_quaternion) + self.multiverse.robot.set_orientation(new_quaternion) + robot_orientation = self.multiverse.robot.get_orientation_as_list() + quaternion_difference = calculate_angle_between_quaternions(new_quaternion, robot_orientation) + self.assertAlmostEqual(quaternion_difference, 0, delta=self.multiverse.conf.orientation_tolerance) + + def test_set_robot_pose(self): + self.spawn_robot(orientation=quaternion_from_euler(0, 0, np.pi / 4)) + position_step = -1 + angle_step = np.pi / 4 + num_steps = 10 + self.step_robot_pose(self.multiverse.robot, position_step, angle_step, num_steps) + position_step = 1 + angle_step = -np.pi / 4 + self.step_robot_pose(self.multiverse.robot, position_step, angle_step, num_steps) + + def step_robot_pose(self, robot, position_step, angle_step, num_steps): + original_position = robot.get_position_as_list() + original_orientation = robot.get_orientation_as_list() + for i in range(num_steps): + new_position = [original_position[0] + position_step * (i + 1), + original_position[1] + position_step * (i + 1), original_position[2]] + rotation_quaternion = quaternion_from_euler(0, 0, angle_step * (i + 1)) + new_quaternion = quaternion_multiply(original_orientation, rotation_quaternion) + new_pose = Pose(new_position, new_quaternion) + self.multiverse.robot.set_pose(new_pose) + robot_pose = self.multiverse.robot.get_pose() + self.assert_poses_are_equal(new_pose, robot_pose, + position_delta=self.multiverse.conf.position_tolerance, + orientation_delta=self.multiverse.conf.orientation_tolerance) + + def test_get_environment_pose(self): + apartment = Object("apartment", ObjectType.ENVIRONMENT, f"apartment.urdf") + pose = apartment.get_pose() + self.assertIsInstance(pose, Pose) + + def test_attach_object(self): + for _ in range(3): + milk = self.spawn_milk([1, 0.1, 0.1]) + cup = self.spawn_cup([1, 1.1, 0.1]) + milk.attach(cup) + self.assertTrue(cup in milk.attachments) + milk_position = milk.get_position_as_list() + milk_position[0] += 1 + cup_position = cup.get_position_as_list() + estimated_cup_position = cup_position.copy() + estimated_cup_position[0] += 1 + milk.set_position(milk_position) + new_cup_position = cup.get_position_as_list() + self.assert_list_is_equal(new_cup_position[:2], estimated_cup_position[:2], + self.multiverse.conf.position_tolerance) + self.tearDown() + + def test_detach_object(self): + for i in range(2): + milk = self.spawn_milk([1, 0, 0.1]) + cup = self.spawn_cup([1, 1, 0.1]) + milk.attach(cup) + self.assertTrue(cup in milk.attachments) + milk.detach(cup) + self.assertTrue(cup not in milk.attachments) + milk_position = milk.get_position_as_list() + milk_position[0] += 1 + cup_position = cup.get_position_as_list() + estimated_cup_position = cup_position.copy() + milk.set_position(milk_position) + new_milk_position = milk.get_position_as_list() + new_cup_position = cup.get_position_as_list() + self.assert_list_is_equal(new_milk_position[:2], milk_position[:2], + self.multiverse.conf.position_tolerance) + self.assert_list_is_equal(new_cup_position[:2], estimated_cup_position[:2], + self.multiverse.conf.position_tolerance) + self.tearDown() + + def test_attach_with_robot(self): + milk = self.spawn_milk([-1, -1, 0.1]) + robot = self.spawn_robot() + ee_link = self.multiverse.get_arm_tool_frame_link(Arms.RIGHT) + # Get position of milk relative to robot end effector + robot.attach(milk, ee_link.name, coincide_the_objects=False) + self.assertTrue(robot in milk.attachments) + milk_initial_pose = milk.root_link.get_pose_wrt_link(ee_link) + robot_position = 1.57 + robot.set_joint_position("arm_right_2_joint", robot_position) + milk_pose = milk.root_link.get_pose_wrt_link(ee_link) + self.assert_poses_are_equal(milk_initial_pose, milk_pose) + + def test_get_object_contact_points(self): + for i in range(10): + milk = self.spawn_milk([1, 1, 0.01], [0, -0.707, 0, 0.707]) + contact_points = self.multiverse.get_object_contact_points(milk) + self.assertIsInstance(contact_points, ContactPointsList) + self.assertEqual(len(contact_points), 1) + self.assertIsInstance(contact_points[0], ContactPoint) + self.assertTrue(contact_points[0].link_b.object, self.multiverse.floor) + cup = self.spawn_cup([1, 1, 0.2]) + # This is needed because the cup is spawned in the air, so it needs to fall + # to get in contact with the milk + self.multiverse.simulate(0.3) + contact_points = self.multiverse.get_object_contact_points(cup) + self.assertIsInstance(contact_points, ContactPointsList) + self.assertEqual(len(contact_points), 1) + self.assertIsInstance(contact_points[0], ContactPoint) + self.assertTrue(contact_points[0].link_b.object, milk) + self.tearDown() + + def test_get_contact_points_between_two_objects(self): + for i in range(3): + milk = self.spawn_milk([1, 1, 0.01], [0, -0.707, 0, 0.707]) + cup = self.spawn_cup([1, 1, 0.2]) + # This is needed because the cup is spawned in the air so it needs to fall + # to get in contact with the milk + self.multiverse.simulate(0.3) + contact_points = self.multiverse.get_contact_points_between_two_objects(milk, cup) + self.assertIsInstance(contact_points, ContactPointsList) + self.assertEqual(len(contact_points), 1) + self.assertIsInstance(contact_points[0], ContactPoint) + self.assertTrue(contact_points[0].link_a.object, milk) + self.assertTrue(contact_points[0].link_b.object, cup) + self.tearDown() + + def test_get_one_ray(self): + milk = self.spawn_milk([1, 1, 0.1]) + intersected_object = self.multiverse.ray_test([1, 2, 0.1], [1, 1.5, 0.1]) + self.assertTrue(intersected_object is None) + intersected_object = self.multiverse.ray_test([1, 2, 0.1], [1, 1, 0.1]) + self.assertTrue(intersected_object == milk.id) + + def test_get_rays(self): + milk = self.spawn_milk([1, 1, 0.1]) + intersected_objects = self.multiverse.ray_test_batch([[1, 2, 0.1], [1, 2, 0.1]], + [[1, 1.5, 0.1], [1, 1, 0.1]]) + self.assertTrue(intersected_objects[0][0] == -1) + self.assertTrue(intersected_objects[1][0] == milk.id) + + @staticmethod + def spawn_big_bowl() -> Object: + big_bowl = Object("big_bowl", ObjectType.GENERIC_OBJECT, "BigBowl.obj", + pose=Pose([2, 2, 0.1], [0, 0, 0, 1])) + return big_bowl + + @staticmethod + def spawn_milk(position: List, orientation: Optional[List] = None, frame="map") -> Object: + if orientation is None: + orientation = [0, 0, 0, 1] + milk = Object("milk_box", ObjectType.MILK, "milk_box.urdf", + pose=Pose(position, orientation, frame=frame)) + return milk + + def spawn_robot(self, position: Optional[List[float]] = None, + orientation: Optional[List[float]] = None, + robot_name: Optional[str] = 'tiago_dual', + replace: Optional[bool] = True) -> Object: + if position is None: + position = [-2, -2, 0.001] + if orientation is None: + orientation = [0, 0, 0, 1] + if self.multiverse.robot is None or replace: + if self.multiverse.robot is not None: + self.multiverse.robot.remove() + robot = Object(robot_name, ObjectType.ROBOT, f"{robot_name}.urdf", + pose=Pose(position, [0, 0, 0, 1])) + else: + robot = self.multiverse.robot + robot.set_position(position) + return robot + + @staticmethod + def spawn_cup(position: List) -> Object: + cup = Object("cup", ObjectType.GENERIC_OBJECT, "Cup.obj", + pose=Pose(position, [0, 0, 0, 1])) + return cup + + def assert_poses_are_equal(self, pose1: Pose, pose2: Pose, + position_delta: Optional[float] = None, orientation_delta: Optional[float] = None): + if position_delta is None: + position_delta = self.multiverse.conf.position_tolerance + if orientation_delta is None: + orientation_delta = self.multiverse.conf.orientation_tolerance + self.assert_position_is_equal(pose1.position_as_list(), pose2.position_as_list(), delta=position_delta) + self.assert_orientation_is_equal(pose1.orientation_as_list(), pose2.orientation_as_list(), + delta=orientation_delta) + + def assert_position_is_equal(self, position1: List[float], position2: List[float], delta: Optional[float] = None): + if delta is None: + delta = self.multiverse.conf.position_tolerance + self.assert_list_is_equal(position1, position2, delta=delta) + + def assert_orientation_is_equal(self, orientation1: List[float], orientation2: List[float], + delta: Optional[float] = None): + if delta is None: + delta = self.multiverse.conf.orientation_tolerance + self.assertAlmostEqual(calculate_angle_between_quaternions(orientation1, orientation2), 0, delta=delta) + + def assert_list_is_equal(self, list1: List, list2: List, delta: float): + for i in range(len(list1)): + self.assertAlmostEqual(list1[i], list2[i], delta=delta) diff --git a/test/test_object.py b/test/test_object.py index 5142aa85d..bede0300b 100644 --- a/test/test_object.py +++ b/test/test_object.py @@ -5,16 +5,18 @@ from pycram.datastructures.enums import JointType, ObjectType from pycram.datastructures.pose import Pose from pycram.datastructures.dataclasses import Color +from pycram.failures import UnsupportedFileExtension from pycram.world_concepts.world_object import Object from pycram.object_descriptors.generic import ObjectDescription as GenericObjectDescription from geometry_msgs.msg import Point, Quaternion import pathlib + class TestObject(BulletWorldTestCase): def test_wrong_object_description_path(self): - with self.assertRaises(FileNotFoundError): + with self.assertRaises(UnsupportedFileExtension): milk = Object("milk_not_found", ObjectType.MILK, "wrong_path.sk") def test_malformed_object_description(self): @@ -166,6 +168,7 @@ def test_object_equal(self): class GenericObjectTestCase(BulletWorldTestCase): def test_init_generic_object(self): - gen_obj_desc = lambda: GenericObjectDescription("robokudo_object", [0,0,0], [0.1, 0.1, 0.1]) + gen_obj_desc = GenericObjectDescription("robokudo_object", [0,0,0], [0.1, 0.1, 0.1]) obj = Object("robokudo_object", ObjectType.MILK, None, gen_obj_desc) - self.assertTrue(True) + pose = obj.get_pose() + self.assertTrue(isinstance(pose, Pose)) diff --git a/test/test_orm.py b/test/test_orm.py index 398cf095d..072382cfe 100644 --- a/test/test_orm.py +++ b/test/test_orm.py @@ -1,4 +1,5 @@ import os +import time import unittest import time from sqlalchemy import select diff --git a/test/test_robot_description.py b/test/test_robot_description.py index 066babba6..e845d9ec6 100644 --- a/test/test_robot_description.py +++ b/test/test_robot_description.py @@ -3,7 +3,7 @@ from pycram.robot_description import RobotDescription, KinematicChainDescription, EndEffectorDescription, \ CameraDescription, RobotDescriptionManager from pycram.datastructures.enums import Arms, GripperState -from urdf_parser_py.urdf import URDF +from pycram.object_descriptors.urdf import ObjectDescription as URDF class TestRobotDescription(unittest.TestCase): @@ -11,8 +11,8 @@ class TestRobotDescription(unittest.TestCase): @classmethod def setUpClass(cls): cls.path = str(pathlib.Path(__file__).parent.resolve()) + '/../resources/robots/' + "pr2" + '.urdf' - cls.urdf_obj = URDF.from_xml_file(cls.path) cls.path_turtlebot = str(pathlib.Path(__file__).parent.resolve()) + '/../resources/robots/' + "turtlebot" + '.urdf' + cls.urdf_obj = URDF(cls.path) def test_robot_description_construct(self): robot_description = RobotDescription("pr2", "base_link", "torso_lift_link", "torso_lift_joint", self.path) diff --git a/test/test_task_tree.py b/test/test_task_tree.py index f31aba1d8..01bda73c8 100644 --- a/test/test_task_tree.py +++ b/test/test_task_tree.py @@ -8,7 +8,7 @@ import unittest import anytree from bullet_world_testcase import BulletWorldTestCase -import pycram.plan_failures +import pycram.failures from pycram.designators import object_designator, action_designator @@ -48,11 +48,11 @@ def test_exception(self): @with_tree def failing_plan(): - raise pycram.plan_failures.PlanFailure("PlanFailure for UnitTesting") + raise pycram.failures.PlanFailure("PlanFailure for UnitTesting") pycram.tasktree.task_tree.reset_tree() - self.assertRaises(pycram.plan_failures.PlanFailure, failing_plan) + self.assertRaises(pycram.failures.PlanFailure, failing_plan) tt = pycram.tasktree.task_tree