Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MJX Training Pipeline #14

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ warn_redundant_casts = true
incremental = true
namespace_packages = false

exclude = ["sim/humanoid_gym/", "sim/deploy", "sim/scripts/create_mjcf.py"]
exclude = ["sim/humanoid_gym/", "sim/deploy", "sim/scripts/create_mjcf.py", "sim/mjx_gym/"]



Expand All @@ -54,7 +54,7 @@ profile = "black"
line-length = 120
target-version = "py310"

exclude = ["sim/humanoid_gym", "sim/deploy/"]
exclude = ["sim/humanoid_gym/", "sim/deploy", "sim/scripts/create_mjcf.py", "sim/mjx_gym/"]

[tool.ruff.lint]

Expand Down
Empty file added sim/mjx_gym/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions sim/mjx_gym/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from brax import envs
from envs.default_humanoid_env.default_humanoid import DefaultHumanoidEnv
from envs.stompy_env.stompy import StompyEnv

environments = {"default_humanoid": DefaultHumanoidEnv, "stompy": StompyEnv}


def get_env(name: str, **kwargs) -> envs.Env:
envs.register_environment(name, environments[name])
return envs.get_environment(name, **kwargs)
Empty file.
152 changes: 152 additions & 0 deletions sim/mjx_gym/envs/default_humanoid_env/default_humanoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import jax
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isort ruff et al., see Makefile for formatting setup

import jax.numpy as jp
import mujoco
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
from brax.mjx.base import State as mjxState
from etils import epath

from envs.default_humanoid_env.rewards import DEFAULT_REWARD_PARAMS
from envs.default_humanoid_env.rewards import get_reward_fn


class DefaultHumanoidEnv(PipelineEnv):
"""
An environment for humanoid body position, velocities, and angles.

Note: This environment is based on the default humanoid environment in the Brax library.
https://github.com/google/brax/blob/main/brax/envs/humanoid.py

However, this environment is designed to work with modular reward functions, allowing for quicker experimentation.
"""

def __init__(
self,
reward_params=DEFAULT_REWARD_PARAMS,
terminate_when_unhealthy=True,
reset_noise_scale=1e-2,
exclude_current_positions_from_observation=True,
log_reward_breakdown=True,
**kwargs,
):
path = epath.Path(epath.resource_path("mujoco")) / ("mjx/test_data/humanoid")
mj_model = mujoco.MjModel.from_xml_path((path / "humanoid.xml").as_posix()) # type: ignore
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG # type: ignore # TODO: not sure why typing is not working here
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6

sys = mjcf.load_model(mj_model)

physics_steps_per_control_step = 4 # Should find way to perturb this value in the future
kwargs["n_frames"] = kwargs.get("n_frames", physics_steps_per_control_step)
kwargs["backend"] = "mjx"

super().__init__(sys, **kwargs)

self._reward_params = reward_params
self._terminate_when_unhealthy = terminate_when_unhealthy
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
self._log_reward_breakdown = log_reward_breakdown

self.reward_fn = get_reward_fn(self._reward_params, self.dt, include_reward_breakdown=True)

def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state.

Args:
rng: Random number generator seed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - formatting still off here

Suggested change
rng: Random number generator seed.
Args:
rng: Random number generator seed.

Returns:
The initial state of the environment.
"""
rng, rng1, rng2 = jax.random.split(rng, 3)

low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.qpos0 + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi)
qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi)

mjx_state = self.pipeline_init(qpos, qvel)
assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}"

obs = self._get_obs(mjx_state, jp.zeros(self.sys.nu))
reward, done, zero = jp.zeros(3)
metrics = {
"x_position": zero,
"y_position": zero,
"distance_from_origin": zero,
"x_velocity": zero,
"y_velocity": zero,
}
for key in self._reward_params.keys():
metrics[key] = zero

return State(mjx_state, obs, reward, done, metrics)

def step(self, state: State, action: jp.ndarray) -> State:
"""Runs one timestep of the environment's dynamics.

Args:
state: The current state of the environment.
action: The action to take.
Returns:
A tuple of the next state, the reward, whether the episode has ended, and additional information.
"""
mjx_state = state.pipeline_state
assert mjx_state, "state.pipeline_state was recorded as None"
# TODO: determine whether to raise an error or reset the environment

next_mjx_state = self.pipeline_step(mjx_state, action)

assert type(next_mjx_state) == mjxState, f"next_mjx_state is of type {type(next_mjx_state)}"
assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}"
# mlutz: from what I've seen, .pipeline_state and .pipeline_step(...) actually return an brax.mjx.base.State object
# however, the type hinting suggests that it should return a brax.base.State object
# brax.mjx.base.State inherits from brax.base.State but also inherits from mjx.Data, which is needed for some rewards

obs = self._get_obs(mjx_state, action)
reward, is_healthy, reward_breakdown = self.reward_fn(mjx_state, action, next_mjx_state)

if self._terminate_when_unhealthy:
done = 1.0 - is_healthy
else:
done = jp.array(0)

state.metrics.update(
x_position=next_mjx_state.subtree_com[1][0],
y_position=next_mjx_state.subtree_com[1][1],
distance_from_origin=jp.linalg.norm(next_mjx_state.subtree_com[1]),
x_velocity=(next_mjx_state.subtree_com[1][0] - mjx_state.subtree_com[1][0]) / self.dt,
y_velocity=(next_mjx_state.subtree_com[1][1] - mjx_state.subtree_com[1][1]) / self.dt,
)

if self._log_reward_breakdown:
for key, val in reward_breakdown.items():
state.metrics[key] = val

return state.replace( # type: ignore # TODO: fix the type hinting...
pipeline_state=next_mjx_state, obs=obs, reward=reward, done=done
)

def _get_obs(self, data: mjxState, action: jp.ndarray) -> jp.ndarray:
"""Observes humanoid body position, velocities, and angles.

Args:
data: The current state of the environment.
action: The current action.
Returns:
Observations of the environment.
"""
position = data.qpos
if self._exclude_current_positions_from_observation:
position = position[2:]

# external_contact_forces are excluded
return jp.concatenate(
[
position,
data.qvel,
data.cinert[1:].ravel(),
data.cvel[1:].ravel(),
data.qfrc_actuator,
]
)
111 changes: 111 additions & 0 deletions sim/mjx_gym/envs/default_humanoid_env/rewards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Callable, Dict, Tuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good practice is to add even one line comment what the file is about.


import jax
import jax.numpy as jp
from brax.mjx.base import State as mjxState

DEFAULT_REWARD_PARAMS = {
"rew_forward": {"weight": 1.25},
"rew_healthy": {"weight": 5.0, "healthy_z_lower": 1.0, "healthy_z_upper": 2.0},
"rew_ctrl_cost": {"weight": 0.1},
}


def get_reward_fn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be used across all envs?

reward_params: Dict[str, Dict[str, float]], dt, include_reward_breakdown
) -> Callable[[mjxState, jp.ndarray, mjxState], Tuple[jp.ndarray, jp.ndarray, Dict[str, jp.ndarray]]]:
"""Get a combined reward function.

Args:
reward_params: Dictionary of reward parameters.
dt: Time step.
Returns:
A reward function that takes in a state, action, and next state and returns a float wrapped in a jp.ndarray.
"""

def reward_fn(
state: mjxState, action: jp.ndarray, next_state: mjxState
) -> Tuple[jp.ndarray, jp.ndarray, Dict[str, jp.ndarray]]:
reward, is_healthy = jp.array(0.0), jp.array(1.0)
rewards = {}
for key, params in reward_params.items():
r, h = reward_functions[key](state, action, next_state, dt, params)
is_healthy *= h
reward += r
if include_reward_breakdown: # For more detailed logging, can be disabled for performance
rewards[key] = r
return reward, is_healthy, rewards

return reward_fn


def forward_reward_fn(
state: mjxState, action: jp.ndarray, next_state: mjxState, dt: jax.Array, params: Dict[str, float]
) -> Tuple[jp.ndarray, jp.ndarray]:
"""Reward function for moving forward.

Args:
state: Current state.
action: Action taken.
next_state: Next state.
dt: Time step.
params: Reward parameters.
Returns:
A float wrapped in a jax array.
"""
xpos = state.subtree_com[1][0] # TODO: include stricter typing than mjxState to avoid this type error
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it more explicit what you are loading

next_xpos = next_state.subtree_com[1][0]
velocity = (next_xpos - xpos) / dt
forward_reward = params["weight"] * velocity

return forward_reward, jp.array(1.0) # TODO: ensure everything is initialized in a size 2 array instead...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the logic behind 1.0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jp.array because we want to keep everything in jax. The 1.0 itself operates like a boolean operator (doesn't change the "healthiness" until a 0 comes around)



def healthy_reward_fn(
state: mjxState, action: jp.ndarray, next_state: mjxState, dt: jax.Array, params: Dict[str, float]
) -> Tuple[jp.ndarray, jp.ndarray]:
"""Reward function for staying healthy.

Args:
state: Current state.
action: Action taken.
next_state: Next state.
dt: Time step.
params: Reward parameters.
Returns:
A float wrapped in a jax array.
"""
min_z = params["healthy_z_lower"]
max_z = params["healthy_z_upper"]
is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0)
is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy)
healthy_reward = jp.array(params["weight"]) * is_healthy

return healthy_reward, is_healthy


def ctrl_cost_reward_fn(
state: mjxState, action: jp.ndarray, next_state: mjxState, dt: jax.Array, params: Dict[str, float]
) -> Tuple[jp.ndarray, jp.ndarray]:
"""Reward function for control cost.

Args:
state: Current state.
action: Action taken.
next_state: Next state.
dt: Time step.
params: Reward parameters.
Returns:
A float wrapped in a jax array.
"""
ctrl_cost = -params["weight"] * jp.sum(jp.square(action))

return ctrl_cost, jp.array(1.0)


# NOTE: After defining the reward functions, they must be added here to be used in the combined reward function.
reward_functions = {
"rew_forward": forward_reward_fn,
"rew_healthy": healthy_reward_fn,
"rew_ctrl_cost": ctrl_cost_reward_fn,
}
Empty file.
Loading
Loading