Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MJX Training Pipeline #14

Closed
wants to merge 18 commits into from
Closed

MJX Training Pipeline #14

wants to merge 18 commits into from

Conversation

michael-lutz
Copy link

This PR introduces a new way to massively scale up locomotion training. Building upon Brax, it ultimately uses the MJX physics engine for simulation.

Structure

Specifically, this PR includes the following directories:

  • Envs
  • Experiments
  • Utils
  • (example) Weights
  • train.py
  • play.py

Envs includes two types of Brax environments: DefaultHumanoidEnv and StompyEnv. Each environment includes a main class which implements the Brax environment interface and utilizes MJX for all physics calculations. One important thing to note is that reward functions are modular, allowing for quick experimentation.

Experiments includes two .yaml files that include sample configurations for model training.

Utils include default values, rendering rollouts, etc.

Weights currently include default humanoid weights (for locomotion) that should work out of the box

train.py and play.py both integrate with wandb. train.py utilizes the Brax implementation of PPO for now, but can be easily customized if needed.

Performance Samples

Training Curves
Screen Shot 2024-05-22 at 9 12 45 PM
Screen Shot 2024-05-22 at 9 14 44 PM

Example humanoid robot walking in MJX
https://github.com/kscalelabs/sim/assets/43460304/8e12b0e6-48ea-4af0-8283-1dc4880767b4

Humanoid trained in MJX, eval in CPU-based MuJoCo
https://github.com/kscalelabs/sim/assets/43460304/7f158aeb-6bc9-4056-bd1d-12882adbd13c

@michael-lutz michael-lutz added the enhancement New feature or request label May 23, 2024
@michael-lutz michael-lutz self-assigned this May 23, 2024
Copy link
Collaborator

@budzianowski budzianowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move weights to mjx_gym/tests

Copy link
Collaborator

@budzianowski budzianowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Couple cleaning comments. Was stompy env tested at all? If not, let's add it in the next PR since this one is already getting big.

@@ -0,0 +1,13 @@
from brax import envs

from .default_humanoid_env.default_humanoid import DefaultHumanoidEnv
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid relative imports

@@ -0,0 +1,150 @@
import jax
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isort ruff et al., see Makefile for formatting setup

from mujoco import mjx
from etils import epath
from .rewards import get_reward_fn
from utils.default import DEFAULT_REWARD_PARAMS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move default reward params to rewards.py

from mujoco import mjx
from etils import epath
import os
from .rewards import get_reward_fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid relative imports - sim.mjx_gym.envs.rewards

self.reward_fn = get_reward_fn(self._reward_params, self.dt, include_reward_breakdown=True)

def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the formatting looks off.

model_path = "weights/" + config.get('project_name', 'model') + ".pkl"
params = model.load_params(model_path)
normalize = lambda x, y: x
if config.get('normalize_observations', False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment what's going here


# rolling out a trajectory
render_every = 2
n_steps = 1000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a parameter

# rolling out a trajectory
render_every = 2
n_steps = 1000
if args.use_mujoco:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there actually any difference?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question!

On performance: MJX takes a while to initialize, very quick to run on GPU. MuJoCo has a quicker "cold start" equivalent, but actual rollouts are slightly slower.

On model performance: due to slightly different physical dynamics, models are slightly less performant on MuJoCo if trained on MJX. See the second vs first video included in the description of the PR.

images = render_mjx_rollout(env, inference_fn, n_steps, render_every)
print(f'Rolled out {len(images)} steps')

# render the trajectory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it optional

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was envisioning this script to be dedicated primarily to rendering, as we can access all other metrics directly in training logs I believe. Do you think it makes sense to make optional?

ydataerr = []
times = [datetime.now()]

max_y, min_y = 13000, 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is used anywhere?

@budzianowski budzianowski self-requested a review May 23, 2024 20:47
Copy link
Collaborator

@budzianowski budzianowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nits but looks good for a start!

"""Resets the environment to an initial state.

Args:
rng: Random number generator seed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - formatting still off here

Suggested change
rng: Random number generator seed.
Args:
rng: Random number generator seed.

@@ -0,0 +1,111 @@
from typing import Callable, Dict, Tuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good practice is to add even one line comment what the file is about.

}


def get_reward_fn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be used across all envs?

Returns:
A float wrapped in a jax array.
"""
xpos = state.subtree_com[1][0] # TODO: include stricter typing than mjxState to avoid this type error
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it more explicit what you are loading

velocity = (next_xpos - xpos) / dt
forward_reward = params["weight"] * velocity

return forward_reward, jp.array(1.0) # TODO: ensure everything is initialized in a size 2 array instead...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the logic behind 1.0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jp.array because we want to keep everything in jax. The 1.0 itself operates like a boolean operator (doesn't change the "healthiness" until a 0 comes around)

def healthy_reward_fn(
state: mjxState, action: jp.ndarray, next_state: mjxState, dt: jax.Array, params: Dict[str, float]
) -> Tuple[jp.ndarray, jp.ndarray]:
"""Reward function for staying healthy.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename upright _reward

Resets the environment to an initial state.

Args:
rng: Random number generator seed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - weird tabs

@@ -0,0 +1,81 @@
import argparse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add example how to run it

@@ -0,0 +1,80 @@
import argparse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add example how to run it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants