-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MJX Training Pipeline #14
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move weights to mjx_gym/tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Couple cleaning comments. Was stompy env tested at all? If not, let's add it in the next PR since this one is already getting big.
sim/mjx_gym/envs/__init__.py
Outdated
@@ -0,0 +1,13 @@ | |||
from brax import envs | |||
|
|||
from .default_humanoid_env.default_humanoid import DefaultHumanoidEnv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avoid relative imports
@@ -0,0 +1,150 @@ | |||
import jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isort ruff et al., see Makefile for formatting setup
from mujoco import mjx | ||
from etils import epath | ||
from .rewards import get_reward_fn | ||
from utils.default import DEFAULT_REWARD_PARAMS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move default reward params to rewards.py
from mujoco import mjx | ||
from etils import epath | ||
import os | ||
from .rewards import get_reward_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avoid relative imports - sim.mjx_gym.envs.rewards
self.reward_fn = get_reward_fn(self._reward_params, self.dt, include_reward_breakdown=True) | ||
|
||
def reset(self, rng: jp.ndarray) -> State: | ||
"""Resets the environment to an initial state. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the formatting looks off.
sim/mjx_gym/play.py
Outdated
model_path = "weights/" + config.get('project_name', 'model') + ".pkl" | ||
params = model.load_params(model_path) | ||
normalize = lambda x, y: x | ||
if config.get('normalize_observations', False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment what's going here
sim/mjx_gym/play.py
Outdated
|
||
# rolling out a trajectory | ||
render_every = 2 | ||
n_steps = 1000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this a parameter
sim/mjx_gym/play.py
Outdated
# rolling out a trajectory | ||
render_every = 2 | ||
n_steps = 1000 | ||
if args.use_mujoco: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there actually any difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question!
On performance: MJX takes a while to initialize, very quick to run on GPU. MuJoCo has a quicker "cold start" equivalent, but actual rollouts are slightly slower.
On model performance: due to slightly different physical dynamics, models are slightly less performant on MuJoCo if trained on MJX. See the second vs first video included in the description of the PR.
sim/mjx_gym/play.py
Outdated
images = render_mjx_rollout(env, inference_fn, n_steps, render_every) | ||
print(f'Rolled out {len(images)} steps') | ||
|
||
# render the trajectory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make it optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was envisioning this script to be dedicated primarily to rendering, as we can access all other metrics directly in training logs I believe. Do you think it makes sense to make optional?
sim/mjx_gym/train.py
Outdated
ydataerr = [] | ||
times = [datetime.now()] | ||
|
||
max_y, min_y = 13000, 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is used anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small nits but looks good for a start!
"""Resets the environment to an initial state. | ||
|
||
Args: | ||
rng: Random number generator seed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - formatting still off here
rng: Random number generator seed. | |
Args: | |
rng: Random number generator seed. |
@@ -0,0 +1,111 @@ | |||
from typing import Callable, Dict, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good practice is to add even one line comment what the file is about.
} | ||
|
||
|
||
def get_reward_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be used across all envs?
Returns: | ||
A float wrapped in a jax array. | ||
""" | ||
xpos = state.subtree_com[1][0] # TODO: include stricter typing than mjxState to avoid this type error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make it more explicit what you are loading
velocity = (next_xpos - xpos) / dt | ||
forward_reward = params["weight"] * velocity | ||
|
||
return forward_reward, jp.array(1.0) # TODO: ensure everything is initialized in a size 2 array instead... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the logic behind 1.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jp.array because we want to keep everything in jax. The 1.0 itself operates like a boolean operator (doesn't change the "healthiness" until a 0 comes around)
def healthy_reward_fn( | ||
state: mjxState, action: jp.ndarray, next_state: mjxState, dt: jax.Array, params: Dict[str, float] | ||
) -> Tuple[jp.ndarray, jp.ndarray]: | ||
"""Reward function for staying healthy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename upright _reward
Resets the environment to an initial state. | ||
|
||
Args: | ||
rng: Random number generator seed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - weird tabs
@@ -0,0 +1,81 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add example how to run it
@@ -0,0 +1,80 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add example how to run it
This PR introduces a new way to massively scale up locomotion training. Building upon Brax, it ultimately uses the MJX physics engine for simulation.
Structure
Specifically, this PR includes the following directories:
Envs
includes two types of Brax environments: DefaultHumanoidEnv and StompyEnv. Each environment includes a main class which implements the Brax environment interface and utilizes MJX for all physics calculations. One important thing to note is that reward functions are modular, allowing for quick experimentation.Experiments
includes two .yaml files that include sample configurations for model training.Utils
include default values, rendering rollouts, etc.Weights
currently include default humanoid weights (for locomotion) that should work out of the boxtrain.py and play.py both integrate with wandb. train.py utilizes the Brax implementation of PPO for now, but can be easily customized if needed.
Performance Samples
Training Curves
Example humanoid robot walking in MJX
https://github.com/kscalelabs/sim/assets/43460304/8e12b0e6-48ea-4af0-8283-1dc4880767b4
Humanoid trained in MJX, eval in CPU-based MuJoCo
https://github.com/kscalelabs/sim/assets/43460304/7f158aeb-6bc9-4056-bd1d-12882adbd13c