From 008fa3acb888faf8a5402308afb54536fe7710e5 Mon Sep 17 00:00:00 2001 From: JeanElsner Date: Thu, 23 Nov 2023 12:55:07 +0100 Subject: [PATCH] Additional tests --- test/test_arm.py | 15 ++++++++--- test/test_environment.py | 55 +++++++++++++++++++++++++++++++++++++-- test/test_gripper.py | 17 +++++++++++- test/test_utils.py | 56 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 7 deletions(-) diff --git a/test/test_arm.py b/test/test_arm.py index d9b4646..c8e2880 100644 --- a/test/test_arm.py +++ b/test/test_arm.py @@ -1,11 +1,8 @@ -from unittest import mock - import numpy as np -import pytest from dm_control import mjcf from numpy import testing -from dm_robotics.panda import arm, parameters +from dm_robotics.panda import arm, arm_constants, parameters def test_physics_step(): @@ -30,3 +27,13 @@ def test_arm_effector(): physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model) effector = arm.ArmEffector(robot_params, robot) effector.set_control(physics, np.zeros(7, dtype=np.float32)) + effector.close() + + +def test_arm_haptic(): + robot = arm.Panda() + robot_params = parameters.RobotParams( + actuation=arm_constants.Actuation.HAPTIC) + physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model) + effector = arm.ArmEffector(robot_params, robot) + effector.set_control(physics, np.zeros(0, dtype=np.float32)) diff --git a/test/test_environment.py b/test/test_environment.py index b1ff19c..491bd8d 100644 --- a/test/test_environment.py +++ b/test/test_environment.py @@ -1,11 +1,18 @@ +import numpy as np +from dm_control import mjcf +from dm_env import specs +from dm_robotics.agentflow.preprocessors import rewards +from dm_robotics.moma import effector, entity_initializer, prop +from dm_robotics.moma.sensors import prop_pose_sensor + from dm_robotics.panda import environment, parameters, utils def test_robots(): robot_config = [ parameters.RobotParams(name='test1'), - parameters.RobotParams(name='test2'), - parameters.RobotParams(name='test3') + parameters.RobotParams(name='test2', pose=[0, 1, 0, 0, 0, 0]), + parameters.RobotParams(name='test3', pose=[0, 2, 0, 0, 0, 0]) ] panda_environment = environment.PandaEnvironment(robot_config) @@ -17,3 +24,47 @@ def test_build(): with environment.PandaEnvironment( parameters.RobotParams()).build_task_environment() as env: utils.full_spec(env) + + +class MockEffector(effector.Effector): + + def initialize_episode(self, physics: mjcf.Physics, + random_state: np.random.RandomState) -> None: + pass + + def set_control(self, physics: mjcf.Physics, command: np.ndarray) -> None: + pass + + @property + def prefix(self) -> str: + return 'dummy' + + def action_spec(self, physics: mjcf.Physics) -> specs.BoundedArray: + return specs.BoundedArray((1,), np.float32, (0,), (1,)) + + +def initialize_scene(random_state: np.random.RandomState) -> None: + del random_state + + +def test_components(): + panda_env = environment.PandaEnvironment(parameters.RobotParams()) + + props = [prop.Block()] + extra_sensors = [prop_pose_sensor.PropPoseSensor(props[0], 'prop')] + extra_effectors = [MockEffector()] + preprocessors = [rewards.ComputeReward(lambda obs: 1)] + + panda_env.add_props(props) + + entity_initializers = [ + entity_initializer.PropPlacer(props, [0, 0, 0], [1, 0, 0, 0]) + ] + panda_env.add_extra_effectors(extra_effectors) + panda_env.add_extra_sensors(extra_sensors) + panda_env.add_entity_initializers(entity_initializers) + panda_env.add_scene_initializers([initialize_scene]) + panda_env.add_timestep_preprocessors(preprocessors) + + with panda_env.build_task_environment(): + pass diff --git a/test/test_gripper.py b/test/test_gripper.py index bf8df5e..b37e736 100644 --- a/test/test_gripper.py +++ b/test/test_gripper.py @@ -2,7 +2,7 @@ from dm_control import mjcf from numpy import testing -from dm_robotics.panda import gripper +from dm_robotics.panda import gripper, parameters def test_physics_step(): @@ -10,10 +10,16 @@ def test_physics_step(): physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model) physics.step() + assert len(robot.joints) == 2 + assert len(robot.actuators) == 1 + robot = gripper.DummyHand() physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model) physics.step() + assert len(robot.joints) == 0 + assert len(robot.actuators) == 0 + def test_set_width(): robot = gripper.PandaHand() @@ -25,3 +31,12 @@ def test_set_width(): robot.set_width(physics, 0.08) testing.assert_allclose(physics.bind(robot.joints).qpos, 0.04 * np.ones(2)) + + +def test_effector(): + robot_params = parameters.RobotParams() + robot = gripper.PandaHand() + sensor = gripper.PandaHandSensor(robot, 'hand') + effector = gripper.PandaHandEffector(robot_params, robot, sensor) + physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model) + effector.set_control(physics, np.zeros(1)) diff --git a/test/test_utils.py b/test/test_utils.py index 1fa1946..9eae5a4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +import logging from unittest.mock import MagicMock, patch import mujoco @@ -51,3 +52,58 @@ def test_plots(mock_runtime, mock_context, mock_viewport): def test_logging(): utils.init_logging() + + +def test_formatter_warning(): + formatter = utils.Formatter() + + record = logging.LogRecord(name='test_logger', + level=logging.WARNING, + pathname='/path/to/module.py', + lineno=42, + msg='This is a warning message', + args=(), + exc_info=None) + + formatted_msg = formatter.format(record) + + # Check if the formatted message contains the ANSI escape code for yellow color + assert '\033[33m' in formatted_msg + # Check if the formatted message ends with the ANSI escape code for resetting color + assert formatted_msg.endswith('\033[0m') + + +def test_formatter_error(): + formatter = utils.Formatter() + + record = logging.LogRecord(name='test_logger', + level=logging.ERROR, + pathname='/path/to/module.py', + lineno=42, + msg='This is an error message', + args=(), + exc_info=None) + + formatted_msg = formatter.format(record) + + # Check if the formatted message contains the ANSI escape code for red color + assert '\033[31m' in formatted_msg + # Check if the formatted message ends with the ANSI escape code for resetting color + assert formatted_msg.endswith('\033[0m') + + +def test_formatter_info(): + formatter = utils.Formatter() + + record = logging.LogRecord(name='test_logger', + level=logging.INFO, + pathname='/path/to/module.py', + lineno=42, + msg='This is an info message', + args=(), + exc_info=None) + + formatted_msg = formatter.format(record) + + assert '\033[33m' not in formatted_msg + assert '\033[31m' not in formatted_msg