-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MJX Training Implementation #1
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of nits, lgtm otherwise!
@@ -0,0 +1,10 @@ | |||
from brax import envs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can omit putting stuff here, that would be preferable.
@@ -0,0 +1,152 @@ | |||
import jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding one line explanation would be useful.
@@ -0,0 +1,152 @@ | |||
import jax | |||
import jax.numpy as jp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isort
Observations of the environment. | ||
""" | ||
position = data.qpos | ||
if self._exclude_current_positions_from_observation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
The following is copied from kscalelabs/sim#14
This PR introduces a new way to massively scale up locomotion training. Building upon Brax, it ultimately uses the MJX physics engine for simulation.
Structure
Specifically, this PR includes the following directories:
Envs
includes two types of Brax environments: DefaultHumanoidEnv and StompyEnv. Each environment includes a main class which implements the Brax environment interface and utilizes MJX for all physics calculations. One important thing to note is that reward functions are modular, allowing for quick experimentation.Experiments
includes two .yaml files that include sample configurations for model training.Utils
include default values, rendering rollouts, etc.Weights
currently include default humanoid weights (for locomotion) that should work out of the boxtrain.py and play.py both integrate with wandb. train.py utilizes the Brax implementation of PPO for now, but can be easily customized if needed.
Performance Samples
Training Curves
Example humanoid robot walking in MJX
https://github.com/kscalelabs/sim/assets/43460304/8e12b0e6-48ea-4af0-8283-1dc4880767b4
Humanoid trained in MJX, eval in CPU-based MuJoCo
https://github.com/kscalelabs/sim/assets/43460304/7f158aeb-6bc9-4056-bd1d-12882adbd13c