-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MJX Training Implementation #1
Changes from 22 commits
3592d06
033d01f
31c1b35
3fa16ca
4df08c3
d3a642c
778476b
c443164
e018d18
9c18a08
ede4b2e
23836dc
8eaf94a
02c1097
f1b3424
58e84cc
a34e181
398ae86
8fe5c18
8fc8fef
498953f
b0ec4e2
6d0bc2b
c88da97
2701d57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,6 @@ build/ | |
dist/ | ||
*.so | ||
out*/ | ||
|
||
# Training artifacts | ||
wandb/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from brax import envs | ||
from envs.default_humanoid_env.default_humanoid import DefaultHumanoidEnv | ||
from envs.stompy_env.stompy import StompyEnv | ||
|
||
environments = {"default_humanoid": DefaultHumanoidEnv, "stompy": StompyEnv} | ||
|
||
|
||
def get_env(name: str, **kwargs) -> envs.Env: | ||
envs.register_environment(name, environments[name]) | ||
return envs.get_environment(name, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import jax | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding one line explanation would be useful. |
||
import jax.numpy as jp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isort |
||
import mujoco | ||
from brax.envs.base import PipelineEnv, State | ||
from brax.io import mjcf | ||
from brax.mjx.base import State as mjxState | ||
from etils import epath | ||
|
||
from envs.default_humanoid_env.rewards import DEFAULT_REWARD_PARAMS | ||
from envs.default_humanoid_env.rewards import get_reward_fn | ||
|
||
|
||
class DefaultHumanoidEnv(PipelineEnv): | ||
""" | ||
An environment for humanoid body position, velocities, and angles. | ||
|
||
Note: This environment is based on the default humanoid environment in the Brax library. | ||
https://github.com/google/brax/blob/main/brax/envs/humanoid.py | ||
|
||
However, this environment is designed to work with modular reward functions, allowing for quicker experimentation. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
reward_params=DEFAULT_REWARD_PARAMS, | ||
terminate_when_unhealthy=True, | ||
reset_noise_scale=1e-2, | ||
exclude_current_positions_from_observation=True, | ||
log_reward_breakdown=True, | ||
**kwargs, | ||
): | ||
path = epath.Path(epath.resource_path("mujoco")) / ("mjx/test_data/humanoid") | ||
mj_model = mujoco.MjModel.from_xml_path((path / "humanoid.xml").as_posix()) # type: ignore | ||
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG # type: ignore # TODO: not sure why typing is not working here | ||
mj_model.opt.iterations = 6 | ||
mj_model.opt.ls_iterations = 6 | ||
|
||
sys = mjcf.load_model(mj_model) | ||
|
||
physics_steps_per_control_step = 4 # Should find way to perturb this value in the future | ||
kwargs["n_frames"] = kwargs.get("n_frames", physics_steps_per_control_step) | ||
kwargs["backend"] = "mjx" | ||
|
||
super().__init__(sys, **kwargs) | ||
|
||
self._reward_params = reward_params | ||
self._terminate_when_unhealthy = terminate_when_unhealthy | ||
self._reset_noise_scale = reset_noise_scale | ||
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation | ||
self._log_reward_breakdown = log_reward_breakdown | ||
|
||
self.reward_fn = get_reward_fn(self._reward_params, self.dt, include_reward_breakdown=True) | ||
|
||
def reset(self, rng: jp.ndarray) -> State: | ||
"""Resets the environment to an initial state. | ||
|
||
Args: | ||
rng: Random number generator seed. | ||
Returns: | ||
The initial state of the environment. | ||
""" | ||
rng, rng1, rng2 = jax.random.split(rng, 3) | ||
|
||
low, hi = -self._reset_noise_scale, self._reset_noise_scale | ||
qpos = self.sys.qpos0 + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi) | ||
qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi) | ||
|
||
mjx_state = self.pipeline_init(qpos, qvel) | ||
assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}" | ||
|
||
obs = self._get_obs(mjx_state, jp.zeros(self.sys.nu)) | ||
reward, done, zero = jp.zeros(3) | ||
metrics = { | ||
"x_position": zero, | ||
"y_position": zero, | ||
"distance_from_origin": zero, | ||
"x_velocity": zero, | ||
"y_velocity": zero, | ||
} | ||
for key in self._reward_params.keys(): | ||
metrics[key] = zero | ||
|
||
return State(mjx_state, obs, reward, done, metrics) | ||
|
||
def step(self, state: State, action: jp.ndarray) -> State: | ||
"""Runs one timestep of the environment's dynamics. | ||
|
||
Args: | ||
state: The current state of the environment. | ||
action: The action to take. | ||
Returns: | ||
A tuple of the next state, the reward, whether the episode has ended, and additional information. | ||
""" | ||
mjx_state = state.pipeline_state | ||
assert mjx_state, "state.pipeline_state was recorded as None" | ||
# TODO: determine whether to raise an error or reset the environment | ||
|
||
next_mjx_state = self.pipeline_step(mjx_state, action) | ||
|
||
assert type(next_mjx_state) == mjxState, f"next_mjx_state is of type {type(next_mjx_state)}" | ||
assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}" | ||
# mlutz: from what I've seen, .pipeline_state and .pipeline_step(...) actually return an brax.mjx.base.State object | ||
# however, the type hinting suggests that it should return a brax.base.State object | ||
# brax.mjx.base.State inherits from brax.base.State but also inherits from mjx.Data, which is needed for some rewards | ||
|
||
obs = self._get_obs(mjx_state, action) | ||
reward, is_healthy, reward_breakdown = self.reward_fn(mjx_state, action, next_mjx_state) | ||
|
||
if self._terminate_when_unhealthy: | ||
done = 1.0 - is_healthy | ||
else: | ||
done = jp.array(0) | ||
|
||
state.metrics.update( | ||
x_position=next_mjx_state.subtree_com[1][0], | ||
y_position=next_mjx_state.subtree_com[1][1], | ||
distance_from_origin=jp.linalg.norm(next_mjx_state.subtree_com[1]), | ||
x_velocity=(next_mjx_state.subtree_com[1][0] - mjx_state.subtree_com[1][0]) / self.dt, | ||
y_velocity=(next_mjx_state.subtree_com[1][1] - mjx_state.subtree_com[1][1]) / self.dt, | ||
) | ||
|
||
if self._log_reward_breakdown: | ||
for key, val in reward_breakdown.items(): | ||
state.metrics[key] = val | ||
|
||
return state.replace( # type: ignore # TODO: fix the type hinting... | ||
pipeline_state=next_mjx_state, obs=obs, reward=reward, done=done | ||
) | ||
|
||
def _get_obs(self, data: mjxState, action: jp.ndarray) -> jp.ndarray: | ||
"""Observes humanoid body position, velocities, and angles. | ||
|
||
Args: | ||
data: The current state of the environment. | ||
action: The current action. | ||
Returns: | ||
Observations of the environment. | ||
""" | ||
position = data.qpos | ||
if self._exclude_current_positions_from_observation: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this is needed? |
||
position = position[2:] | ||
|
||
# external_contact_forces are excluded | ||
return jp.concatenate( | ||
[ | ||
position, | ||
data.qvel, | ||
data.cinert[1:].ravel(), | ||
data.cvel[1:].ravel(), | ||
data.qfrc_actuator, | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can omit putting stuff here, that would be preferable.