From 18a5adb2e36801eb171c7ed063187c1a6d740821 Mon Sep 17 00:00:00 2001 From: icub Date: Mon, 9 Oct 2023 10:26:47 +0000 Subject: [PATCH] Rebase on `minor_api_update` and lint --- src/jaxgym/envs/ergocub.py | 77 ++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 44 deletions(-) diff --git a/src/jaxgym/envs/ergocub.py b/src/jaxgym/envs/ergocub.py index c3d45457c..769327b55 100644 --- a/src/jaxgym/envs/ergocub.py +++ b/src/jaxgym/envs/ergocub.py @@ -3,7 +3,7 @@ import multiprocessing import os import warnings -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, ClassVar, Dict, List, Optional, Type, Union import gymnasium as gym import jax.numpy as jnp @@ -17,7 +17,6 @@ from resolve_robotics_uri_py import resolve_robotics_uri from stable_baselines3 import PPO from stable_baselines3.common import vec_env as vec_env_sb -from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import VecMonitor, VecNormalize from torch import nn @@ -241,7 +240,7 @@ def jaxsim(self) -> JaxSim: def initial(self, rng: Any = None) -> StateType: """""" - assert jax.dtypes.issubdtype(rng, jax.dtypes.prng_key) + # TODO: assert jax.dtypes.issubdtype(rng, jax.dtypes.prng_key) # Split the key subkey1, subkey2 = jax.random.split(rng, num=2) @@ -281,10 +280,10 @@ def initial(self, rng: Any = None) -> StateType: ) # Return the simulation state - return dict( - simulator_data=simulator.data, - goal=jnp.array(goal_xy_position, dtype=float), - ) + return { + simulator_data: simulator.data, + goal: jnp.array(goal_xy_position, dtype=float), + } def transition( self, state: StateType, action: ActType, rng: Any = None @@ -309,20 +308,19 @@ def pre_step(self, sim: JaxSim) -> JaxSim: forces=jnp.atleast_1d(action), joint_names=model.joint_names() ) - return sim + return sim, None number_of_integration_steps = 40 # 0.010 # TODO 20 for having 0.010 # Stepping logic - with simulator.editable(validate=True) as simulator: - simulator, _ = simulator.step_over_horizon( - horizon_steps=number_of_integration_steps, - clear_inputs=False, - callback_handler=SetTorquesOverHorizon(), - ) + simulator, _ = simulator.step_over_horizon( + horizon_steps=number_of_integration_steps, + clear_inputs=False, + callback_handler=SetTorquesOverHorizon(), + ) # Return the new environment state (updated SimulatorData) - return state | dict(simulator_data=simulator.data) + return state | {simulator_data: simulator.data} def observation(self, state: StateType) -> ObsType: """""" @@ -353,7 +351,9 @@ def observation(self, state: StateType) -> ObsType: base_linear_velocity=model.base_velocity()[0:3], base_angular_velocity=model.base_velocity()[3:6], contact_state=model.in_contact( - link_names=[name for name in model.link_names() if "_ankle" in name] + link_names=tuple( + name for name in model.link_names() if "_ankle" in name + ) ), ) @@ -383,11 +383,11 @@ def reward( # reward += 100.0 * v_WB[0] # forward velocity reward -= jnp.linalg.norm(W_p_B[0:2] - W_p_xy_goal) # distance from goal reward += 1.0 * model_next.in_contact( - link_names=[ + link_names=tuple( name for name in model_next.link_names() if name.startswith("leg_") and name.endswith("_lower") - ] + ) ).any().astype(float) reward -= 0.1 * jnp.linalg.norm(action) / action.size # control cost @@ -479,13 +479,6 @@ class ErgoCubWalkEnvV0(JaxEnv): def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None: """""" - from jaxgym.wrappers.jax import ( - ClipActionWrapper, - FlattenSpacesWrapper, - JaxTransformWrapper, - TimeLimit, - ) - func_env = ErgoCubWalkFuncEnvV0() func_env_wrapped = func_env @@ -506,7 +499,7 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None: class ErgoCubWalkVectorEnvV0(JaxVectorEnv): """""" - metadata = dict() + metadata = {} def __init__( self, @@ -549,8 +542,6 @@ def make_jax_env( ) -> JaxEnv: """""" - # TODO: single env -> time limit with stable_baselines? - if max_episode_steps in {None, 0}: env = ErgoCubWalkFuncEnvV0() else: @@ -630,10 +621,10 @@ def tree_inverse_transpose( ) -> List[jtp.PyTree]: """""" - return [ + return tuple( jax.tree_util.tree_map(lambda leaf: leaf[i], pytree) for i in range(batch_size) - ] + ) def step_wait(self) -> vec_env_sb.base_vec_env.VecEnvStepReturn: """""" @@ -717,7 +708,7 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32") if np.array(seed, dtype="uint32") != np.array(seed): - raise ValueError(f"seed must be compatible with 'uint32' casting") + raise ValueError("seed must be compatible with 'uint32' casting") self._seed = seed return [seed] @@ -735,7 +726,7 @@ def make_vec_env_stable_baselines( env = jax_dataclass_env - vec_env_kwargs = vec_env_kwargs if vec_env_kwargs is not None else dict() + vec_env_kwargs = vec_env_kwargs if vec_env_kwargs is not None else {} vec_env = JaxVectorEnv( func_env=env, @@ -755,11 +746,11 @@ def make_vec_env_stable_baselines( os.environ["IGN_GAZEBO_RESOURCE_PATH"] = "/conda/share/" # DEBUG - max_episode_steps = 200 + MAX_EPISODE_STEP = 200 func_env = NaNHandlerWrapper(env=ErgoCubWalkFuncEnvV0()) - if max_episode_steps is not None: - func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps) + if MAX_EPISODE_STEP is not None: + func_env = TimeLimit(env=func_env, max_episode_steps=MAX_EPISODE_STEP) func_env = ClipActionWrapper( env=SquashActionWrapper(env=ActionNoiseWrapper(env=func_env)), @@ -767,11 +758,9 @@ def make_vec_env_stable_baselines( vec_env = make_vec_env_stable_baselines( jax_dataclass_env=func_env, - n_envs=6000, + n_envs=256, seed=42, - vec_env_kwargs=dict( - jit_compile=True, - ), + vec_env_kwargs={jit_compile: True}, ) vec_env = VecMonitor( @@ -798,11 +787,11 @@ def make_vec_env_stable_baselines( target_kl=0.025, verbose=2, learning_rate=0.000_300, - policy_kwargs=dict( - activation_fn=nn.ReLU, - net_arch=dict(pi=[512, 512], vf=[512, 512]), - log_std_init=np.log(0.05), - ), + policy_kwargs={ + activation_fn: nn.ReLU, + net_arch: {pi: [512, 512], vf: [512, 512]}, + log_std_init: np.log(0.05), + }, ) print(model.policy)