-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MJX Training Implementation #1
Changes from all commits
3592d06
033d01f
31c1b35
3fa16ca
4df08c3
d3a642c
778476b
c443164
e018d18
9c18a08
ede4b2e
23836dc
8eaf94a
02c1097
f1b3424
58e84cc
a34e181
398ae86
8fe5c18
8fc8fef
498953f
b0ec4e2
6d0bc2b
c88da97
2701d57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,6 @@ build/ | |
dist/ | ||
*.so | ||
out*/ | ||
|
||
# Training artifacts | ||
wandb/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import Any | ||
|
||
from brax import envs | ||
|
||
from ksim.mjx_gym.envs.default_humanoid_env.default_humanoid import DefaultHumanoidEnv | ||
from ksim.mjx_gym.envs.stompy_env.stompy import StompyEnv | ||
|
||
environments = {"default_humanoid": DefaultHumanoidEnv, "stompy": StompyEnv} | ||
|
||
|
||
def get_env(name: str, **kwargs: Any) -> envs.Env: # noqa: ANN401 | ||
envs.register_environment(name, environments[name]) | ||
return envs.get_environment(name, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
"""Defines the default humanoid environment.""" | ||
|
||
from typing import NotRequired, TypedDict, Unpack | ||
|
||
import jax | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding one line explanation would be useful. |
||
import jax.numpy as jp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isort |
||
import mujoco | ||
from brax import base | ||
from brax.envs.base import PipelineEnv, State | ||
from brax.io import mjcf | ||
from brax.mjx.base import State as mjxState | ||
from etils import epath | ||
|
||
from ksim.mjx_gym.envs.default_humanoid_env.rewards import ( | ||
DEFAULT_REWARD_PARAMS, | ||
RewardParams, | ||
get_reward_fn, | ||
) | ||
|
||
|
||
class EnvKwargs(TypedDict): | ||
sys: base.System | ||
backend: NotRequired[str] | ||
n_frames: NotRequired[int] | ||
debug: NotRequired[bool] | ||
|
||
|
||
class DefaultHumanoidEnv(PipelineEnv): | ||
"""An environment for humanoid body position, velocities, and angles. | ||
|
||
Note: This environment is based on the default humanoid environment in the Brax library. | ||
https://github.com/google/brax/blob/main/brax/envs/humanoid.py | ||
|
||
However, this environment is designed to work with modular reward functions, allowing for quicker experimentation. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
reward_params: RewardParams = DEFAULT_REWARD_PARAMS, | ||
terminate_when_unhealthy: bool = True, | ||
reset_noise_scale: float = 1e-2, | ||
exclude_current_positions_from_observation: bool = True, | ||
log_reward_breakdown: bool = True, | ||
**kwargs: Unpack[EnvKwargs], | ||
) -> None: | ||
path = epath.Path(epath.resource_path("mujoco")) / ("mjx/test_data/humanoid") | ||
mj_model = mujoco.MjModel.from_xml_path((path / "humanoid.xml").as_posix()) | ||
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG | ||
mj_model.opt.iterations = 6 | ||
mj_model.opt.ls_iterations = 6 | ||
|
||
sys = mjcf.load_model(mj_model) | ||
|
||
physics_steps_per_control_step = 4 # Should find way to perturb this value in the future | ||
kwargs["n_frames"] = kwargs.get("n_frames", physics_steps_per_control_step) | ||
kwargs["backend"] = "mjx" | ||
|
||
super().__init__(sys, **kwargs) | ||
|
||
self._reward_params = reward_params | ||
self._terminate_when_unhealthy = terminate_when_unhealthy | ||
self._reset_noise_scale = reset_noise_scale | ||
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation | ||
self._log_reward_breakdown = log_reward_breakdown | ||
|
||
self.reward_fn = get_reward_fn(self._reward_params, self.dt, include_reward_breakdown=True) | ||
|
||
def reset(self, rng: jp.ndarray) -> State: | ||
"""Resets the environment to an initial state. | ||
|
||
Args: | ||
rng: Random number generator seed. | ||
|
||
Returns: | ||
The initial state of the environment. | ||
""" | ||
rng, rng1, rng2 = jax.random.split(rng, 3) | ||
|
||
low, hi = -self._reset_noise_scale, self._reset_noise_scale | ||
qpos = self.sys.qpos0 + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi) | ||
qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi) | ||
|
||
mjx_state = self.pipeline_init(qpos, qvel) | ||
assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}" | ||
|
||
obs = self._get_obs(mjx_state, jp.zeros(self.sys.nu)) | ||
reward, done, zero = jp.zeros(3) | ||
metrics = { | ||
"x_position": zero, | ||
"y_position": zero, | ||
"distance_from_origin": zero, | ||
"x_velocity": zero, | ||
"y_velocity": zero, | ||
} | ||
for key in self._reward_params.keys(): | ||
metrics[key] = zero | ||
|
||
return State(mjx_state, obs, reward, done, metrics) | ||
|
||
def step(self, state: State, action: jp.ndarray) -> State: | ||
"""Runs one timestep of the environment's dynamics. | ||
|
||
Args: | ||
state: The current state of the environment. | ||
action: The action to take. | ||
|
||
Returns: | ||
A tuple of the next state, the reward, whether the episode has ended, and additional information. | ||
""" | ||
mjx_state = state.pipeline_state | ||
assert mjx_state, "state.pipeline_state was recorded as None" | ||
# TODO: determine whether to raise an error or reset the environment | ||
|
||
next_mjx_state = self.pipeline_step(mjx_state, action) | ||
|
||
assert type(next_mjx_state) == mjxState, f"next_mjx_state is of type {type(next_mjx_state)}" | ||
assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}" | ||
# mlutz: from what I've seen, .pipeline_state and .pipeline_step(...) | ||
# actually return an brax.mjx.base.State object however, the type | ||
# hinting suggests that it should return a brax.base.State object | ||
# brax.mjx.base.State inherits from brax.base.State but also inherits | ||
# from mjx.Data, which is needed for some rewards | ||
|
||
obs = self._get_obs(mjx_state, action) | ||
reward, is_healthy, reward_breakdown = self.reward_fn(mjx_state, action, next_mjx_state) | ||
|
||
if self._terminate_when_unhealthy: | ||
done = 1.0 - is_healthy | ||
else: | ||
done = jp.array(0) | ||
|
||
state.metrics.update( | ||
x_position=next_mjx_state.subtree_com[1][0], | ||
y_position=next_mjx_state.subtree_com[1][1], | ||
distance_from_origin=jp.linalg.norm(next_mjx_state.subtree_com[1]), | ||
x_velocity=(next_mjx_state.subtree_com[1][0] - mjx_state.subtree_com[1][0]) / self.dt, | ||
y_velocity=(next_mjx_state.subtree_com[1][1] - mjx_state.subtree_com[1][1]) / self.dt, | ||
) | ||
|
||
if self._log_reward_breakdown: | ||
for key, val in reward_breakdown.items(): | ||
state.metrics[key] = val | ||
|
||
# TODO: fix the type hinting... | ||
return state.replace( | ||
pipeline_state=next_mjx_state, | ||
obs=obs, | ||
reward=reward, | ||
done=done, | ||
) | ||
|
||
def _get_obs(self, data: mjxState, action: jp.ndarray) -> jp.ndarray: | ||
"""Observes humanoid body position, velocities, and angles. | ||
|
||
Args: | ||
data: The current state of the environment. | ||
action: The current action. | ||
|
||
Returns: | ||
Observations of the environment. | ||
""" | ||
position = data.qpos | ||
if self._exclude_current_positions_from_observation: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this is needed? |
||
position = position[2:] | ||
|
||
# external_contact_forces are excluded | ||
return jp.concatenate( | ||
[ | ||
position, | ||
data.qvel, | ||
data.cinert[1:].ravel(), | ||
data.cvel[1:].ravel(), | ||
data.qfrc_actuator, | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can omit putting stuff here, that would be preferable.