Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MJX Training Implementation #1

Merged
merged 25 commits into from
May 24, 2024
Merged

MJX Training Implementation #1

merged 25 commits into from
May 24, 2024

Conversation

michael-lutz
Copy link
Contributor

The following is copied from kscalelabs/sim#14

This PR introduces a new way to massively scale up locomotion training. Building upon Brax, it ultimately uses the MJX physics engine for simulation.

Structure

Specifically, this PR includes the following directories:

  • Envs
  • Experiments
  • Utils
  • (example) Weights
  • train.py
  • play.py

Envs includes two types of Brax environments: DefaultHumanoidEnv and StompyEnv. Each environment includes a main class which implements the Brax environment interface and utilizes MJX for all physics calculations. One important thing to note is that reward functions are modular, allowing for quick experimentation.

Experiments includes two .yaml files that include sample configurations for model training.

Utils include default values, rendering rollouts, etc.

Weights currently include default humanoid weights (for locomotion) that should work out of the box

train.py and play.py both integrate with wandb. train.py utilizes the Brax implementation of PPO for now, but can be easily customized if needed.

Performance Samples

Training Curves
Screen Shot 2024-05-22 at 9 12 45 PM
Screen Shot 2024-05-22 at 9 14 44 PM

Example humanoid robot walking in MJX
https://github.com/kscalelabs/sim/assets/43460304/8e12b0e6-48ea-4af0-8283-1dc4880767b4

Humanoid trained in MJX, eval in CPU-based MuJoCo
https://github.com/kscalelabs/sim/assets/43460304/7f158aeb-6bc9-4056-bd1d-12882adbd13c

@michael-lutz michael-lutz self-assigned this May 24, 2024
@michael-lutz michael-lutz added the enhancement New feature or request label May 24, 2024
Copy link
Collaborator

@budzianowski budzianowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of nits, lgtm otherwise!

@@ -0,0 +1,10 @@
from brax import envs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can omit putting stuff here, that would be preferable.

@@ -0,0 +1,152 @@
import jax
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding one line explanation would be useful.

@@ -0,0 +1,152 @@
import jax
import jax.numpy as jp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isort

Observations of the environment.
"""
position = data.qpos
if self._exclude_current_positions_from_observation:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is needed?

Copy link
Member

@codekansas codekansas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@michael-lutz michael-lutz enabled auto-merge (squash) May 24, 2024 02:25
@michael-lutz michael-lutz disabled auto-merge May 24, 2024 02:26
@michael-lutz michael-lutz merged commit 84606b8 into master May 24, 2024
1 check passed
@michael-lutz michael-lutz deleted the transfer-branch branch May 24, 2024 02:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants