Skip to content

Commit

Permalink
Clean up and format
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Oct 9, 2023
1 parent 478a1fa commit 1c53f37
Show file tree
Hide file tree
Showing 16 changed files with 43 additions and 145 deletions.
2 changes: 1 addition & 1 deletion src/jaxgym/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
""""""

if seed is None:
seed = np.random.default_rng().integers(0, 2**32 - 1, dtype="uint32")
seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")

if np.array(seed, dtype="uint32") != np.array(seed):
raise ValueError(f"seed must be compatible with 'uint32' casting")
Expand Down
2 changes: 1 addition & 1 deletion src/jaxgym/envs/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def reward(
# type(self).terminal(self=self, state=next_state), dtype=float
reward_pivot = jnp.cos(observation.pivot_pos)
cost_action = jnp.sqrt(action.dot(action))
cost_pivot_vel = jnp.sqrt(observation.pivot_vel**2)
cost_pivot_vel = jnp.sqrt(observation.pivot_vel ** 2)
cost_linear_pos = jnp.abs(observation.linear_pos)

reward = 0
Expand Down
146 changes: 22 additions & 124 deletions src/jaxgym/envs/ergocub.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
import dataclasses
import functools
import multiprocessing
import pathlib
import os
import warnings
from typing import Any, ClassVar, Dict, List, Optional

warnings.simplefilter(action="ignore", category=FutureWarning)

import functools
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Type, Union

import gymnasium as gym
import jax.numpy as jnp
import jax.random
import jax_dataclasses
import jaxgym.jax.pytree_space as spaces
import matplotlib.pyplot as plt
import mujoco
import numpy as np
import numpy.typing as npt
import rod
import stable_baselines3
from gymnasium.experimental.vector.vector_env import VectorWrapper
from meshcat_viz import MeshcatWorld
from resolve_robotics_uri_py import resolve_robotics_uri
from stable_baselines3 import PPO
from stable_baselines3.common import vec_env as vec_env_sb
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecMonitor, VecNormalize
from torch import nn

import jaxgym.jax.pytree_space as spaces
import jaxsim.typing as jtp
from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper, JaxEnv, PyTree
from jaxgym.vector.jax import FlattenSpacesVecWrapper, JaxVectorEnv
from jaxgym.wrappers.jax import (
Expand All @@ -35,24 +35,15 @@
TimeLimit,
ToNumPyWrapper,
)
from meshcat_viz import MeshcatWorld
from resolve_robotics_uri_py import resolve_robotics_uri
from scipy.spatial.transform import Rotation
from stable_baselines3 import PPO
from stable_baselines3.common import vec_env as vec_env_sb
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor, VecNormalize

import jaxsim.typing as jtp
from jaxsim import JaxSim
from jaxsim.physics.algos.soft_contacts import SoftContactsParams
from jaxsim.simulation import simulator_callbacks
from jaxsim.simulation.ode_integration import IntegratorType
from jaxsim.simulation.simulator import SimulatorData, VelRepr
from jaxsim.utils import JaxsimDataclass, Mutability

warnings.simplefilter(action="ignore", category=FutureWarning)


@jax_dataclasses.pytree_dataclass
class ErgoCubObservation(JaxsimDataclass):
Expand Down Expand Up @@ -173,8 +164,8 @@ def __post_init__(self) -> None:
model = self.jaxsim.get_model(model_name="ErgoCub")

# Create the action space (static attribute)
# with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
high = jnp.array([25.0] * model.dofs(), dtype=float)
with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
high = jnp.array([25.0] * model.dofs(), dtype=float)

with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
self._action_space = spaces.PyTree(low=-high, high=high)
Expand Down Expand Up @@ -250,7 +241,7 @@ def jaxsim(self) -> JaxSim:

def initial(self, rng: Any = None) -> StateType:
""""""
# assert isinstance(rng, jax.random.PRNGKey)
assert jax.dtypes.issubdtype(rng, jax.dtypes.prng_key)

# Split the key
subkey1, subkey2 = jax.random.split(rng, num=2)
Expand Down Expand Up @@ -723,7 +714,7 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
""""""

if seed is None:
seed = np.random.default_rng().integers(0, 2**32 - 1, dtype="uint32")
seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")

if np.array(seed, dtype="uint32") != np.array(seed):
raise ValueError(f"seed must be compatible with 'uint32' casting")
Expand Down Expand Up @@ -762,8 +753,6 @@ def make_vec_env_stable_baselines(

return vec_env_sb

import os

os.environ["IGN_GAZEBO_RESOURCE_PATH"] = "/conda/share/" # DEBUG

max_episode_steps = 200
Expand All @@ -778,7 +767,7 @@ def make_vec_env_stable_baselines(

vec_env = make_vec_env_stable_baselines(
jax_dataclass_env=func_env,
n_envs=512,
n_envs=6000,
seed=42,
vec_env_kwargs=dict(
jit_compile=True,
Expand All @@ -789,21 +778,13 @@ def make_vec_env_stable_baselines(
venv=VecNormalize(
venv=vec_env,
training=True,
norm_obs=True,
norm_reward=True,
clip_obs=10.0,
clip_reward=10.0,
gamma=0.95,
epsilon=1e-8,
)
)

vec_env.venv.venv.logger_rewards = []
seed = vec_env.seed(seed=7)[0]
_ = vec_env.reset()

import torch as th

model = PPO(
"MlpPolicy",
env=vec_env,
Expand All @@ -815,98 +796,15 @@ def make_vec_env_stable_baselines(
clip_range=0.1,
normalize_advantage=True,
target_kl=0.025,
verbose=1,
verbose=2,
learning_rate=0.000_300,
policy_kwargs=dict(
activation_fn=th.nn.ReLU,
activation_fn=nn.ReLU,
net_arch=dict(pi=[512, 512], vf=[512, 512]),
log_std_init=np.log(0.05),
),
)

print(model.policy)

model = model.learn(total_timesteps=50_000, progress_bar=False)

# =========
# Visualize
# =========

visualize = False

def visualizer(
env: JaxEnv | Callable[[None], JaxEnv], policy: BaseAlgorithm
) -> Callable[[Optional[int]], None]:
""""""

import numpy as np
import rod
from loop_rate_limiters import RateLimiter
from meshcat_viz import MeshcatWorld

from jaxsim import JaxSim

# Open the visualizer
world = MeshcatWorld()
world.open()

# Create the JaxSim environment and get the simulator
env = env() if isinstance(env, Callable) else env
sim: JaxSim = env.unwrapped.func_env.unwrapped.jaxsim

# Extract the SDF string from the simulated model
jaxsim_model = sim.get_model(model_name="cartpole")
rod_model = jaxsim_model.physics_model.description.extra_info["sdf_model"]
rod_sdf = rod.Sdf(model=rod_model, version="1.7")
sdf_string = rod_sdf.serialize(pretty=True)

# Insert the model from a URDF/SDF resource
model_name = world.insert_model(model_description=sdf_string, is_urdf=False)

# Create the visualization function
def rollout(seed: Optional[int] = None) -> None:
""""""

# Reset the environment
observation, state_info = env.reset(seed=seed)

# Initialize the model state with the initial observation
world.update_model(
model_name=model_name,
joint_names=["linear", "pivot"],
joint_positions=np.array([observation[0], observation[2]]),
)

rtf = 1.0
down_sampling = 1
rate = RateLimiter(frequency=float(rtf / (sim.dt() * down_sampling)))

done = False

# Visualization loop
while not done:
action, _ = policy.predict(observation=observation, deterministic=True)
print(action)
observation, _, terminated, truncated, _ = env.step(action)
done = terminated or truncated

world.update_model(
model_name=model_name,
joint_names=["linear", "pivot"],
joint_positions=np.array([observation[0], observation[2]]),
)

print(done)
rate.sleep()

print("done")

return rollout

if visualize:
rollout_visualizer = visualizer(env=lambda: make_jax_env(1_000), policy=model)

import time

time.sleep(3)
rollout_visualizer(None)
model = model.learn(total_timesteps=50000, progress_bar=True)
2 changes: 1 addition & 1 deletion src/jaxgym/functional/_jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self.render_state = None

np_random, _ = seeding.np_random()
seed = np_random.integers(0, 2**32 - 1, dtype="uint32")
seed = np_random.integers(0, 2 ** 32 - 1, dtype="uint32")

self.rng = jrng.PRNGKey(seed)

Expand Down
6 changes: 3 additions & 3 deletions src/jaxgym/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from gymnasium.envs.registration import EnvSpec
from gymnasium.utils import seeding

# from meshcat_viz import MeshcatWorld

import jaxgym.jax.pytree_space as spaces
from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper
from jaxsim import logging

# from meshcat_viz import MeshcatWorld


class JaxEnv(gym.Env[ObsType, ActType], Generic[ObsType, ActType]):
""""""
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
self._meshcat_window = None # old

# Initialize the RNGs with a random seed
seed = np.random.default_rng().integers(0, 2**32 - 1, dtype="uint32")
seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
self._np_random, _ = seeding.np_random(seed=int(seed))
self.rng = jax.random.PRNGKey(seed=seed)

Expand Down
4 changes: 2 additions & 2 deletions src/jaxgym/jax/pytree_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def check() -> None:
seed = (
seed
if seed is not None
else np.random.default_rng().integers(0, 2**32 - 1, dtype="uint32")
else np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
)

# Initialize the JAX random key
Expand Down Expand Up @@ -216,7 +216,7 @@ def seed(self, seed: int | None = None) -> list[int]:
seed = (
seed
if seed is not None
else np.random.default_rng().integers(0, 2**32 - 1, dtype="uint32")
else np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
)

self.key = jax.random.PRNGKey(seed=seed)
Expand Down
2 changes: 1 addition & 1 deletion src/jaxgym/stable_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
"""Sets the random seeds for all environments."""

if seed is None:
seed = np.random.default_rng().integers(0, 2**32 - 1, dtype="uint32")
seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")

if np.array(seed, dtype="uint32") != np.array(seed):
raise ValueError(f"seed must be compatible with 'uint32' casting")
Expand Down
1 change: 0 additions & 1 deletion src/jaxgym/vector/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .vector_env import JaxVectorEnv
from .wrappers import FlattenSpacesVecWrapper

4 changes: 2 additions & 2 deletions src/jaxgym/vector/jax/vector_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from jaxsim import logging
from typing import Any, Sequence

import jax.flatten_util
Expand All @@ -23,6 +22,7 @@
import jaxsim.typing as jtp
from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper
from jaxgym.wrappers.jax import JaxTransformWrapper, TimeLimit
from jaxsim import logging
from jaxsim.utils import not_tracing


Expand Down Expand Up @@ -128,7 +128,7 @@ def has_wrapper(
# self.render_state = None

# Initialize the RNGs with a random seed
seed = np.random.default_rng().integers(0, 2**32 - 1, dtype="uint32")
seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
self._np_random, _ = seeding.np_random(seed=int(seed))
self._key = jax.random.PRNGKey(seed=seed)

Expand Down
4 changes: 2 additions & 2 deletions src/jaxgym/wrappers/jax/action_noise.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Callable, Generic

import numpy.typing as npt
import jax.numpy as jnp
import jax.flatten_util
import jax.numpy as jnp
import jax.tree_util
import jax_dataclasses
import numpy.typing as npt
from gymnasium.experimental.functional import (
ActType,
ObsType,
Expand Down
2 changes: 1 addition & 1 deletion src/jaxgym/wrappers/jax/flatten_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
TerminalType,
)

from jaxsim import logging
from jaxgym.jax import JaxDataclassWrapper
from jaxsim import logging

WrapperStateType = StateType
WrapperObsType = jnp.ndarray
Expand Down
1 change: 1 addition & 0 deletions src/jaxgym/wrappers/jax/time_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def step_info(
# | dict(terminal_observation=self.observation(state=next_state))
)


# @jax.jit
# def has_field(d) -> bool:
# import jax.lax
Expand Down
2 changes: 1 addition & 1 deletion src/jaxgym/wrappers/jax/time_limit_sb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
TerminalType,
)

from jaxgym.jax import JaxDataclassWrapper
from jaxgym.functional import FuncEnv
from jaxgym.jax import JaxDataclassWrapper
from jaxsim import logging

WrapperStateType = StateType
Expand Down
2 changes: 1 addition & 1 deletion src/jaxgym/wrappers/jax/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
TerminalType,
)

from jaxsim import logging
from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper
from jaxgym.wrappers import TransformWrapper
from jaxsim import logging
from jaxsim.utils import JaxsimDataclass, Mutability


Expand Down
Loading

0 comments on commit 1c53f37

Please sign in to comment.