diff --git a/.gitignore b/.gitignore index ae8e28a..74955e0 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ build/ dist/ *.so out*/ + +# Training artifacts +wandb/ \ No newline at end of file diff --git a/Makefile b/Makefile index ef07aab..efacbcb 100644 --- a/Makefile +++ b/Makefile @@ -49,8 +49,8 @@ clean: py-files := $(shell find . -name '*.py') format: - @black $(py-files) - @ruff format $(py-files) + @black ksim + @ruff format ksim .PHONY: format format-cpp: @@ -59,13 +59,13 @@ format-cpp: .PHONY: format-cpp static-checks: - @black --diff --check $(py-files) - @ruff check $(py-files) - @mypy --install-types --non-interactive $(py-files) + @black --diff --check ksim + @ruff check ksim + @mypy --install-types --non-interactive ksim .PHONY: lint mypy-daemon: - @dmypy run -- $(py-files) + @dmypy run -- ksim .PHONY: mypy-daemon # ------------------------ # diff --git a/README.md b/README.md index a9f7ade..79f232f 100644 --- a/README.md +++ b/README.md @@ -21,4 +21,104 @@ # K-Scale Sim -A library for simulating Stompy in MuJoCo. +A library for simulating Stompy in MJX and MuJoCo. + +## Installation +1. Clone this repository: +```bash +git clone https://github.com/kscalelabs/ksim.git +cd ksim +``` + +2. It is recommended that you use a virtual environment to install the dependencies for this project. You can create a new conda environment using the following command: +```bash +conda create --name ksim python=3.11 +conda activate ksim +``` +To install the dependencies, run the following command: + +```bash +make install-dev +``` + +## MJX Gym Usage +MJX Gym is a library for training and evaluating reinforcement learning agents in MJX environments. It is built on top of the Brax library and provides a simple interface for running experiments with Stompy and other humanoid formfactors. Currently, we support walking but plan on adding more tasks and simulator environments in the future. + + + +### Training +For quick experimentation, you may specify all relevant training configurations via YAML files, and simply run train.py with the desired configuration. Our configurations allow for rapid reward function prototyping, environment specification, and hyperparameter tuning. Additionally, all training results are tracked and logged using Weights & Biases. + +We recommend starting with the default humanoid environment to get a feel for the simulator. To train the default humanoid environment, first navigate to the /mjx_gym directory: +```bash +cd ksim/mjx_gym +``` +Then, run the following command: +```bash +python train.py --config experiments/default_humanoid_walk.yaml +``` + +Example training curves are shown below: +
+ + +
+ + +### Testing and Rendering +We provide an easy way to test and render the trained model using the play.py script. This script loads the trained model and runs it in the specified environment. + +To test and render the trained model, run the following command: +```bash +python play.py --config experiments/default_humanoid_walk.yaml +``` + +Here is an example of the trained model walking in the default humanoid environment after roughly 15 minutes of training on the K-Scale cluster. + + + +You may transfer the model (trained via MJX) to MuJoCo for CPU-based simulation via arguments to the play script. For example, to run the model in the default humanoid environment, use the following command: +```bash +python play.py --config experiments/default_humanoid_walk.yaml --use_mujoco +``` + +It is important that the configuration file used for testing is the **same as the one used for training.** This ensures that the model is loaded correctly and avoids any catastrophic failures during testing. For simplicity, we have included default training weights in the /mjx_gym/weights directory, which are automatically used when running the play script with the command above. However, you may specify a different weight file using the --params_path argument. + +Here is an example of the trained model (in MJX) walking in the default humanoid environment in MuJoCo. + + + +While the model performs well in the MJX environment, it is important to note that the model's performance may vary when transferred to MuJoCo. This is due to slight differences in the simulators and the underlying physics engines. Evaluating the model in both simulators is helpful for determining whether the model will generalize well across different environments. + +### Common Issues +You might see the following error when running train.py: +>An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. + +One way to fix this issue is to uninstall jax and jaxlib and reinstall them with your specific CUDA version the following commands: +```bash +pip uninstall jax jaxlib +pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` +Note that the CUDA version should match the one installed on your machine. + +A similar issue might also occur: +> jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. + +This error is usually caused by an incorrect version of cudnn installed in your environment. To fix this issue, you can try installing cuDNN using the following command: +```bash +conda install cudnn +``` + +Another common issue occurs when rendering on a headless server. For now, we recommend rendering locally or using a remote desktop connection to view the rendering. + +## TODO +- [ ] Get Stompy to load efficiently with MJX (currently, the meshes and collision detection are not loading correctly) +- [ ] Add goal conditioning to the humanoid walk task +- [ ] Implement imitation learning techniques for end-to-end walking and recovering tasks +- [ ] Constantly iterate on the reward functions to improve training efficiency +- [ ] Add CPU-based MuJoCo training platform for improved sim-to-sim support +- [ ] Add sim-to-real transfer learning techniques during training \ No newline at end of file diff --git a/ksim/mjx_gym/__init__.py b/ksim/mjx_gym/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ksim/mjx_gym/envs/__init__.py b/ksim/mjx_gym/envs/__init__.py new file mode 100644 index 0000000..259d7d6 --- /dev/null +++ b/ksim/mjx_gym/envs/__init__.py @@ -0,0 +1,13 @@ +from typing import Any + +from brax import envs + +from ksim.mjx_gym.envs.default_humanoid_env.default_humanoid import DefaultHumanoidEnv +from ksim.mjx_gym.envs.stompy_env.stompy import StompyEnv + +environments = {"default_humanoid": DefaultHumanoidEnv, "stompy": StompyEnv} + + +def get_env(name: str, **kwargs: Any) -> envs.Env: # noqa: ANN401 + envs.register_environment(name, environments[name]) + return envs.get_environment(name, **kwargs) diff --git a/ksim/mjx_gym/envs/default_humanoid_env/__init__.py b/ksim/mjx_gym/envs/default_humanoid_env/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ksim/mjx_gym/envs/default_humanoid_env/default_humanoid.py b/ksim/mjx_gym/envs/default_humanoid_env/default_humanoid.py new file mode 100644 index 0000000..bacd711 --- /dev/null +++ b/ksim/mjx_gym/envs/default_humanoid_env/default_humanoid.py @@ -0,0 +1,175 @@ +"""Defines the default humanoid environment.""" + +from typing import NotRequired, TypedDict, Unpack + +import jax +import jax.numpy as jp +import mujoco +from brax import base +from brax.envs.base import PipelineEnv, State +from brax.io import mjcf +from brax.mjx.base import State as mjxState +from etils import epath + +from ksim.mjx_gym.envs.default_humanoid_env.rewards import ( + DEFAULT_REWARD_PARAMS, + RewardParams, + get_reward_fn, +) + + +class EnvKwargs(TypedDict): + sys: base.System + backend: NotRequired[str] + n_frames: NotRequired[int] + debug: NotRequired[bool] + + +class DefaultHumanoidEnv(PipelineEnv): + """An environment for humanoid body position, velocities, and angles. + + Note: This environment is based on the default humanoid environment in the Brax library. + https://github.com/google/brax/blob/main/brax/envs/humanoid.py + + However, this environment is designed to work with modular reward functions, allowing for quicker experimentation. + """ + + def __init__( + self, + reward_params: RewardParams = DEFAULT_REWARD_PARAMS, + terminate_when_unhealthy: bool = True, + reset_noise_scale: float = 1e-2, + exclude_current_positions_from_observation: bool = True, + log_reward_breakdown: bool = True, + **kwargs: Unpack[EnvKwargs], + ) -> None: + path = epath.Path(epath.resource_path("mujoco")) / ("mjx/test_data/humanoid") + mj_model = mujoco.MjModel.from_xml_path((path / "humanoid.xml").as_posix()) + mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG + mj_model.opt.iterations = 6 + mj_model.opt.ls_iterations = 6 + + sys = mjcf.load_model(mj_model) + + physics_steps_per_control_step = 4 # Should find way to perturb this value in the future + kwargs["n_frames"] = kwargs.get("n_frames", physics_steps_per_control_step) + kwargs["backend"] = "mjx" + + super().__init__(sys, **kwargs) + + self._reward_params = reward_params + self._terminate_when_unhealthy = terminate_when_unhealthy + self._reset_noise_scale = reset_noise_scale + self._exclude_current_positions_from_observation = exclude_current_positions_from_observation + self._log_reward_breakdown = log_reward_breakdown + + self.reward_fn = get_reward_fn(self._reward_params, self.dt, include_reward_breakdown=True) + + def reset(self, rng: jp.ndarray) -> State: + """Resets the environment to an initial state. + + Args: + rng: Random number generator seed. + + Returns: + The initial state of the environment. + """ + rng, rng1, rng2 = jax.random.split(rng, 3) + + low, hi = -self._reset_noise_scale, self._reset_noise_scale + qpos = self.sys.qpos0 + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi) + qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi) + + mjx_state = self.pipeline_init(qpos, qvel) + assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}" + + obs = self._get_obs(mjx_state, jp.zeros(self.sys.nu)) + reward, done, zero = jp.zeros(3) + metrics = { + "x_position": zero, + "y_position": zero, + "distance_from_origin": zero, + "x_velocity": zero, + "y_velocity": zero, + } + for key in self._reward_params.keys(): + metrics[key] = zero + + return State(mjx_state, obs, reward, done, metrics) + + def step(self, state: State, action: jp.ndarray) -> State: + """Runs one timestep of the environment's dynamics. + + Args: + state: The current state of the environment. + action: The action to take. + + Returns: + A tuple of the next state, the reward, whether the episode has ended, and additional information. + """ + mjx_state = state.pipeline_state + assert mjx_state, "state.pipeline_state was recorded as None" + # TODO: determine whether to raise an error or reset the environment + + next_mjx_state = self.pipeline_step(mjx_state, action) + + assert type(next_mjx_state) == mjxState, f"next_mjx_state is of type {type(next_mjx_state)}" + assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}" + # mlutz: from what I've seen, .pipeline_state and .pipeline_step(...) + # actually return an brax.mjx.base.State object however, the type + # hinting suggests that it should return a brax.base.State object + # brax.mjx.base.State inherits from brax.base.State but also inherits + # from mjx.Data, which is needed for some rewards + + obs = self._get_obs(mjx_state, action) + reward, is_healthy, reward_breakdown = self.reward_fn(mjx_state, action, next_mjx_state) + + if self._terminate_when_unhealthy: + done = 1.0 - is_healthy + else: + done = jp.array(0) + + state.metrics.update( + x_position=next_mjx_state.subtree_com[1][0], + y_position=next_mjx_state.subtree_com[1][1], + distance_from_origin=jp.linalg.norm(next_mjx_state.subtree_com[1]), + x_velocity=(next_mjx_state.subtree_com[1][0] - mjx_state.subtree_com[1][0]) / self.dt, + y_velocity=(next_mjx_state.subtree_com[1][1] - mjx_state.subtree_com[1][1]) / self.dt, + ) + + if self._log_reward_breakdown: + for key, val in reward_breakdown.items(): + state.metrics[key] = val + + # TODO: fix the type hinting... + return state.replace( + pipeline_state=next_mjx_state, + obs=obs, + reward=reward, + done=done, + ) + + def _get_obs(self, data: mjxState, action: jp.ndarray) -> jp.ndarray: + """Observes humanoid body position, velocities, and angles. + + Args: + data: The current state of the environment. + action: The current action. + + Returns: + Observations of the environment. + """ + position = data.qpos + if self._exclude_current_positions_from_observation: + position = position[2:] + + # external_contact_forces are excluded + return jp.concatenate( + [ + position, + data.qvel, + data.cinert[1:].ravel(), + data.cvel[1:].ravel(), + data.qfrc_actuator, + ] + ) diff --git a/ksim/mjx_gym/envs/default_humanoid_env/rewards.py b/ksim/mjx_gym/envs/default_humanoid_env/rewards.py new file mode 100644 index 0000000..b357e0c --- /dev/null +++ b/ksim/mjx_gym/envs/default_humanoid_env/rewards.py @@ -0,0 +1,147 @@ +"""Defines the rewards for the humanoid environment.""" + +from typing import Callable, NotRequired, TypedDict + +import jax +import jax.numpy as jp +from brax.mjx.base import State as mjxState + + +class RewardDict(TypedDict): + weight: float + healthy_z_lower: NotRequired[float] + healthy_z_upper: NotRequired[float] + + +RewardParams = dict[str, RewardDict] + + +DEFAULT_REWARD_PARAMS: RewardParams = { + "rew_forward": {"weight": 1.25}, + "rew_healthy": {"weight": 5.0, "healthy_z_lower": 1.0, "healthy_z_upper": 2.0}, + "rew_ctrl_cost": {"weight": 0.1}, +} + + +def get_reward_fn( + reward_params: RewardParams, + dt: jax.Array, + include_reward_breakdown: bool, +) -> Callable[[mjxState, jp.ndarray, mjxState], tuple[jp.ndarray, jp.ndarray, dict[str, jp.ndarray]]]: + """Get a combined reward function. + + Args: + reward_params: Dictionary of reward parameters. + dt: Time step. + include_reward_breakdown: Whether to include a breakdown of the reward + into its components. + + Returns: + A reward function that takes in a state, action, and next state and + returns a float wrapped in a jp.ndarray. + """ + + def reward_fn( + state: mjxState, action: jp.ndarray, next_state: mjxState + ) -> tuple[jp.ndarray, jp.ndarray, dict[str, jp.ndarray]]: + reward, is_healthy = jp.array(0.0), jp.array(1.0) + rewards = {} + for key, params in reward_params.items(): + r, h = reward_functions[key](state, action, next_state, dt, params) + is_healthy *= h + reward += r + if include_reward_breakdown: # For more detailed logging, can be disabled for performance + rewards[key] = r + return reward, is_healthy, rewards + + return reward_fn + + +def forward_reward_fn( + state: mjxState, + action: jp.ndarray, + next_state: mjxState, + dt: jax.Array, + params: RewardDict, +) -> tuple[jp.ndarray, jp.ndarray]: + """Reward function for moving forward. + + Args: + state: Current state. + action: Action taken. + next_state: Next state. + dt: Time step. + params: Reward parameters. + + Returns: + A float wrapped in a jax array. + """ + xpos = state.subtree_com[1][0] # TODO: include stricter typing than mjxState to avoid this type error + next_xpos = next_state.subtree_com[1][0] + velocity = (next_xpos - xpos) / dt + forward_reward = params["weight"] * velocity + + return forward_reward, jp.array(1.0) # TODO: ensure everything is initialized in a size 2 array instead... + + +def healthy_reward_fn( + state: mjxState, + action: jp.ndarray, + next_state: mjxState, + dt: jax.Array, + params: RewardDict, +) -> tuple[jp.ndarray, jp.ndarray]: + """Reward function for staying healthy. + + Args: + state: Current state. + action: Action taken. + next_state: Next state. + dt: Time step. + params: Reward parameters. + + Returns: + A float wrapped in a jax array. + """ + min_z = params["healthy_z_lower"] + max_z = params["healthy_z_upper"] + is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0) + is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy) + healthy_reward = jp.array(params["weight"]) * is_healthy + + return healthy_reward, is_healthy + + +def ctrl_cost_reward_fn( + state: mjxState, + action: jp.ndarray, + next_state: mjxState, + dt: jax.Array, + params: RewardDict, +) -> tuple[jp.ndarray, jp.ndarray]: + """Reward function for control cost. + + Args: + state: Current state. + action: Action taken. + next_state: Next state. + dt: Time step. + params: Reward parameters. + + Returns: + A float wrapped in a jax array. + """ + ctrl_cost = -params["weight"] * jp.sum(jp.square(action)) + + return ctrl_cost, jp.array(1.0) + + +RewardFunction = Callable[[mjxState, jp.ndarray, mjxState, jax.Array, RewardDict], tuple[jp.ndarray, jp.ndarray]] + + +# NOTE: After defining the reward functions, they must be added here to be used in the combined reward function. +reward_functions: dict[str, RewardFunction] = { + "rew_forward": forward_reward_fn, + "rew_healthy": healthy_reward_fn, + "rew_ctrl_cost": ctrl_cost_reward_fn, +} diff --git a/ksim/mjx_gym/envs/stompy_env/__init__.py b/ksim/mjx_gym/envs/stompy_env/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ksim/mjx_gym/envs/stompy_env/rewards.py b/ksim/mjx_gym/envs/stompy_env/rewards.py new file mode 100644 index 0000000..6c13df8 --- /dev/null +++ b/ksim/mjx_gym/envs/stompy_env/rewards.py @@ -0,0 +1,141 @@ +"""Defines rewards for the Stompy environment.""" + +from typing import Callable, Dict, Tuple + +import jax +import jax.numpy as jp +from brax.mjx.base import State as mjxState + +from ksim.mjx_gym.envs.default_humanoid_env.rewards import ( + RewardDict, + RewardFunction, + RewardParams, +) + +DEFAULT_REWARD_PARAMS: RewardParams = { + "rew_forward": {"weight": 1.25}, + "rew_healthy": {"weight": 5.0, "healthy_z_lower": 1.0, "healthy_z_upper": 2.0}, + "rew_ctrl_cost": {"weight": 0.1}, +} + + +def get_reward_fn( + reward_params: RewardParams, + dt: jax.Array, + include_reward_breakdown: bool, +) -> Callable[[mjxState, jp.ndarray, mjxState], Tuple[jp.ndarray, jp.ndarray, Dict[str, jp.ndarray]]]: + """Get a combined reward function. + + Args: + reward_params: Dictionary of reward parameters. + dt: Time step. + include_reward_breakdown: Whether to include a breakdown of the reward + into its components. + + Returns: + A reward function that takes in a state, action, and next state and + returns a float wrapped in a jp.ndarray. + """ + + def reward_fn( + state: mjxState, + action: jp.ndarray, + next_state: mjxState, + ) -> Tuple[jp.ndarray, jp.ndarray, Dict[str, jp.ndarray]]: + reward, is_healthy = jp.array(0.0), jp.array(1.0) + rewards = {} + for key, params in reward_params.items(): + r, h = reward_functions[key](state, action, next_state, dt, params) + is_healthy *= h + reward += r + if include_reward_breakdown: # For more detailed logging, can be disabled for performance + rewards[key] = r + return reward, is_healthy, rewards + + return reward_fn + + +def forward_reward_fn( + state: mjxState, + action: jp.ndarray, + next_state: mjxState, + dt: jax.Array, + params: RewardDict, +) -> Tuple[jp.ndarray, jp.ndarray]: + """Reward function for moving forward. + + Args: + state: Current state. + action: Action taken. + next_state: Next state. + dt: Time step. + params: Reward parameters. + + Returns: + A float wrapped in a jax array. + """ + xpos = state.subtree_com[1][0] # TODO: include stricter typing than mjxState to avoid this type error + next_xpos = next_state.subtree_com[1][0] + velocity = (next_xpos - xpos) / dt + forward_reward = params["weight"] * velocity + + return forward_reward, jp.array(1.0) # TODO: ensure everything is initialized in a size 2 array instead... + + +def healthy_reward_fn( + state: mjxState, + action: jp.ndarray, + next_state: mjxState, + dt: jax.Array, + params: RewardDict, +) -> Tuple[jp.ndarray, jp.ndarray]: + """Reward function for staying healthy. + + Args: + state: Current state. + action: Action taken. + next_state: Next state. + dt: Time step. + params: Reward parameters. + + Returns: + A float wrapped in a jax array. + """ + min_z = params["healthy_z_lower"] + max_z = params["healthy_z_upper"] + is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0) + is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy) + healthy_reward = jp.array(params["weight"]) * is_healthy + + return healthy_reward, is_healthy + + +def ctrl_cost_reward_fn( + state: mjxState, + action: jp.ndarray, + next_state: mjxState, + dt: jax.Array, + params: RewardDict, +) -> Tuple[jp.ndarray, jp.ndarray]: + """Reward function for control cost. + + Args: + state: Current state. + action: Action taken. + next_state: Next state. + dt: Time step. + params: Reward parameters. + + Returns: + A float wrapped in a jax array. + """ + ctrl_cost = -params["weight"] * jp.sum(jp.square(action)) + + return ctrl_cost, jp.array(1.0) + + +reward_functions: dict[str, RewardFunction] = { + "rew_forward": forward_reward_fn, + "rew_healthy": healthy_reward_fn, + "rew_ctrl_cost": ctrl_cost_reward_fn, +} diff --git a/ksim/mjx_gym/envs/stompy_env/stompy.py b/ksim/mjx_gym/envs/stompy_env/stompy.py new file mode 100644 index 0000000..3008c64 --- /dev/null +++ b/ksim/mjx_gym/envs/stompy_env/stompy.py @@ -0,0 +1,165 @@ +"""Defines the Stompy MJX environment.""" + +import os +from typing import Unpack + +import jax +import jax.numpy as jp +import mujoco +from brax.envs.base import PipelineEnv, State +from brax.io import mjcf +from brax.mjx.base import State as mjxState + +from ksim.mjx_gym.envs.default_humanoid_env.default_humanoid import EnvKwargs +from ksim.mjx_gym.envs.default_humanoid_env.rewards import ( + DEFAULT_REWARD_PARAMS, + RewardParams, + get_reward_fn, +) + + +class StompyEnv(PipelineEnv): + """An environment for humanoid body position, velocities, and angles.""" + + def __init__( + self, + reward_params: RewardParams = DEFAULT_REWARD_PARAMS, + terminate_when_unhealthy: bool = True, + reset_noise_scale: float = 1e-2, + exclude_current_positions_from_observation: bool = True, + log_reward_breakdown: bool = True, + **kwargs: Unpack[EnvKwargs], + ) -> None: + path = os.getenv("MODEL_DIR", "") + "/robot_simplified.xml" + mj_model = mujoco.MjModel.from_xml_path(path) + mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG + mj_model.opt.iterations = 6 + mj_model.opt.ls_iterations = 6 + + sys = mjcf.load_model(mj_model) + + physics_steps_per_control_step = 4 # Should find way to perturb this value in the future + kwargs["n_frames"] = kwargs.get("n_frames", physics_steps_per_control_step) + kwargs["backend"] = "mjx" + + super().__init__(sys, **kwargs) + + self._reward_params = reward_params + self._terminate_when_unhealthy = terminate_when_unhealthy + self._reset_noise_scale = reset_noise_scale + self._exclude_current_positions_from_observation = exclude_current_positions_from_observation + self._log_reward_breakdown = log_reward_breakdown + + self.reward_fn = get_reward_fn(self._reward_params, self.dt, include_reward_breakdown=True) + + def reset(self, rng: jp.ndarray) -> State: + """Resets the environment to an initial state. + + Args: + rng: Random number generator seed. + + Returns: + The initial state of the environment. + """ + rng, rng1, rng2 = jax.random.split(rng, 3) + + low, hi = -self._reset_noise_scale, self._reset_noise_scale + qpos = self.sys.qpos0 + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi) + qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi) + + mjx_state = self.pipeline_init(qpos, qvel) + assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}" + + obs = self._get_obs(mjx_state, jp.zeros(self.sys.nu)) + reward, done, zero = jp.zeros(3) + metrics = { + "x_position": zero, + "y_position": zero, + "distance_from_origin": zero, + "x_velocity": zero, + "y_velocity": zero, + } + for key in self._reward_params.keys(): + metrics[key] = zero + + return State(mjx_state, obs, reward, done, metrics) + + def step(self, state: State, action: jp.ndarray) -> State: + """Runs one timestep of the environment's dynamics. + + Args: + state: The current state of the environment. + action: The action to take. + + Returns: + A tuple of the next state, the reward, whether the episode has ended, and additional information. + """ + mjx_state = state.pipeline_state + assert mjx_state, "state.pipeline_state was recorded as None" + # TODO: determine whether to raise an error or reset the environment + + next_mjx_state = self.pipeline_step(mjx_state, action) + + assert type(next_mjx_state) == mjxState, f"next_mjx_state is of type {type(next_mjx_state)}" + assert type(mjx_state) == mjxState, f"mjx_state is of type {type(mjx_state)}" + # mlutz: from what I've seen, .pipeline_state and .pipeline_step(...) + # actually return an brax.mjx.base.State object however, the type + # hinting suggests that it should return a brax.base.State object + # brax.mjx.base.State inherits from brax.base.State but also inherits + # from mjx.Data, which is needed for some rewards + + obs = self._get_obs(mjx_state, action) + reward, is_healthy, reward_breakdown = self.reward_fn(mjx_state, action, next_mjx_state) + + if self._terminate_when_unhealthy: + done = 1.0 - is_healthy + else: + done = jp.array(0) + + state.metrics.update( + x_position=next_mjx_state.subtree_com[1][0], + y_position=next_mjx_state.subtree_com[1][1], + distance_from_origin=jp.linalg.norm(next_mjx_state.subtree_com[1]), + x_velocity=(next_mjx_state.subtree_com[1][0] - mjx_state.subtree_com[1][0]) / self.dt, + y_velocity=(next_mjx_state.subtree_com[1][1] - mjx_state.subtree_com[1][1]) / self.dt, + ) + + if self._log_reward_breakdown: + for key, val in reward_breakdown.items(): + state.metrics[key] = val + + return state.replace(pipeline_state=next_mjx_state, obs=obs, reward=reward, done=done) + + def _get_obs(self, data: mjxState, action: jp.ndarray) -> jp.ndarray: + """Observes humanoid body position, velocities, and angles. + + Args: + data: The current state of the environment. + action: The current action. + + Returns: + Observations of the environment. + """ + position = data.qpos + if self._exclude_current_positions_from_observation: + position = position[2:] + + # external_contact_forces are excluded + return jp.concatenate( + [ + position, + data.qvel, + data.cinert[1:].ravel(), + data.cvel[1:].ravel(), + data.qfrc_actuator, + ] + ) + + +def adhoc_test() -> None: + print("hello, world!") + + +if __name__ == "__main__": + # python -m ksim.mjx_gym.envs.stompy_env.stompy + adhoc_test() diff --git a/ksim/mjx_gym/experiments/default_humanoid_walk.yaml b/ksim/mjx_gym/experiments/default_humanoid_walk.yaml new file mode 100644 index 0000000..94ae788 --- /dev/null +++ b/ksim/mjx_gym/experiments/default_humanoid_walk.yaml @@ -0,0 +1,31 @@ +project_name: default_humanoid_walk +experiment_name: constant_alive_reward_fixing +num_timesteps: 100000000 +num_evals: 10 +reward_scaling: 0.1 +episode_length: 1000 +normalize_observations: true +action_repeat: 1 +unroll_length: 10 +num_minibatches: 32 +num_updates_per_batch: 8 +discounting: 0.97 +learning_rate: 0.0003 +entropy_cost: 0.001 +num_envs: 2048 +batch_size: 1024 +seed: 0 +env_name: default_humanoid +reward_params: + rew_forward: + weight: 1.25 + rew_healthy: + weight: 5.0 + healthy_z_lower: 1.0 + healthy_z_upper: 2.0 + rew_ctrl_cost: + weight: 0.1 +terminate_when_unhealthy: true +reset_noise_scale: 0.01 +exclude_current_positions_from_observation: true +log_reward_breakdown: true diff --git a/ksim/mjx_gym/experiments/stompy_walk.yaml b/ksim/mjx_gym/experiments/stompy_walk.yaml new file mode 100644 index 0000000..32772a3 --- /dev/null +++ b/ksim/mjx_gym/experiments/stompy_walk.yaml @@ -0,0 +1,31 @@ +project_name: stompy_walk +experiment_name: first_experiment +num_timesteps: 30000000 +num_evals: 5 +reward_scaling: 0.1 +episode_length: 1000 +normalize_observations: true +action_repeat: 1 +unroll_length: 10 +num_minibatches: 32 +num_updates_per_batch: 8 +discounting: 0.97 +learning_rate: 0.0003 +entropy_cost: 0.001 +num_envs: 512 +batch_size: 256 +seed: 0 +env_name: stompy +reward_params: + rew_forward: + weight: 1.25 + rew_healthy: + weight: 5.0 + healthy_z_lower: 5.0 + healthy_z_upper: 2.0 + rew_ctrl_cost: + weight: 0.1 +terminate_when_unhealthy: true +reset_noise_scale: 0.01 +exclude_current_positions_from_observation: true +log_reward_breakdown: true diff --git a/ksim/mjx_gym/play.py b/ksim/mjx_gym/play.py new file mode 100644 index 0000000..9136a8e --- /dev/null +++ b/ksim/mjx_gym/play.py @@ -0,0 +1,96 @@ +"""Defines the CLI for running PPO training with specified config file.""" + +import argparse +import logging +from typing import Any + +import mediapy as media +import numpy as np +import wandb +import yaml +from brax.io import model +from brax.training.acme import running_statistics +from brax.training.agents.ppo import networks as ppo_networks + +from ksim.mjx_gym.envs import get_env +from ksim.mjx_gym.envs.default_humanoid_env.default_humanoid import ( + DEFAULT_REWARD_PARAMS, +) +from ksim.mjx_gym.utils.rollouts import render_mjx_rollout, render_mujoco_rollout + +logger = logging.getLogger(__name__) + + +def train(config: dict[str, Any], n_steps: int, render_every: int) -> None: + wandb.init( + project=config.get("project_name", "robotic_locomotion_training") + "_test", + name=config.get("experiment_name", "ppo-training") + "_test", + ) + + # Load environment + env = get_env( + name=config.get("env_name", "default_humanoid"), + reward_params=config.get("reward_params", DEFAULT_REWARD_PARAMS), + terminate_when_unhealthy=config.get("terminate_when_unhealthy", True), + reset_noise_scale=config.get("reset_noise_scale", 1e-2), + exclude_current_positions_from_observation=config.get("exclude_current_positions_from_observation", True), + log_reward_breakdown=config.get("log_reward_breakdown", True), + ) + + logger.info( + "Loaded environment %s with env.observation_size: %s and env.action_size: %s", + config.get("env_name", ""), + env.observation_size, + env.action_size, + ) + + # Loading params + if args.params_path is not None: + model_path = args.params_path + else: + model_path = "weights/" + config.get("project_name", "model") + ".pkl" + params = model.load_params(model_path) + + def normalize(x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x + + if config.get("normalize_observations", False): + normalize = ( + running_statistics.normalize + ) # NOTE: very important to keep training & test normalization consistent + policy_network = ppo_networks.make_ppo_networks( + env.observation_size, env.action_size, preprocess_observations_fn=normalize + ) + inference_fn = ppo_networks.make_inference_fn(policy_network)(params) + print(f"Loaded params from {model_path}") + + # rolling out a trajectory + if args.use_mujoco: + images_thwc = render_mujoco_rollout(env, inference_fn, n_steps, render_every) + else: + images_thwc = render_mjx_rollout(env, inference_fn, n_steps, render_every) + print(f"Rolled out {len(images_thwc)} steps") + + # render the trajectory + images_tchw = np.transpose(images_thwc, (0, 3, 1, 2)) + + fps = 1 / env.dt + wandb.log({"training_rollouts": wandb.Video(images_tchw, fps=fps, format="mp4")}) + media.write_video("video.mp4", images_thwc, fps=fps) + + +if __name__ == "__main__": + # Parse command-line arguments + parser = argparse.ArgumentParser(description="Run PPO training with specified config file.") + parser.add_argument("--config", type=str, required=True, help="Path to the config YAML file") + parser.add_argument("--use_mujoco", type=bool, default=False, help="Use mujoco instead of mjx for rendering") + parser.add_argument("--params_path", type=str, default=None, help="Path to the params file") + parser.add_argument("--n_steps", type=int, default=1000, help="Number of steps to rollout") + parser.add_argument("--render_every", type=int, default=2, help="Render every nth step") + args = parser.parse_args() + + # Load config file + with open(args.config, "r") as file: + config = yaml.safe_load(file) + + train(config, args.n_steps, args.render_every) diff --git a/ksim/mjx_gym/train.py b/ksim/mjx_gym/train.py new file mode 100644 index 0000000..5eb6f45 --- /dev/null +++ b/ksim/mjx_gym/train.py @@ -0,0 +1,85 @@ +"""Defines the training CLI.""" + +import argparse +import functools +from datetime import datetime +from typing import Any + +import wandb +import yaml +from brax.io import model +from brax.training.agents.ppo import train as ppo + +from ksim.mjx_gym.envs import get_env +from ksim.mjx_gym.envs.default_humanoid_env.default_humanoid import ( + DEFAULT_REWARD_PARAMS, +) + + +def train(config: dict[str, Any]) -> None: + wandb.init( + project=config.get("project_name", "robotic-locomotion-training"), + name=config.get("experiment_name", "ppo-training"), + ) + + print(f"reward_params: {config.get('reward_params', DEFAULT_REWARD_PARAMS)}") + print(f'training on {config["num_envs"]} environments') + + env = get_env( + name=config.get("env_name", "default_humanoid"), + reward_params=config.get("reward_params", DEFAULT_REWARD_PARAMS), + terminate_when_unhealthy=config.get("terminate_when_unhealthy", True), + reset_noise_scale=config.get("reset_noise_scale", 1e-2), + exclude_current_positions_from_observation=config.get("exclude_current_positions_from_observation", True), + log_reward_breakdown=config.get("log_reward_breakdown", True), + ) + print(f'Env loaded: {config.get("env_name", "could not find environment")}') + + train_fn = functools.partial( + ppo.train, + num_timesteps=config["num_timesteps"], + num_evals=config["num_evals"], + reward_scaling=config["reward_scaling"], + episode_length=config["episode_length"], + normalize_observations=config["normalize_observations"], + action_repeat=config["action_repeat"], + unroll_length=config["unroll_length"], + num_minibatches=config["num_minibatches"], + num_updates_per_batch=config["num_updates_per_batch"], + discounting=config["discounting"], + learning_rate=config["learning_rate"], + entropy_cost=config["entropy_cost"], + num_envs=config["num_envs"], + batch_size=config["batch_size"], + seed=config["seed"], + ) + + times = [datetime.now()] + + def progress(num_steps: int, metrics: dict[str, Any]) -> None: # noqa: ANN401 + times.append(datetime.now()) + + wandb.log({"steps": num_steps, "epoch_time": (times[-1] - times[-2]).total_seconds(), **metrics}) + + def save_model(current_step: int, make_policy: str, params: dict[str, Any]) -> None: # noqa: ANN401 + model_path = "weights/" + config.get("project_name", "model") + ".pkl" + model.save_params(model_path, params) + print(f"Saved model at step {current_step} to {model_path}") + + make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress, policy_params_fn=save_model) + + print(f"time to jit: {times[1] - times[0]}") + print(f"time to train: {times[-1] - times[1]}") + + +if __name__ == "__main__": + # Parse command-line arguments + parser = argparse.ArgumentParser(description="Run PPO training with specified config file.") + parser.add_argument("--config", type=str, required=True, help="Path to the config YAML file") + args = parser.parse_args() + + # Load config from YAML file + with open(args.config, "r") as file: + config = yaml.safe_load(file) + + train(config) diff --git a/ksim/mjx_gym/utils/__init__.py b/ksim/mjx_gym/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ksim/mjx_gym/utils/rollouts.py b/ksim/mjx_gym/utils/rollouts.py new file mode 100644 index 0000000..4ee3eb2 --- /dev/null +++ b/ksim/mjx_gym/utils/rollouts.py @@ -0,0 +1,127 @@ +"""Defines some utility functions for rollouts.""" + +import logging +from typing import Callable + +import jax +import jax.numpy as jp +import mujoco +import numpy as np +from brax.mjx.base import State as mjxState +from mujoco import mjx +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +InferenceFn = Callable[[jp.ndarray, jp.ndarray], tuple[jp.ndarray, jp.ndarray]] + + +def mjx_rollout( + env: mujoco.MjModel, + inference_fn: InferenceFn, + n_steps: int = 1000, + render_every: int = 2, + seed: int = 0, +) -> list[mjxState]: + """Rollout a trajectory using MJX. + + It is worth noting that env, a Brax environment, is expected to implement MJX + in the background. See default_humanoid_env for reference. + + Args: + env: Brax environment + inference_fn: Inference function + n_steps: Number of steps to rollout + render_every: Render every nth step + seed: Random seed + + Returns: + A list of pipeline states of the policy rollout + """ + # print(f"Rolling out {n_steps} steps with MJX") + logger.info("Rolling out %d steps with MJX", n_steps) + reset_fn = jax.jit(env.reset) + step_fn = jax.jit(env.step) + inference_fn = jax.jit(inference_fn) + rng = jax.random.PRNGKey(seed) + + state = reset_fn(rng) + rollout = [state.pipeline_state] + for i in tqdm(range(n_steps)): + act_rng, rng = jax.random.split(rng) + ctrl, _ = inference_fn(state.obs, act_rng) + state = step_fn(state, ctrl) + rollout.append(state.pipeline_state) + + if state.done: + state = reset_fn(rng) + + return rollout + + +def render_mjx_rollout( + env: mujoco.MjModel, + inference_fn: InferenceFn, + n_steps: int = 1000, + render_every: int = 2, + seed: int = 0, +) -> np.ndarray: + """Rollout a trajectory using MuJoCo and render it. + + Args: + env: Brax environment + inference_fn: Inference function + n_steps: Number of steps to rollout + render_every: Render every nth step + seed: Random seed + + Returns: + A list of renderings of the policy rollout with dimensions (T, H, W, C) + """ + rollout = mjx_rollout(env, inference_fn, n_steps, render_every, seed) + images = env.render(rollout[::render_every], camera="side") + + return np.array(images) + + +def render_mujoco_rollout( + env: mujoco.MjModel, + inference_fn: InferenceFn, + n_steps: int = 1000, + render_every: int = 2, + seed: int = 0, +) -> np.ndarray: + """Rollout a trajectory using MuJoCo. + + Args: + env: Brax environment + inference_fn: Inference function + n_steps: Number of steps to rollout + render_every: Render every nth step + seed: Random seed + + Returns: + A list of images of the policy rollout (T, H, W, C) + """ + print(f"Rolling out {n_steps} steps with MuJoCo") + model = env.sys.mj_model + data = mujoco.MjData(model) + renderer = mujoco.Renderer(model) + ctrl = jp.zeros(model.nu) + + images: list[np.ndarray] = [] + rng = jax.random.PRNGKey(seed) + for step in tqdm(range(n_steps)): + act_rng, seed = jax.random.split(rng) + obs = env._get_obs(mjx.put_data(model, data), ctrl) + # TODO: implement methods in envs that avoid having to use mjx in a hacky way... + ctrl, _ = inference_fn(obs, act_rng) + data.ctrl = ctrl + for _ in range(env._n_frames): + mujoco.mj_step(model, data) + + if step % render_every == 0: + renderer.update_scene(data, camera="side") + images.append(renderer.render()) + + return np.array(images) diff --git a/ksim/mjx_gym/weights/default_humanoid_walk.pkl b/ksim/mjx_gym/weights/default_humanoid_walk.pkl new file mode 100644 index 0000000..c081b5f Binary files /dev/null and b/ksim/mjx_gym/weights/default_humanoid_walk.pkl differ diff --git a/ksim/requirements-dev.txt b/ksim/requirements-dev.txt index e36b648..7be90d2 100644 --- a/ksim/requirements-dev.txt +++ b/ksim/requirements-dev.txt @@ -5,3 +5,14 @@ darglint mypy pytest ruff + +brax +jax +mediapy +moviepy +mujoco-mjx +numpy +torch +tqdm +imageio +wandb diff --git a/ksim/requirements.txt b/ksim/requirements.txt index 2ee1e7b..f9277d7 100644 --- a/ksim/requirements.txt +++ b/ksim/requirements.txt @@ -1,2 +1,12 @@ # requirements.txt +brax +jax +mediapy +moviepy +mujoco-mjx +numpy +torch +tqdm +imageio +wandb diff --git a/pyproject.toml b/pyproject.toml index 3c7370a..896fa24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,10 +30,20 @@ warn_redundant_casts = true incremental = true namespace_packages = false -# Uncomment to exclude modules from Mypy. -# [[tool.mypy.overrides]] -# module = [] -# ignore_missing_imports = true +enable_incomplete_feature = ["Unpack"] + +[[tool.mypy.overrides]] + +module = [ + "brax.*", + "etils.*", + "mujoco.*", + "tqdm.*", + "jaxlib.cpu.*", + "jaxlib.mlir.*", +] + +ignore_missing_imports = true [tool.isort] diff --git a/setup.py b/setup.py index 078985b..1f4743f 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ name="kscale-sim", version=version, description="K-Scale Simulation Library", - author="Benjamin Bolte", + author="Benjamin Bolte, Pawel Budzianowski, Michael Lutz", url="https://github.com/kscalelabs/ksim", long_description=long_description, long_description_content_type="text/markdown",