Skip to content

Commit

Permalink
Additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanElsner committed Nov 23, 2023
1 parent 00fb5c8 commit 008fa3a
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 7 deletions.
15 changes: 11 additions & 4 deletions test/test_arm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from unittest import mock

import numpy as np
import pytest
from dm_control import mjcf
from numpy import testing

from dm_robotics.panda import arm, parameters
from dm_robotics.panda import arm, arm_constants, parameters


def test_physics_step():
Expand All @@ -30,3 +27,13 @@ def test_arm_effector():
physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model)
effector = arm.ArmEffector(robot_params, robot)
effector.set_control(physics, np.zeros(7, dtype=np.float32))
effector.close()


def test_arm_haptic():
robot = arm.Panda()
robot_params = parameters.RobotParams(
actuation=arm_constants.Actuation.HAPTIC)
physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model)
effector = arm.ArmEffector(robot_params, robot)
effector.set_control(physics, np.zeros(0, dtype=np.float32))
55 changes: 53 additions & 2 deletions test/test_environment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import numpy as np
from dm_control import mjcf
from dm_env import specs
from dm_robotics.agentflow.preprocessors import rewards
from dm_robotics.moma import effector, entity_initializer, prop
from dm_robotics.moma.sensors import prop_pose_sensor

from dm_robotics.panda import environment, parameters, utils


def test_robots():
robot_config = [
parameters.RobotParams(name='test1'),
parameters.RobotParams(name='test2'),
parameters.RobotParams(name='test3')
parameters.RobotParams(name='test2', pose=[0, 1, 0, 0, 0, 0]),
parameters.RobotParams(name='test3', pose=[0, 2, 0, 0, 0, 0])
]
panda_environment = environment.PandaEnvironment(robot_config)

Expand All @@ -17,3 +24,47 @@ def test_build():
with environment.PandaEnvironment(
parameters.RobotParams()).build_task_environment() as env:
utils.full_spec(env)


class MockEffector(effector.Effector):

def initialize_episode(self, physics: mjcf.Physics,
random_state: np.random.RandomState) -> None:
pass

def set_control(self, physics: mjcf.Physics, command: np.ndarray) -> None:
pass

@property
def prefix(self) -> str:
return 'dummy'

def action_spec(self, physics: mjcf.Physics) -> specs.BoundedArray:
return specs.BoundedArray((1,), np.float32, (0,), (1,))


def initialize_scene(random_state: np.random.RandomState) -> None:
del random_state


def test_components():
panda_env = environment.PandaEnvironment(parameters.RobotParams())

props = [prop.Block()]
extra_sensors = [prop_pose_sensor.PropPoseSensor(props[0], 'prop')]
extra_effectors = [MockEffector()]
preprocessors = [rewards.ComputeReward(lambda obs: 1)]

panda_env.add_props(props)

entity_initializers = [
entity_initializer.PropPlacer(props, [0, 0, 0], [1, 0, 0, 0])
]
panda_env.add_extra_effectors(extra_effectors)
panda_env.add_extra_sensors(extra_sensors)
panda_env.add_entity_initializers(entity_initializers)
panda_env.add_scene_initializers([initialize_scene])
panda_env.add_timestep_preprocessors(preprocessors)

with panda_env.build_task_environment():
pass
17 changes: 16 additions & 1 deletion test/test_gripper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@
from dm_control import mjcf
from numpy import testing

from dm_robotics.panda import gripper
from dm_robotics.panda import gripper, parameters


def test_physics_step():
robot = gripper.PandaHand()
physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model)
physics.step()

assert len(robot.joints) == 2
assert len(robot.actuators) == 1

robot = gripper.DummyHand()
physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model)
physics.step()

assert len(robot.joints) == 0
assert len(robot.actuators) == 0


def test_set_width():
robot = gripper.PandaHand()
Expand All @@ -25,3 +31,12 @@ def test_set_width():

robot.set_width(physics, 0.08)
testing.assert_allclose(physics.bind(robot.joints).qpos, 0.04 * np.ones(2))


def test_effector():
robot_params = parameters.RobotParams()
robot = gripper.PandaHand()
sensor = gripper.PandaHandSensor(robot, 'hand')
effector = gripper.PandaHandEffector(robot_params, robot, sensor)
physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model)
effector.set_control(physics, np.zeros(1))
56 changes: 56 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from unittest.mock import MagicMock, patch

import mujoco
Expand Down Expand Up @@ -51,3 +52,58 @@ def test_plots(mock_runtime, mock_context, mock_viewport):

def test_logging():
utils.init_logging()


def test_formatter_warning():
formatter = utils.Formatter()

record = logging.LogRecord(name='test_logger',
level=logging.WARNING,
pathname='/path/to/module.py',
lineno=42,
msg='This is a warning message',
args=(),
exc_info=None)

formatted_msg = formatter.format(record)

# Check if the formatted message contains the ANSI escape code for yellow color
assert '\033[33m' in formatted_msg
# Check if the formatted message ends with the ANSI escape code for resetting color
assert formatted_msg.endswith('\033[0m')


def test_formatter_error():
formatter = utils.Formatter()

record = logging.LogRecord(name='test_logger',
level=logging.ERROR,
pathname='/path/to/module.py',
lineno=42,
msg='This is an error message',
args=(),
exc_info=None)

formatted_msg = formatter.format(record)

# Check if the formatted message contains the ANSI escape code for red color
assert '\033[31m' in formatted_msg
# Check if the formatted message ends with the ANSI escape code for resetting color
assert formatted_msg.endswith('\033[0m')


def test_formatter_info():
formatter = utils.Formatter()

record = logging.LogRecord(name='test_logger',
level=logging.INFO,
pathname='/path/to/module.py',
lineno=42,
msg='This is an info message',
args=(),
exc_info=None)

formatted_msg = formatter.format(record)

assert '\033[33m' not in formatted_msg
assert '\033[31m' not in formatted_msg

0 comments on commit 008fa3a

Please sign in to comment.