-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MJX Training Pipeline #14
Changes from all commits
f682180
b4fe375
c708b19
b9cc1ff
2eead0c
ba5419b
f2b3db3
d71aa52
b9e00fb
df829ef
12d3f8b
736d3bf
da03b17
f2b8f18
9346d4e
e2a9be4
f570c55
487e816
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from brax import envs | ||
from envs.default_humanoid_env.default_humanoid import DefaultHumanoidEnv | ||
from envs.stompy_env.stompy import StompyEnv | ||
|
||
environments = {"default_humanoid": DefaultHumanoidEnv, "stompy": StompyEnv} | ||
|
||
|
||
def get_env(name: str, **kwargs) -> envs.Env: | ||
envs.register_environment(name, environments[name]) | ||
return envs.get_environment(name, **kwargs) |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,152 @@ | ||||||||
import jax | ||||||||
import jax.numpy as jp | ||||||||
import mujoco | ||||||||
from brax.envs.base import PipelineEnv, State | ||||||||
from brax.io import mjcf | ||||||||
from brax.mjx.base import State as mjxState | ||||||||
from etils import epath | ||||||||
|
||||||||
from envs.default_humanoid_env.rewards import DEFAULT_REWARD_PARAMS | ||||||||
from envs.default_humanoid_env.rewards import get_reward_fn | ||||||||
|
||||||||
|
||||||||
class DefaultHumanoidEnv(PipelineEnv): | ||||||||
""" | ||||||||
An environment for humanoid body position, velocities, and angles. | ||||||||
|
||||||||
Note: This environment is based on the default humanoid environment in the Brax library. | ||||||||
https://github.com/google/brax/blob/main/brax/envs/humanoid.py | ||||||||
|
||||||||
However, this environment is designed to work with modular reward functions, allowing for quicker experimentation. | ||||||||
""" | ||||||||
|
||||||||
def __init__( | ||||||||
self, | ||||||||
reward_params=DEFAULT_REWARD_PARAMS, | ||||||||
terminate_when_unhealthy=True, | ||||||||
reset_noise_scale=1e-2, | ||||||||
exclude_current_positions_from_observation=True, | ||||||||
log_reward_breakdown=True, | ||||||||
**kwargs, | ||||||||
): | ||||||||
path = epath.Path(epath.resource_path("mujoco")) / ("mjx/test_data/humanoid") | ||||||||
mj_model = mujoco.MjModel.from_xml_path((path / "humanoid.xml").as_posix()) # type: ignore | ||||||||
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG # type: ignore # TODO: not sure why typing is not working here | ||||||||
mj_model.opt.iterations = 6 | ||||||||
mj_model.opt.ls_iterations = 6 | ||||||||
|
||||||||
sys = mjcf.load_model(mj_model) | ||||||||
|
||||||||
physics_steps_per_control_step = 4 # Should find way to perturb this value in the future | ||||||||
kwargs["n_frames"] = kwargs.get("n_frames", physics_steps_per_control_step) | ||||||||
kwargs["backend"] = "mjx" | ||||||||
|
||||||||
super().__init__(sys, **kwargs) | ||||||||
|
||||||||
self._reward_params = reward_params | ||||||||
self._terminate_when_unhealthy = terminate_when_unhealthy | ||||||||
self._reset_noise_scale = reset_noise_scale | ||||||||
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation | ||||||||
self._log_reward_breakdown = log_reward_breakdown | ||||||||
|
||||||||
self.reward_fn = get_reward_fn(self._reward_params, self.dt, include_reward_breakdown=True) | ||||||||
|
||||||||
def reset(self, rng: jp.ndarray) -> State: | ||||||||
"""Resets the environment to an initial state. | ||||||||
|
||||||||
Args: | ||||||||
rng: Random number generator seed. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit - formatting still off here
Suggested change
|
||||||||
Returns: | ||||||||
The initial state of the environment. | ||||||||
""" | ||||||||
rng, rng1, rng2 = jax.random.split(rng, 3) | ||||||||
|
||||||||
low, hi = -self._reset_noise_scale, self._reset_noise_scale | ||||||||
qpos = self.sys.qpos0 + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi) | ||||||||
qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi) | ||||||||
|
||||||||
mjx_state = self.pipeline_init(qpos, qvel) | ||||||||
assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}" | ||||||||
|
||||||||
obs = self._get_obs(mjx_state, jp.zeros(self.sys.nu)) | ||||||||
reward, done, zero = jp.zeros(3) | ||||||||
metrics = { | ||||||||
"x_position": zero, | ||||||||
"y_position": zero, | ||||||||
"distance_from_origin": zero, | ||||||||
"x_velocity": zero, | ||||||||
"y_velocity": zero, | ||||||||
} | ||||||||
for key in self._reward_params.keys(): | ||||||||
metrics[key] = zero | ||||||||
|
||||||||
return State(mjx_state, obs, reward, done, metrics) | ||||||||
|
||||||||
def step(self, state: State, action: jp.ndarray) -> State: | ||||||||
"""Runs one timestep of the environment's dynamics. | ||||||||
|
||||||||
Args: | ||||||||
state: The current state of the environment. | ||||||||
action: The action to take. | ||||||||
Returns: | ||||||||
A tuple of the next state, the reward, whether the episode has ended, and additional information. | ||||||||
""" | ||||||||
mjx_state = state.pipeline_state | ||||||||
assert mjx_state, "state.pipeline_state was recorded as None" | ||||||||
# TODO: determine whether to raise an error or reset the environment | ||||||||
|
||||||||
next_mjx_state = self.pipeline_step(mjx_state, action) | ||||||||
|
||||||||
assert type(next_mjx_state) == mjxState, f"next_mjx_state is of type {type(next_mjx_state)}" | ||||||||
assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}" | ||||||||
# mlutz: from what I've seen, .pipeline_state and .pipeline_step(...) actually return an brax.mjx.base.State object | ||||||||
# however, the type hinting suggests that it should return a brax.base.State object | ||||||||
# brax.mjx.base.State inherits from brax.base.State but also inherits from mjx.Data, which is needed for some rewards | ||||||||
|
||||||||
obs = self._get_obs(mjx_state, action) | ||||||||
reward, is_healthy, reward_breakdown = self.reward_fn(mjx_state, action, next_mjx_state) | ||||||||
|
||||||||
if self._terminate_when_unhealthy: | ||||||||
done = 1.0 - is_healthy | ||||||||
else: | ||||||||
done = jp.array(0) | ||||||||
|
||||||||
state.metrics.update( | ||||||||
x_position=next_mjx_state.subtree_com[1][0], | ||||||||
y_position=next_mjx_state.subtree_com[1][1], | ||||||||
distance_from_origin=jp.linalg.norm(next_mjx_state.subtree_com[1]), | ||||||||
x_velocity=(next_mjx_state.subtree_com[1][0] - mjx_state.subtree_com[1][0]) / self.dt, | ||||||||
y_velocity=(next_mjx_state.subtree_com[1][1] - mjx_state.subtree_com[1][1]) / self.dt, | ||||||||
) | ||||||||
|
||||||||
if self._log_reward_breakdown: | ||||||||
for key, val in reward_breakdown.items(): | ||||||||
state.metrics[key] = val | ||||||||
|
||||||||
return state.replace( # type: ignore # TODO: fix the type hinting... | ||||||||
pipeline_state=next_mjx_state, obs=obs, reward=reward, done=done | ||||||||
) | ||||||||
|
||||||||
def _get_obs(self, data: mjxState, action: jp.ndarray) -> jp.ndarray: | ||||||||
"""Observes humanoid body position, velocities, and angles. | ||||||||
|
||||||||
Args: | ||||||||
data: The current state of the environment. | ||||||||
action: The current action. | ||||||||
Returns: | ||||||||
Observations of the environment. | ||||||||
""" | ||||||||
position = data.qpos | ||||||||
if self._exclude_current_positions_from_observation: | ||||||||
position = position[2:] | ||||||||
|
||||||||
# external_contact_forces are excluded | ||||||||
return jp.concatenate( | ||||||||
[ | ||||||||
position, | ||||||||
data.qvel, | ||||||||
data.cinert[1:].ravel(), | ||||||||
data.cvel[1:].ravel(), | ||||||||
data.qfrc_actuator, | ||||||||
] | ||||||||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from typing import Callable, Dict, Tuple | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good practice is to add even one line comment what the file is about. |
||
|
||
import jax | ||
import jax.numpy as jp | ||
from brax.mjx.base import State as mjxState | ||
|
||
DEFAULT_REWARD_PARAMS = { | ||
"rew_forward": {"weight": 1.25}, | ||
"rew_healthy": {"weight": 5.0, "healthy_z_lower": 1.0, "healthy_z_upper": 2.0}, | ||
"rew_ctrl_cost": {"weight": 0.1}, | ||
} | ||
|
||
|
||
def get_reward_fn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be used across all envs? |
||
reward_params: Dict[str, Dict[str, float]], dt, include_reward_breakdown | ||
) -> Callable[[mjxState, jp.ndarray, mjxState], Tuple[jp.ndarray, jp.ndarray, Dict[str, jp.ndarray]]]: | ||
"""Get a combined reward function. | ||
|
||
Args: | ||
reward_params: Dictionary of reward parameters. | ||
dt: Time step. | ||
Returns: | ||
A reward function that takes in a state, action, and next state and returns a float wrapped in a jp.ndarray. | ||
""" | ||
|
||
def reward_fn( | ||
state: mjxState, action: jp.ndarray, next_state: mjxState | ||
) -> Tuple[jp.ndarray, jp.ndarray, Dict[str, jp.ndarray]]: | ||
reward, is_healthy = jp.array(0.0), jp.array(1.0) | ||
rewards = {} | ||
for key, params in reward_params.items(): | ||
r, h = reward_functions[key](state, action, next_state, dt, params) | ||
is_healthy *= h | ||
reward += r | ||
if include_reward_breakdown: # For more detailed logging, can be disabled for performance | ||
rewards[key] = r | ||
return reward, is_healthy, rewards | ||
|
||
return reward_fn | ||
|
||
|
||
def forward_reward_fn( | ||
state: mjxState, action: jp.ndarray, next_state: mjxState, dt: jax.Array, params: Dict[str, float] | ||
) -> Tuple[jp.ndarray, jp.ndarray]: | ||
"""Reward function for moving forward. | ||
|
||
Args: | ||
state: Current state. | ||
action: Action taken. | ||
next_state: Next state. | ||
dt: Time step. | ||
params: Reward parameters. | ||
Returns: | ||
A float wrapped in a jax array. | ||
""" | ||
xpos = state.subtree_com[1][0] # TODO: include stricter typing than mjxState to avoid this type error | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make it more explicit what you are loading |
||
next_xpos = next_state.subtree_com[1][0] | ||
velocity = (next_xpos - xpos) / dt | ||
forward_reward = params["weight"] * velocity | ||
|
||
return forward_reward, jp.array(1.0) # TODO: ensure everything is initialized in a size 2 array instead... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the logic behind 1.0? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. jp.array because we want to keep everything in jax. The 1.0 itself operates like a boolean operator (doesn't change the "healthiness" until a 0 comes around) |
||
|
||
|
||
def healthy_reward_fn( | ||
state: mjxState, action: jp.ndarray, next_state: mjxState, dt: jax.Array, params: Dict[str, float] | ||
) -> Tuple[jp.ndarray, jp.ndarray]: | ||
"""Reward function for staying healthy. | ||
|
||
Args: | ||
state: Current state. | ||
action: Action taken. | ||
next_state: Next state. | ||
dt: Time step. | ||
params: Reward parameters. | ||
Returns: | ||
A float wrapped in a jax array. | ||
""" | ||
min_z = params["healthy_z_lower"] | ||
max_z = params["healthy_z_upper"] | ||
is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0) | ||
is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy) | ||
healthy_reward = jp.array(params["weight"]) * is_healthy | ||
|
||
return healthy_reward, is_healthy | ||
|
||
|
||
def ctrl_cost_reward_fn( | ||
state: mjxState, action: jp.ndarray, next_state: mjxState, dt: jax.Array, params: Dict[str, float] | ||
) -> Tuple[jp.ndarray, jp.ndarray]: | ||
"""Reward function for control cost. | ||
|
||
Args: | ||
state: Current state. | ||
action: Action taken. | ||
next_state: Next state. | ||
dt: Time step. | ||
params: Reward parameters. | ||
Returns: | ||
A float wrapped in a jax array. | ||
""" | ||
ctrl_cost = -params["weight"] * jp.sum(jp.square(action)) | ||
|
||
return ctrl_cost, jp.array(1.0) | ||
|
||
|
||
# NOTE: After defining the reward functions, they must be added here to be used in the combined reward function. | ||
reward_functions = { | ||
"rew_forward": forward_reward_fn, | ||
"rew_healthy": healthy_reward_fn, | ||
"rew_ctrl_cost": ctrl_cost_reward_fn, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isort ruff et al., see Makefile for formatting setup