Skip to content

Commit

Permalink
Rebase on minor_api_update and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
icub committed Oct 9, 2023
1 parent 1c53f37 commit 18a5adb
Showing 1 changed file with 33 additions and 44 deletions.
77 changes: 33 additions & 44 deletions src/jaxgym/envs/ergocub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import multiprocessing
import os
import warnings
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, ClassVar, Dict, List, Optional, Type, Union

import gymnasium as gym
import jax.numpy as jnp
Expand All @@ -17,7 +17,6 @@
from resolve_robotics_uri_py import resolve_robotics_uri
from stable_baselines3 import PPO
from stable_baselines3.common import vec_env as vec_env_sb
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecMonitor, VecNormalize
from torch import nn

Expand Down Expand Up @@ -241,7 +240,7 @@ def jaxsim(self) -> JaxSim:

def initial(self, rng: Any = None) -> StateType:
""""""
assert jax.dtypes.issubdtype(rng, jax.dtypes.prng_key)
# TODO: assert jax.dtypes.issubdtype(rng, jax.dtypes.prng_key)

# Split the key
subkey1, subkey2 = jax.random.split(rng, num=2)
Expand Down Expand Up @@ -281,10 +280,10 @@ def initial(self, rng: Any = None) -> StateType:
)

# Return the simulation state
return dict(
simulator_data=simulator.data,
goal=jnp.array(goal_xy_position, dtype=float),
)
return {
simulator_data: simulator.data,
goal: jnp.array(goal_xy_position, dtype=float),
}

def transition(
self, state: StateType, action: ActType, rng: Any = None
Expand All @@ -309,20 +308,19 @@ def pre_step(self, sim: JaxSim) -> JaxSim:
forces=jnp.atleast_1d(action), joint_names=model.joint_names()
)

return sim
return sim, None

number_of_integration_steps = 40 # 0.010 # TODO 20 for having 0.010

# Stepping logic
with simulator.editable(validate=True) as simulator:
simulator, _ = simulator.step_over_horizon(
horizon_steps=number_of_integration_steps,
clear_inputs=False,
callback_handler=SetTorquesOverHorizon(),
)
simulator, _ = simulator.step_over_horizon(
horizon_steps=number_of_integration_steps,
clear_inputs=False,
callback_handler=SetTorquesOverHorizon(),
)

# Return the new environment state (updated SimulatorData)
return state | dict(simulator_data=simulator.data)
return state | {simulator_data: simulator.data}

def observation(self, state: StateType) -> ObsType:
""""""
Expand Down Expand Up @@ -353,7 +351,9 @@ def observation(self, state: StateType) -> ObsType:
base_linear_velocity=model.base_velocity()[0:3],
base_angular_velocity=model.base_velocity()[3:6],
contact_state=model.in_contact(
link_names=[name for name in model.link_names() if "_ankle" in name]
link_names=tuple(
name for name in model.link_names() if "_ankle" in name
)
),
)

Expand Down Expand Up @@ -383,11 +383,11 @@ def reward(
# reward += 100.0 * v_WB[0] # forward velocity
reward -= jnp.linalg.norm(W_p_B[0:2] - W_p_xy_goal) # distance from goal
reward += 1.0 * model_next.in_contact(
link_names=[
link_names=tuple(
name
for name in model_next.link_names()
if name.startswith("leg_") and name.endswith("_lower")
]
)
).any().astype(float)
reward -= 0.1 * jnp.linalg.norm(action) / action.size # control cost

Expand Down Expand Up @@ -479,13 +479,6 @@ class ErgoCubWalkEnvV0(JaxEnv):
def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None:
""""""

from jaxgym.wrappers.jax import (
ClipActionWrapper,
FlattenSpacesWrapper,
JaxTransformWrapper,
TimeLimit,
)

func_env = ErgoCubWalkFuncEnvV0()

func_env_wrapped = func_env
Expand All @@ -506,7 +499,7 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None:
class ErgoCubWalkVectorEnvV0(JaxVectorEnv):
""""""

metadata = dict()
metadata = {}

def __init__(
self,
Expand Down Expand Up @@ -549,8 +542,6 @@ def make_jax_env(
) -> JaxEnv:
""""""

# TODO: single env -> time limit with stable_baselines?

if max_episode_steps in {None, 0}:
env = ErgoCubWalkFuncEnvV0()
else:
Expand Down Expand Up @@ -630,10 +621,10 @@ def tree_inverse_transpose(
) -> List[jtp.PyTree]:
""""""

return [
return tuple(
jax.tree_util.tree_map(lambda leaf: leaf[i], pytree)
for i in range(batch_size)
]
)

def step_wait(self) -> vec_env_sb.base_vec_env.VecEnvStepReturn:
""""""
Expand Down Expand Up @@ -717,7 +708,7 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")

if np.array(seed, dtype="uint32") != np.array(seed):
raise ValueError(f"seed must be compatible with 'uint32' casting")
raise ValueError("seed must be compatible with 'uint32' casting")

self._seed = seed
return [seed]
Expand All @@ -735,7 +726,7 @@ def make_vec_env_stable_baselines(

env = jax_dataclass_env

vec_env_kwargs = vec_env_kwargs if vec_env_kwargs is not None else dict()
vec_env_kwargs = vec_env_kwargs if vec_env_kwargs is not None else {}

vec_env = JaxVectorEnv(
func_env=env,
Expand All @@ -755,23 +746,21 @@ def make_vec_env_stable_baselines(

os.environ["IGN_GAZEBO_RESOURCE_PATH"] = "/conda/share/" # DEBUG

max_episode_steps = 200
MAX_EPISODE_STEP = 200
func_env = NaNHandlerWrapper(env=ErgoCubWalkFuncEnvV0())

if max_episode_steps is not None:
func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps)
if MAX_EPISODE_STEP is not None:
func_env = TimeLimit(env=func_env, max_episode_steps=MAX_EPISODE_STEP)

func_env = ClipActionWrapper(
env=SquashActionWrapper(env=ActionNoiseWrapper(env=func_env)),
)

vec_env = make_vec_env_stable_baselines(
jax_dataclass_env=func_env,
n_envs=6000,
n_envs=256,
seed=42,
vec_env_kwargs=dict(
jit_compile=True,
),
vec_env_kwargs={jit_compile: True},
)

vec_env = VecMonitor(
Expand All @@ -798,11 +787,11 @@ def make_vec_env_stable_baselines(
target_kl=0.025,
verbose=2,
learning_rate=0.000_300,
policy_kwargs=dict(
activation_fn=nn.ReLU,
net_arch=dict(pi=[512, 512], vf=[512, 512]),
log_std_init=np.log(0.05),
),
policy_kwargs={
activation_fn: nn.ReLU,
net_arch: {pi: [512, 512], vf: [512, 512]},
log_std_init: np.log(0.05),
},
)

print(model.policy)
Expand Down

0 comments on commit 18a5adb

Please sign in to comment.