diff --git a/sim/mjx_gym/envs/default_humanoid_env/rewards.py b/sim/mjx_gym/envs/default_humanoid_env/rewards.py index d45104f5..ee86ad7d 100644 --- a/sim/mjx_gym/envs/default_humanoid_env/rewards.py +++ b/sim/mjx_gym/envs/default_humanoid_env/rewards.py @@ -2,7 +2,6 @@ import jax import jax.numpy as jp -from brax import base from brax.mjx.base import State as mjxState DEFAULT_REWARD_PARAMS = { @@ -11,6 +10,7 @@ "rew_ctrl_cost": {"weight": 0.1}, } + def get_reward_fn( reward_params: Dict[str, Dict[str, float]], dt, include_reward_breakdown ) -> Callable[[mjxState, jp.ndarray, mjxState], Tuple[jp.ndarray, jp.ndarray, Dict[str, jp.ndarray]]]: diff --git a/sim/mjx_gym/envs/stompy_env/rewards.py b/sim/mjx_gym/envs/stompy_env/rewards.py index 4bf7c57a..6e025e85 100644 --- a/sim/mjx_gym/envs/stompy_env/rewards.py +++ b/sim/mjx_gym/envs/stompy_env/rewards.py @@ -2,7 +2,6 @@ import jax import jax.numpy as jp -from brax import base from brax.mjx.base import State as mjxState DEFAULT_REWARD_PARAMS = { @@ -11,6 +10,7 @@ "rew_ctrl_cost": {"weight": 0.1}, } + def get_reward_fn( reward_params: Dict[str, Dict[str, float]], dt, include_reward_breakdown ) -> Callable[[mjxState, jp.ndarray, mjxState], Tuple[jp.ndarray, jp.ndarray, Dict[str, jp.ndarray]]]: diff --git a/sim/mjx_gym/envs/stompy_env/stompy.py b/sim/mjx_gym/envs/stompy_env/stompy.py index cce4c452..a25652c4 100644 --- a/sim/mjx_gym/envs/stompy_env/stompy.py +++ b/sim/mjx_gym/envs/stompy_env/stompy.py @@ -6,12 +6,11 @@ from brax.envs.base import PipelineEnv, State from brax.io import mjcf from brax.mjx.base import State as mjxState -from etils import epath -from mujoco import mjx from envs.stompy_env.rewards import DEFAULT_REWARD_PARAMS from envs.stompy_env.rewards import get_reward_fn + class StompyEnv(PipelineEnv): """ An environment for humanoid body position, velocities, and angles. @@ -53,9 +52,9 @@ def reset(self, rng: jp.ndarray) -> State: Resets the environment to an initial state. Args: - rng: Random number generator seed. + rng: Random number generator seed. Returns: - The initial state of the environment. + The initial state of the environment. """ rng, rng1, rng2 = jax.random.split(rng, 3) @@ -85,10 +84,10 @@ def step(self, state: State, action: jp.ndarray) -> State: Runs one timestep of the environment's dynamics. Args: - state: The current state of the environment. - action: The action to take. + state: The current state of the environment. + action: The action to take. Returns: - A tuple of the next state, the reward, whether the episode has ended, and additional information. + A tuple of the next state, the reward, whether the episode has ended, and additional information. """ mjx_state = state.pipeline_state assert mjx_state, "state.pipeline_state was recorded as None" @@ -131,10 +130,10 @@ def _get_obs(self, data: mjxState, action: jp.ndarray) -> jp.ndarray: Observes humanoid body position, velocities, and angles. Args: - data: The current state of the environment. - action: The current action. + data: The current state of the environment. + action: The current action. Returns: - Observations of the environment. + Observations of the environment. """ position = data.qpos if self._exclude_current_positions_from_observation: diff --git a/sim/mjx_gym/play.py b/sim/mjx_gym/play.py index 4d27a5f3..c50d3861 100644 --- a/sim/mjx_gym/play.py +++ b/sim/mjx_gym/play.py @@ -11,6 +11,7 @@ from envs.default_humanoid_env.default_humanoid import DEFAULT_REWARD_PARAMS from utils.rollouts import render_mjx_rollout, render_mujoco_rollout + def train(config, n_steps, render_every): wandb.init( project=config.get("project_name", "robotic_locomotion_training") + "_test", @@ -38,7 +39,9 @@ def train(config, n_steps, render_every): params = model.load_params(model_path) normalize = lambda x, y: x if config.get("normalize_observations", False): - normalize = running_statistics.normalize # NOTE: very important to keep training & test normalization consistent + normalize = ( + running_statistics.normalize + ) # NOTE: very important to keep training & test normalization consistent policy_network = ppo_networks.make_ppo_networks( env.observation_size, env.action_size, preprocess_observations_fn=normalize ) @@ -60,6 +63,7 @@ def train(config, n_steps, render_every): wandb.log({"training_rollouts": wandb.Video(images_tchw, fps=fps, format="mp4")}) media.write_video("video.mp4", images_thwc, fps=fps) + if __name__ == "__main__": # Parse command-line arguments parser = argparse.ArgumentParser(description="Run PPO training with specified config file.") @@ -74,4 +78,4 @@ def train(config, n_steps, render_every): with open(args.config, "r") as file: config = yaml.safe_load(file) - train(config, args.n_steps, args.render_every) \ No newline at end of file + train(config, args.n_steps, args.render_every) diff --git a/sim/mjx_gym/train.py b/sim/mjx_gym/train.py index ecaccf87..8536aecf 100644 --- a/sim/mjx_gym/train.py +++ b/sim/mjx_gym/train.py @@ -2,7 +2,6 @@ import functools from datetime import datetime -import matplotlib.pyplot as plt import wandb import yaml from brax.io import model @@ -11,6 +10,7 @@ from envs import get_env from envs.default_humanoid_env.default_humanoid import DEFAULT_REWARD_PARAMS + def train(config): wandb.init( project=config.get("project_name", "robotic-locomotion-training"), @@ -50,6 +50,7 @@ def train(config): ) times = [datetime.now()] + def progress(num_steps, metrics): times.append(datetime.now()) @@ -65,6 +66,7 @@ def save_model(current_step, make_policy, params): print(f"time to jit: {times[1] - times[0]}") print(f"time to train: {times[-1] - times[1]}") + if __name__ == "__main__": # Parse command-line arguments parser = argparse.ArgumentParser(description="Run PPO training with specified config file.") @@ -74,5 +76,5 @@ def save_model(current_step, make_policy, params): # Load config from YAML file with open(args.config, "r") as file: config = yaml.safe_load(file) - - train(config) \ No newline at end of file + + train(config)