Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MJX Training Implementation #1

Merged
merged 25 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3592d06
feat: created new default humanoid environment class implementing bra…
michael-lutz May 22, 2024
033d01f
feat: ran initial training with new MJX environment
michael-lutz May 22, 2024
31c1b35
feat: added nicer training script
michael-lutz May 22, 2024
3fa16ca
feat: added training config for stompy
michael-lutz May 22, 2024
4df08c3
feat: storing weight checkpoints and added play
michael-lutz May 22, 2024
d3a642c
feat: added stomppy environment
michael-lutz May 22, 2024
778476b
feat: added model checkpointing
michael-lutz May 22, 2024
c443164
feat: printing when rendering
michael-lutz May 22, 2024
e018d18
feat: removed unnecessary model checkpointing
michael-lutz May 22, 2024
9c18a08
chore: cleaned up play script and added CPU-only rendering
michael-lutz May 23, 2024
ede4b2e
chore: removed brax replication
michael-lutz May 23, 2024
23836dc
fix: fixing stompy environment import issues and created simplified mesh
michael-lutz May 23, 2024
8eaf94a
chore: sorting libraries, typing, etc
michael-lutz May 23, 2024
02c1097
chore: cleaned up training scripts and file organization
michael-lutz May 23, 2024
f1b3424
chore: removed unused imports
michael-lutz May 23, 2024
58e84cc
fix: removing duplicate exclude definition in pyproject
michael-lutz May 23, 2024
a34e181
fix: moved mjx_gym to ksim folder
michael-lutz May 23, 2024
398ae86
chore: removing duplicate folders
michael-lutz May 23, 2024
8fe5c18
feat: updated the readme with relevant getting-started information
michael-lutz May 23, 2024
8fc8fef
chore: fixed spacing and updated setup.py
michael-lutz May 23, 2024
498953f
chore: updating pyproject
michael-lutz May 24, 2024
b0ec4e2
thing
michael-lutz May 24, 2024
6d0bc2b
fix some types
codekansas May 24, 2024
c88da97
fix typing
codekansas May 24, 2024
2701d57
chore: fixing import ordering
michael-lutz May 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ build/
dist/
*.so
out*/

# Training artifacts
wandb/
12 changes: 6 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ clean:
py-files := $(shell find . -name '*.py')

format:
@black $(py-files)
@ruff format $(py-files)
@black ksim
@ruff format ksim
.PHONY: format

format-cpp:
Expand All @@ -59,13 +59,13 @@ format-cpp:
.PHONY: format-cpp

static-checks:
@black --diff --check $(py-files)
@ruff check $(py-files)
@mypy --install-types --non-interactive $(py-files)
@black --diff --check ksim
@ruff check ksim
@mypy --install-types --non-interactive ksim
.PHONY: lint

mypy-daemon:
@dmypy run -- $(py-files)
@dmypy run -- ksim
.PHONY: mypy-daemon

# ------------------------ #
Expand Down
102 changes: 101 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,104 @@

# K-Scale Sim

A library for simulating Stompy in MuJoCo.
A library for simulating Stompy in MJX and MuJoCo.

## Installation
1. Clone this repository:
```bash
git clone https://github.com/kscalelabs/ksim.git
cd ksim
```

2. It is recommended that you use a virtual environment to install the dependencies for this project. You can create a new conda environment using the following command:
```bash
conda create --name ksim python=3.11
conda activate ksim
```
To install the dependencies, run the following command:

```bash
make install-dev
```

## MJX Gym Usage
MJX Gym is a library for training and evaluating reinforcement learning agents in MJX environments. It is built on top of the Brax library and provides a simple interface for running experiments with Stompy and other humanoid formfactors. Currently, we support walking but plan on adding more tasks and simulator environments in the future.



### Training
For quick experimentation, you may specify all relevant training configurations via YAML files, and simply run train.py with the desired configuration. Our configurations allow for rapid reward function prototyping, environment specification, and hyperparameter tuning. Additionally, all training results are tracked and logged using Weights & Biases.

We recommend starting with the default humanoid environment to get a feel for the simulator. To train the default humanoid environment, first navigate to the /mjx_gym directory:
```bash
cd ksim/mjx_gym
```
Then, run the following command:
```bash
python train.py --config experiments/default_humanoid_walk.yaml
```

Example training curves are shown below:
<p align="center">
<img alt="K-Scale Open Source Robotics" src="https://private-user-images.githubusercontent.com/43460304/333030841-5a4bd24e-0478-4cbb-bb92-18cf5e9c72b7.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTY1MDc4ODcsIm5iZiI6MTcxNjUwNzU4NywicGF0aCI6Ii80MzQ2MDMwNC8zMzMwMzA4NDEtNWE0YmQyNGUtMDQ3OC00Y2JiLWJiOTItMThjZjVlOWM3MmI3LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA1MjMlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNTIzVDIzMzk0N1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFiZTQyNDFlNjViN2M0MDRiMjgyZWRkYWI3ZmQxYWU3NDM4MDViZDVlYmJlY2NiNTIwZDU2NGVmNWE2YjU1NjcmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.Yi2HzUGt3HYivMyGTuRFSRjsZviZRK2dFEx_KtZyEyA" width="400" />
<img alt="K-Scale Open Source Robotics" src="https://private-user-images.githubusercontent.com/43460304/333030844-341f4a5c-95ea-49ed-b102-df4e00868c5a.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTY1MDc4ODcsIm5iZiI6MTcxNjUwNzU4NywicGF0aCI6Ii80MzQ2MDMwNC8zMzMwMzA4NDQtMzQxZjRhNWMtOTVlYS00OWVkLWIxMDItZGY0ZTAwODY4YzVhLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA1MjMlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNTIzVDIzMzk0N1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTcxYzc5MzlkMGRiZGJkMmE2ZjYwY2FhMmZjZGQ3NTM4YWNlOTVkMTU4MjBjZmM2ZmRmZmE1OTkxYTlmYTJkZDImWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.01D5U4QD8JooJwEDtcw9OKt1tmM_zGYj5U7dDstLC84" width="400" >
</p>


### Testing and Rendering
We provide an easy way to test and render the trained model using the play.py script. This script loads the trained model and runs it in the specified environment.

To test and render the trained model, run the following command:
```bash
python play.py --config experiments/default_humanoid_walk.yaml
```

Here is an example of the trained model walking in the default humanoid environment after roughly 15 minutes of training on the K-Scale cluster.

<video controls>
<source alt="K-Scale Open Source Robotics" src="https://github.com/kscalelabs/sim/assets/43460304/8e12b0e6-48ea-4af0-8283-1dc4880767b4" type="video/mp4">
</video>

You may transfer the model (trained via MJX) to MuJoCo for CPU-based simulation via arguments to the play script. For example, to run the model in the default humanoid environment, use the following command:
```bash
python play.py --config experiments/default_humanoid_walk.yaml --use_mujoco
```

It is important that the configuration file used for testing is the **same as the one used for training.** This ensures that the model is loaded correctly and avoids any catastrophic failures during testing. For simplicity, we have included default training weights in the /mjx_gym/weights directory, which are automatically used when running the play script with the command above. However, you may specify a different weight file using the --params_path argument.

Here is an example of the trained model (in MJX) walking in the default humanoid environment in MuJoCo.

<video controls>
<source alt="K-Scale Open Source Robotics" src="https://github.com/kscalelabs/sim/assets/43460304/7f158aeb-6bc9-4056-bd1d-12882adbd13c" type="video/mp4">
</video>

While the model performs well in the MJX environment, it is important to note that the model's performance may vary when transferred to MuJoCo. This is due to slight differences in the simulators and the underlying physics engines. Evaluating the model in both simulators is helpful for determining whether the model will generalize well across different environments.

### Common Issues
You might see the following error when running train.py:
>An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

One way to fix this issue is to uninstall jax and jaxlib and reinstall them with your specific CUDA version the following commands:
```bash
pip uninstall jax jaxlib
pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
Note that the CUDA version should match the one installed on your machine.

A similar issue might also occur:
> jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

This error is usually caused by an incorrect version of cudnn installed in your environment. To fix this issue, you can try installing cuDNN using the following command:
```bash
conda install cudnn
```

Another common issue occurs when rendering on a headless server. For now, we recommend rendering locally or using a remote desktop connection to view the rendering.

## TODO
- [ ] Get Stompy to load efficiently with MJX (currently, the meshes and collision detection are not loading correctly)
- [ ] Add goal conditioning to the humanoid walk task
- [ ] Implement imitation learning techniques for end-to-end walking and recovering tasks
- [ ] Constantly iterate on the reward functions to improve training efficiency
- [ ] Add CPU-based MuJoCo training platform for improved sim-to-sim support
- [ ] Add sim-to-real transfer learning techniques during training
Empty file added ksim/mjx_gym/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions ksim/mjx_gym/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any

from brax import envs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can omit putting stuff here, that would be preferable.


from ksim.mjx_gym.envs.default_humanoid_env.default_humanoid import DefaultHumanoidEnv
from ksim.mjx_gym.envs.stompy_env.stompy import StompyEnv

environments = {"default_humanoid": DefaultHumanoidEnv, "stompy": StompyEnv}


def get_env(name: str, **kwargs: Any) -> envs.Env: # noqa: ANN401
envs.register_environment(name, environments[name])
return envs.get_environment(name, **kwargs)
Empty file.
175 changes: 175 additions & 0 deletions ksim/mjx_gym/envs/default_humanoid_env/default_humanoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Defines the default humanoid environment."""

from typing import NotRequired, TypedDict, Unpack

import jax
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding one line explanation would be useful.

import jax.numpy as jp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isort

import mujoco
from brax import base
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
from brax.mjx.base import State as mjxState
from etils import epath

from ksim.mjx_gym.envs.default_humanoid_env.rewards import (
DEFAULT_REWARD_PARAMS,
RewardParams,
get_reward_fn,
)


class EnvKwargs(TypedDict):
sys: base.System
backend: NotRequired[str]
n_frames: NotRequired[int]
debug: NotRequired[bool]


class DefaultHumanoidEnv(PipelineEnv):
"""An environment for humanoid body position, velocities, and angles.

Note: This environment is based on the default humanoid environment in the Brax library.
https://github.com/google/brax/blob/main/brax/envs/humanoid.py

However, this environment is designed to work with modular reward functions, allowing for quicker experimentation.
"""

def __init__(
self,
reward_params: RewardParams = DEFAULT_REWARD_PARAMS,
terminate_when_unhealthy: bool = True,
reset_noise_scale: float = 1e-2,
exclude_current_positions_from_observation: bool = True,
log_reward_breakdown: bool = True,
**kwargs: Unpack[EnvKwargs],
) -> None:
path = epath.Path(epath.resource_path("mujoco")) / ("mjx/test_data/humanoid")
mj_model = mujoco.MjModel.from_xml_path((path / "humanoid.xml").as_posix())
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6

sys = mjcf.load_model(mj_model)

physics_steps_per_control_step = 4 # Should find way to perturb this value in the future
kwargs["n_frames"] = kwargs.get("n_frames", physics_steps_per_control_step)
kwargs["backend"] = "mjx"

super().__init__(sys, **kwargs)

self._reward_params = reward_params
self._terminate_when_unhealthy = terminate_when_unhealthy
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
self._log_reward_breakdown = log_reward_breakdown

self.reward_fn = get_reward_fn(self._reward_params, self.dt, include_reward_breakdown=True)

def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state.

Args:
rng: Random number generator seed.

Returns:
The initial state of the environment.
"""
rng, rng1, rng2 = jax.random.split(rng, 3)

low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.qpos0 + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi)
qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi)

mjx_state = self.pipeline_init(qpos, qvel)
assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}"

obs = self._get_obs(mjx_state, jp.zeros(self.sys.nu))
reward, done, zero = jp.zeros(3)
metrics = {
"x_position": zero,
"y_position": zero,
"distance_from_origin": zero,
"x_velocity": zero,
"y_velocity": zero,
}
for key in self._reward_params.keys():
metrics[key] = zero

return State(mjx_state, obs, reward, done, metrics)

def step(self, state: State, action: jp.ndarray) -> State:
"""Runs one timestep of the environment's dynamics.

Args:
state: The current state of the environment.
action: The action to take.

Returns:
A tuple of the next state, the reward, whether the episode has ended, and additional information.
"""
mjx_state = state.pipeline_state
assert mjx_state, "state.pipeline_state was recorded as None"
# TODO: determine whether to raise an error or reset the environment

next_mjx_state = self.pipeline_step(mjx_state, action)

assert type(next_mjx_state) == mjxState, f"next_mjx_state is of type {type(next_mjx_state)}"
assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}"
# mlutz: from what I've seen, .pipeline_state and .pipeline_step(...)
# actually return an brax.mjx.base.State object however, the type
# hinting suggests that it should return a brax.base.State object
# brax.mjx.base.State inherits from brax.base.State but also inherits
# from mjx.Data, which is needed for some rewards

obs = self._get_obs(mjx_state, action)
reward, is_healthy, reward_breakdown = self.reward_fn(mjx_state, action, next_mjx_state)

if self._terminate_when_unhealthy:
done = 1.0 - is_healthy
else:
done = jp.array(0)

state.metrics.update(
x_position=next_mjx_state.subtree_com[1][0],
y_position=next_mjx_state.subtree_com[1][1],
distance_from_origin=jp.linalg.norm(next_mjx_state.subtree_com[1]),
x_velocity=(next_mjx_state.subtree_com[1][0] - mjx_state.subtree_com[1][0]) / self.dt,
y_velocity=(next_mjx_state.subtree_com[1][1] - mjx_state.subtree_com[1][1]) / self.dt,
)

if self._log_reward_breakdown:
for key, val in reward_breakdown.items():
state.metrics[key] = val

# TODO: fix the type hinting...
return state.replace(
pipeline_state=next_mjx_state,
obs=obs,
reward=reward,
done=done,
)

def _get_obs(self, data: mjxState, action: jp.ndarray) -> jp.ndarray:
"""Observes humanoid body position, velocities, and angles.

Args:
data: The current state of the environment.
action: The current action.

Returns:
Observations of the environment.
"""
position = data.qpos
if self._exclude_current_positions_from_observation:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is needed?

position = position[2:]

# external_contact_forces are excluded
return jp.concatenate(
[
position,
data.qvel,
data.cinert[1:].ravel(),
data.cvel[1:].ravel(),
data.qfrc_actuator,
]
)
Loading
Loading