Skip to content

Commit

Permalink
chore: removed unused imports
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-lutz committed May 23, 2024
1 parent 9346d4e commit e2a9be4
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 17 deletions.
2 changes: 1 addition & 1 deletion sim/mjx_gym/envs/default_humanoid_env/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import jax
import jax.numpy as jp
from brax import base
from brax.mjx.base import State as mjxState

DEFAULT_REWARD_PARAMS = {
Expand All @@ -11,6 +10,7 @@
"rew_ctrl_cost": {"weight": 0.1},
}


def get_reward_fn(
reward_params: Dict[str, Dict[str, float]], dt, include_reward_breakdown
) -> Callable[[mjxState, jp.ndarray, mjxState], Tuple[jp.ndarray, jp.ndarray, Dict[str, jp.ndarray]]]:
Expand Down
2 changes: 1 addition & 1 deletion sim/mjx_gym/envs/stompy_env/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import jax
import jax.numpy as jp
from brax import base
from brax.mjx.base import State as mjxState

DEFAULT_REWARD_PARAMS = {
Expand All @@ -11,6 +10,7 @@
"rew_ctrl_cost": {"weight": 0.1},
}


def get_reward_fn(
reward_params: Dict[str, Dict[str, float]], dt, include_reward_breakdown
) -> Callable[[mjxState, jp.ndarray, mjxState], Tuple[jp.ndarray, jp.ndarray, Dict[str, jp.ndarray]]]:
Expand Down
19 changes: 9 additions & 10 deletions sim/mjx_gym/envs/stompy_env/stompy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
from brax.mjx.base import State as mjxState
from etils import epath
from mujoco import mjx

from envs.stompy_env.rewards import DEFAULT_REWARD_PARAMS
from envs.stompy_env.rewards import get_reward_fn


class StompyEnv(PipelineEnv):
"""
An environment for humanoid body position, velocities, and angles.
Expand Down Expand Up @@ -53,9 +52,9 @@ def reset(self, rng: jp.ndarray) -> State:
Resets the environment to an initial state.
Args:
rng: Random number generator seed.
rng: Random number generator seed.
Returns:
The initial state of the environment.
The initial state of the environment.
"""
rng, rng1, rng2 = jax.random.split(rng, 3)

Expand Down Expand Up @@ -85,10 +84,10 @@ def step(self, state: State, action: jp.ndarray) -> State:
Runs one timestep of the environment's dynamics.
Args:
state: The current state of the environment.
action: The action to take.
state: The current state of the environment.
action: The action to take.
Returns:
A tuple of the next state, the reward, whether the episode has ended, and additional information.
A tuple of the next state, the reward, whether the episode has ended, and additional information.
"""
mjx_state = state.pipeline_state
assert mjx_state, "state.pipeline_state was recorded as None"
Expand Down Expand Up @@ -131,10 +130,10 @@ def _get_obs(self, data: mjxState, action: jp.ndarray) -> jp.ndarray:
Observes humanoid body position, velocities, and angles.
Args:
data: The current state of the environment.
action: The current action.
data: The current state of the environment.
action: The current action.
Returns:
Observations of the environment.
Observations of the environment.
"""
position = data.qpos
if self._exclude_current_positions_from_observation:
Expand Down
8 changes: 6 additions & 2 deletions sim/mjx_gym/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from envs.default_humanoid_env.default_humanoid import DEFAULT_REWARD_PARAMS
from utils.rollouts import render_mjx_rollout, render_mujoco_rollout


def train(config, n_steps, render_every):
wandb.init(
project=config.get("project_name", "robotic_locomotion_training") + "_test",
Expand Down Expand Up @@ -38,7 +39,9 @@ def train(config, n_steps, render_every):
params = model.load_params(model_path)
normalize = lambda x, y: x
if config.get("normalize_observations", False):
normalize = running_statistics.normalize # NOTE: very important to keep training & test normalization consistent
normalize = (
running_statistics.normalize
) # NOTE: very important to keep training & test normalization consistent
policy_network = ppo_networks.make_ppo_networks(
env.observation_size, env.action_size, preprocess_observations_fn=normalize
)
Expand All @@ -60,6 +63,7 @@ def train(config, n_steps, render_every):
wandb.log({"training_rollouts": wandb.Video(images_tchw, fps=fps, format="mp4")})
media.write_video("video.mp4", images_thwc, fps=fps)


if __name__ == "__main__":
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Run PPO training with specified config file.")
Expand All @@ -74,4 +78,4 @@ def train(config, n_steps, render_every):
with open(args.config, "r") as file:
config = yaml.safe_load(file)

train(config, args.n_steps, args.render_every)
train(config, args.n_steps, args.render_every)
8 changes: 5 additions & 3 deletions sim/mjx_gym/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import functools
from datetime import datetime

import matplotlib.pyplot as plt
import wandb
import yaml
from brax.io import model
Expand All @@ -11,6 +10,7 @@
from envs import get_env
from envs.default_humanoid_env.default_humanoid import DEFAULT_REWARD_PARAMS


def train(config):
wandb.init(
project=config.get("project_name", "robotic-locomotion-training"),
Expand Down Expand Up @@ -50,6 +50,7 @@ def train(config):
)

times = [datetime.now()]

def progress(num_steps, metrics):
times.append(datetime.now())

Expand All @@ -65,6 +66,7 @@ def save_model(current_step, make_policy, params):
print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")


if __name__ == "__main__":
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Run PPO training with specified config file.")
Expand All @@ -74,5 +76,5 @@ def save_model(current_step, make_policy, params):
# Load config from YAML file
with open(args.config, "r") as file:
config = yaml.safe_load(file)
train(config)

train(config)

0 comments on commit e2a9be4

Please sign in to comment.