diff --git a/.gitignore b/.gitignore
index f2520562c..17dce9518 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,7 @@
+core
+*.pickle
+*.zip
+
# IDEs
.idea*
diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 000000000..cfb2b6b85
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,16 @@
+{
+ // Use IntelliSense to learn about possible attributes.
+ // Hover to view descriptions of existing attributes.
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "Python: Current File",
+ "type": "python",
+ "request": "launch",
+ "program": "${file}",
+ "console": "integratedTerminal",
+ "justMyCode": true
+ },
+ ]
+}
\ No newline at end of file
diff --git a/examples/resources/ant.sdf b/examples/resources/ant.sdf
new file mode 100644
index 000000000..c47dc030b
--- /dev/null
+++ b/examples/resources/ant.sdf
@@ -0,0 +1,520 @@
+
+
+
+
+
+ 1.0
+
+ 0.025
+ 0.025
+ 0.025
+ 0.0
+ 0.0
+ 0.0
+
+
+
+
+
+ 0.25
+
+
+
+
+
+
+ 0.2886751345948129 0.2886751345948129 0.2886751345948129
+
+
+
+
+
+ -2.7755575615628914e-17 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+ 0.3
+
+ 0.0044800000000000005
+ 0.0044800000000000005
+ 0.00096
+ 0.0
+ 0.0
+ 0.0
+
+ 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.07999999999999997 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+
+
+
+ 0.08800000000000001
+
+
+ 0.4800000000000001 5.551115123125783e-17 0.0 0.0 0.0 -5.551115123125783e-17
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+ 5.551115123125783e-17 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+ 0.2
+
+ 0.002986666666666667
+ 0.002986666666666667
+ 0.00064
+ 0.0
+ 0.0
+ 0.0
+
+ 0.20000000000000007 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.20000000000000007 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.4000000000000001 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.20000000000000007 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.4000000000000001 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+
+
+ -2.7755575615628914e-17 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+ 0.3
+
+ 0.0044800000000000005
+ 0.0044800000000000005
+ 0.00096
+ 0.0
+ 0.0
+ 0.0
+
+ 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.07999999999999997 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+
+
+
+ 0.08800000000000001
+
+
+ 0.4800000000000001 -5.551115123125783e-17 0.0 0.0 0.0 5.551115123125783e-17
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+ 5.551115123125783e-17 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+ 0.2
+
+ 0.002986666666666667
+ 0.002986666666666667
+ 0.00064
+ 0.0
+ 0.0
+ 0.0
+
+ 0.20000000000000007 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.20000000000000007 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.4000000000000001 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.20000000000000007 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.4000000000000001 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+
+
+ 0.0 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+ 0.3
+
+ 0.0044800000000000005
+ 0.0044800000000000005
+ 0.00096
+ 0.0
+ 0.0
+ 0.0
+
+ 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.08 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+
+
+
+ 0.08800000000000001
+
+
+ 0.4800000000000001 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+ 1.1102230246251565e-16 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+ 0.2
+
+ 0.002986666666666667
+ 0.002986666666666667
+ 0.00064
+ 0.0
+ 0.0
+ 0.0
+
+ 0.20000000000000012 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.20000000000000012 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.40000000000000013 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.20000000000000012 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.40000000000000013 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+
+
+ 0.0 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+ 0.3
+
+ 0.0044800000000000005
+ 0.0044800000000000005
+ 0.00096
+ 0.0
+ 0.0
+ 0.0
+
+ 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.08 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+
+
+
+ 0.08800000000000001
+
+
+ 0.4800000000000001 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+ 1.1102230246251565e-16 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+ 0.2
+
+ 0.002986666666666667
+ 0.002986666666666667
+ 0.00064
+ 0.0
+ 0.0
+ 0.0
+
+ 0.20000000000000012 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.20000000000000012 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.40000000000000013 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+
+
+
+ 0.08
+ 0.4
+
+
+ 0.20000000000000012 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0
+
+
+
+
+ 0.08
+
+
+ 0.40000000000000013 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+
+
+ torso
+ leg_front_left_upper
+ 0.1767766952966369 0.1767766952966369 0.0 0.0 0.0 0.7853981633974484
+
+ 0 0 1
+
+ -0.5235987755982988
+ 0.5235987755982988
+
+
+
+
+ leg_front_left_upper
+ leg_front_left_lower
+ 0.48 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+ 0 1 0
+
+ 0.5235987755982988
+ 1.2217304763960306
+
+
+
+
+ torso
+ leg_front_right_upper
+ 0.1767766952966369 -0.1767766952966369 0.0 0.0 0.0 -0.7853981633974484
+
+ 0 0 1
+
+ -0.5235987755982988
+ 0.5235987755982988
+
+
+
+
+ leg_front_right_upper
+ leg_front_right_lower
+ 0.48 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+ 0 1 0
+
+ 0.5235987755982988
+ 1.2217304763960306
+
+
+
+
+ torso
+ leg_back_left_upper
+ -0.1767766952966369 0.1767766952966369 0.0 0.0 0.0 2.356194490192345
+
+ 0 0 1
+
+ -0.5235987755982988
+ 0.5235987755982988
+
+
+
+
+ leg_back_left_upper
+ leg_back_left_lower
+ 0.48000000000000004 0.0 0.0 0.0 0.0 5.551115123125783e-17
+
+ 0 1 0
+
+ 0.5235987755982988
+ 1.2217304763960306
+
+
+
+
+ torso
+ leg_back_right_upper
+ -0.1767766952966369 -0.1767766952966369 0.0 0.0 0.0 -2.356194490192345
+
+ 0 0 1
+
+ -0.5235987755982988
+ 0.5235987755982988
+
+
+
+
+ leg_back_right_upper
+ leg_back_right_lower
+ 0.48000000000000004 0.0 0.0 0.0 0.0 -5.551115123125783e-17
+
+ 0 1 0
+
+ 0.5235987755982988
+ 1.2217304763960306
+
+
+
+
+
diff --git a/examples/resources/cartpole.urdf b/examples/resources/cartpole.urdf
new file mode 100644
index 000000000..7df2fd4c9
--- /dev/null
+++ b/examples/resources/cartpole.urdf
@@ -0,0 +1,81 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/pyproject.toml b/pyproject.toml
index 5e229ba9e..4c31bb012 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,3 +15,10 @@ line-length = 88
[tool.isort]
profile = "black"
multi_line_output = 3
+
+[tool.pytest.ini_options]
+minversion = "6.0"
+addopts = "-rsxX -v --strict-markers --forked"
+testpaths = [
+ "tests",
+]
diff --git a/setup.cfg b/setup.cfg
index 398189d00..4bcb49574 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -60,6 +60,7 @@ install_requires =
pptree
rod
scipy
+ typing_extensions; python_version < "3.11"
[options.packages.find]
where = src
@@ -70,13 +71,10 @@ style =
isort
testing =
idyntree
- pytest
+ pytest >= 6.0
+ pytest-forked
pytest-icdiff
robot-descriptions
all =
%(style)s
%(testing)s
-
-[tool:pytest]
-addopts = -rsxX -v --strict-markers
-testpaths = tests
diff --git a/src/jaxgym/__init__.py b/src/jaxgym/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/jaxgym/__main__.py b/src/jaxgym/__main__.py
new file mode 100644
index 000000000..7b9d9fde6
--- /dev/null
+++ b/src/jaxgym/__main__.py
@@ -0,0 +1,1795 @@
+import warnings
+
+warnings.simplefilter(action="ignore", category=FutureWarning)
+
+import functools
+import pathlib
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
+
+import gymnasium as gym
+import jax.random
+import matplotlib.pyplot as plt
+import mujoco
+import numpy as np
+import numpy.typing as npt
+import stable_baselines3
+from gymnasium.experimental.vector.vector_env import VectorWrapper
+from sb3_contrib import TRPO
+from scipy.spatial.transform import Rotation
+from stable_baselines3 import PPO
+from stable_baselines3.common import vec_env as vec_env_sb
+from stable_baselines3.common.base_class import BaseAlgorithm
+from stable_baselines3.common.env_checker import check_env
+from stable_baselines3.common.env_util import make_vec_env
+from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor, VecNormalize
+
+import jaxsim.typing as jtp
+from jaxgym.envs.ant import AntReachTargetFuncEnvV0
+from jaxgym.envs.cartpole import CartpoleSwingUpFuncEnvV0
+from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper, JaxEnv, PyTree
+from jaxgym.vector.jax import FlattenSpacesVecWrapper, JaxVectorEnv
+from jaxgym.wrappers.jax import ( # TimeLimitStableBaselines,
+ ActionNoiseWrapper,
+ ClipActionWrapper,
+ FlattenSpacesWrapper,
+ JaxTransformWrapper,
+ NaNHandlerWrapper,
+ SquashActionWrapper,
+ TimeLimit,
+ ToNumPyWrapper,
+)
+
+#
+#
+#
+
+
+class MujocoModel:
+ """"""
+
+ def __init__(self, xml_path: pathlib.Path) -> None:
+ """"""
+
+ if not xml_path.exists():
+ raise FileNotFoundError(f"Could not find file '{xml_path}'")
+
+ self.model = mujoco.MjModel.from_xml_path(filename=str(xml_path), assets=None)
+
+ self.data = mujoco.MjData(self.model)
+
+ # Populate data
+ mujoco.mj_forward(self.model, self.data)
+
+ # print(self.model.opt)
+
+ def time(self) -> float:
+ """"""
+
+ return self.data.time
+
+ def gravity(self) -> npt.NDArray:
+ """"""
+
+ return self.model.opt.gravity
+
+ def number_of_joints(self) -> int:
+ """"""
+
+ return self.model.njnt
+
+ def number_of_geometries(self) -> int:
+ """"""
+
+ return self.model.ngeom
+
+ def number_of_bodies(self) -> int:
+ """"""
+
+ return self.model.nbody
+
+ def joint_names(self) -> List[str]:
+ """"""
+
+ return [
+ mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_JOINT, idx)
+ for idx in range(self.number_of_joints())
+ ]
+
+ def joint_dofs(self, joint_name: str) -> int:
+ """"""
+
+ if joint_name not in self.joint_names():
+ raise ValueError(f"Joint '{joint_name}' not found")
+
+ return self.data.joint(joint_name).qpos.size
+
+ def joint_position(self, joint_name: str) -> npt.NDArray:
+ """"""
+
+ if joint_name not in self.joint_names():
+ raise ValueError(f"Joint '{joint_name}' not found")
+
+ return self.data.joint(joint_name).qpos
+
+ def joint_velocity(self, joint_name: str) -> npt.NDArray:
+ """"""
+
+ if joint_name not in self.joint_names():
+ raise ValueError(f"Joint '{joint_name}' not found")
+
+ return self.data.joint(joint_name).qvel
+
+ def body_names(self) -> List[str]:
+ """"""
+
+ return [
+ mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_BODY, idx)
+ for idx in range(self.number_of_bodies())
+ ]
+
+ def body_position(self, body_name: str) -> npt.NDArray:
+ """"""
+
+ if body_name not in self.body_names():
+ raise ValueError(f"Body '{body_name}' not found")
+
+ return self.data.body(body_name).xpos
+
+ def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:
+ """"""
+
+ if body_name not in self.body_names():
+ raise ValueError(f"Body '{body_name}' not found")
+
+ return (
+ self.data.body(body_name).xmat if dcm else self.data.body(body_name).xquat
+ )
+
+ def geometry_names(self) -> List[str]:
+ """"""
+
+ return [
+ mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_GEOM, idx)
+ for idx in range(self.number_of_geometries())
+ ]
+
+ def geometry_position(self, geometry_name: str) -> npt.NDArray:
+ """"""
+
+ if geometry_name not in self.geometry_names():
+ raise ValueError(f"Geometry '{geometry_name}' not found")
+
+ return self.data.geom(geometry_name).xpos
+
+ def geometry_orientation(
+ self, geometry_name: str, dcm: bool = False
+ ) -> npt.NDArray:
+ """"""
+
+ if geometry_name not in self.geometry_names():
+ raise ValueError(f"Geometry '{geometry_name}' not found")
+
+ R = np.reshape(self.data.geom(geometry_name).xmat, (3, 3))
+
+ if dcm:
+ return R
+
+ q_xyzw = Rotation.from_matrix(R).as_quat()
+ return q_xyzw[[3, 0, 1, 2]]
+
+ def to_string(self) -> Tuple[str, str]:
+ """Convert a mujoco model to a string."""
+
+ import tempfile
+
+ with tempfile.NamedTemporaryFile(mode="w+") as f:
+ mujoco.mj_saveLastXML(f.name, self.model)
+ mjcf_string = pathlib.Path(f.name).read_text()
+
+ with tempfile.NamedTemporaryFile(mode="w+") as f:
+ mujoco.mj_printModel(self.model, f.name)
+ compiled_model_string = pathlib.Path(f.name).read_text()
+
+ return mjcf_string, compiled_model_string
+
+
+# Full cartpole example with collection loop
+#
+# -> validate with visualization
+# -> move to pytorch later
+# -> study TensorDict -> create wrapper -> use PPO
+# -> Check JaxToTorch wrapper and adapt it to work for JaxVectorEnv
+# -> alternatively, wait for stable_baselines3 + gymnasium (open issue)
+
+# For TensorDict, check:
+#
+# -> create docker image with torch 2.0
+# - EnvBase https://github.com/pytorch/rl/blob/main/torchrl/envs/common.py#L120 (has batch_size -> maybe vectorized?)
+# - torchrl.collectors
+# - torchrl.envs.libs.gym.GymWrapper
+# - torchrl.envs.libs.brax.BraxWrapper
+
+# TODO: JaxSimEnv with render support?
+
+
+class CustomVecEnvSB(vec_env_sb.VecEnv):
+ """"""
+
+ metadata = {"render_modes": []}
+
+ def __init__(
+ self,
+ jax_vector_env: JaxVectorEnv | VectorWrapper,
+ log_rewards: bool = False,
+ # num_envs: int,
+ # observation_space: spaces.Space,
+ # action_space: spaces.Space,
+ # render_mode: Optional[str] = None,
+ ) -> None:
+ """"""
+
+ if not isinstance(jax_vector_env.unwrapped, JaxVectorEnv):
+ raise TypeError(type(jax_vector_env))
+
+ self.jax_vector_env = jax_vector_env
+
+ single_env_action_space: PyTree = jax_vector_env.unwrapped.single_action_space
+
+ single_env_observation_space: PyTree = (
+ jax_vector_env.unwrapped.single_observation_space
+ )
+
+ super().__init__(
+ num_envs=self.jax_vector_env.num_envs,
+ action_space=single_env_action_space.to_box(),
+ observation_space=single_env_observation_space.to_box(),
+ render_mode=None,
+ )
+
+ self.actions = np.zeros_like(self.jax_vector_env.action_space.sample())
+
+ # Initialize the RNG seed
+ self._seed = None
+ self.seed()
+
+ # Initialize the rewards logger
+ self.logger_rewards = [] if log_rewards else None
+
+ def reset(self) -> vec_env_sb.base_vec_env.VecEnvObs:
+ """"""
+
+ observations, state_infos = self.jax_vector_env.reset(seed=self._seed)
+ return np.array(observations)
+
+ def step_async(self, actions: np.ndarray) -> None:
+ self.actions = actions
+
+ @staticmethod
+ @functools.partial(jax.jit, static_argnames=("batch_size",))
+ def tree_inverse_transpose(pytree: jtp.PyTree, batch_size: int) -> List[jtp.PyTree]:
+ """"""
+
+ return [
+ jax.tree_util.tree_map(lambda leaf: leaf[i], pytree)
+ for i in range(batch_size)
+ ]
+
+ def step_wait(self) -> vec_env_sb.base_vec_env.VecEnvStepReturn:
+ """"""
+
+ (
+ observations,
+ rewards,
+ terminals,
+ truncated,
+ step_infos,
+ ) = self.jax_vector_env.step(actions=self.actions)
+
+ done = np.logical_or(terminals, truncated)
+
+ # list_of_step_infos = [
+ # jax.tree_util.tree_map(lambda l: l[i], step_infos)
+ # for i in range(self.jax_vector_env.num_envs)
+ # ]
+
+ list_of_step_infos = self.tree_inverse_transpose(
+ pytree=step_infos, batch_size=self.jax_vector_env.num_envs
+ )
+
+ # def pytree_to_numpy(pytree: jtp.PyTree) -> jtp.PyTree:
+ # return jax.tree_util.tree_map(lambda leaf: np.array(leaf), pytree)
+ #
+ # list_of_step_infos_numpy = [pytree_to_numpy(pt) for pt in list_of_step_infos]
+
+ list_of_step_infos_numpy = [
+ ToNumPyWrapper.pytree_to_numpy(pytree=pt) for pt in list_of_step_infos
+ ]
+
+ if self.logger_rewards is not None:
+ self.logger_rewards.append(np.array(rewards).mean())
+
+ return (
+ np.array(observations),
+ np.array(rewards),
+ np.array(done),
+ list_of_step_infos_numpy,
+ )
+
+ def close(self) -> None:
+ return self.jax_vector_env.close()
+
+ def get_attr(
+ self, attr_name: str, indices: vec_env_sb.base_vec_env.VecEnvIndices = None
+ ) -> List[Any]:
+ raise NotImplementedError
+
+ def set_attr(
+ self,
+ attr_name: str,
+ value: Any,
+ indices: vec_env_sb.base_vec_env.VecEnvIndices = None,
+ ) -> None:
+ raise NotImplementedError
+
+ def env_method(
+ self,
+ method_name: str,
+ *method_args,
+ indices: vec_env_sb.base_vec_env.VecEnvIndices = None,
+ **method_kwargs,
+ ) -> List[Any]:
+ raise NotImplementedError
+
+ def env_is_wrapped(
+ self,
+ wrapper_class: Type[gym.Wrapper],
+ indices: vec_env_sb.base_vec_env.VecEnvIndices = None,
+ ) -> List[bool]:
+ return [False] * self.num_envs
+ # raise NotImplementedError
+
+ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
+ """"""
+
+ if seed is None:
+ seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
+
+ if np.array(seed, dtype="uint32") != np.array(seed):
+ raise ValueError(f"seed must be compatible with 'uint32' casting")
+
+ self._seed = seed
+ return [seed]
+
+ # _ = self.jax_vector_env.reset(seed=seed)
+ # return [None]
+
+
+def make_vec_env_stable_baselines(
+ jax_dataclass_env: JaxDataclassEnv | JaxDataclassWrapper,
+ n_envs: int = 1,
+ seed: Optional[int] = None,
+ # monitor_dir: Optional[str] = None,
+ vec_env_kwargs: Optional[Dict[str, Any]] = None,
+ #
+ # env_id: Union[str, Callable[..., gym.Env]],
+ # # n_envs: int = 1,
+ # # seed: Optional[int] = None,
+ # start_index: int = 0,
+ # monitor_dir: Optional[str] = None,
+ # wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
+ # env_kwargs: Optional[Dict[str, Any]] = None,
+ # vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
+ # vec_env_kwargs: Optional[Dict[str, Any]] = None,
+ # monitor_kwargs: Optional[Dict[str, Any]] = None,
+ # wrapper_kwargs: Optional[Dict[str, Any]] = None,
+) -> vec_env_sb.VecEnv:
+ """"""
+
+ env = jax_dataclass_env
+
+ vec_env_kwargs = vec_env_kwargs if vec_env_kwargs is not None else dict()
+
+ # Vectorize the environment.
+ # Note: it automatically wraps the environment in a TimeLimit wrapper.
+ # Note: the space must be PyTree.
+ vec_env = JaxVectorEnv(
+ func_env=env,
+ num_envs=n_envs,
+ **vec_env_kwargs,
+ )
+
+ # Flatten the PyTree spaces to regular Box spaces
+ vec_env = FlattenSpacesVecWrapper(env=vec_env)
+
+ # if seed is not None:
+ # _ = vec_env.reset(seed=seed)
+
+ vec_env_sb = CustomVecEnvSB(jax_vector_env=vec_env, log_rewards=True)
+
+ if seed is not None:
+ _ = vec_env_sb.seed(seed=seed)
+
+ return vec_env_sb
+
+
+def visualizer(
+ env: JaxEnv | Callable[[None], JaxEnv], policy: BaseAlgorithm
+) -> Callable[[Optional[int]], None]:
+ """"""
+
+ import numpy as np
+ import rod
+ from loop_rate_limiters import RateLimiter
+ from meshcat_viz import MeshcatWorld
+
+ from jaxsim import JaxSim
+
+ # Open the visualizer
+ world = MeshcatWorld()
+ world.open()
+
+ # Create the JaxSim environment and get the simulator
+ env = env() if isinstance(env, Callable) else env
+ sim: JaxSim = env.unwrapped.func_env.unwrapped.jaxsim
+
+ # Extract the SDF string from the simulated model
+ jaxsim_model = sim.get_model(model_name="cartpole")
+ rod_model = jaxsim_model.physics_model.description.extra_info["sdf_model"]
+ rod_sdf = rod.Sdf(model=rod_model, version="1.7")
+ sdf_string = rod_sdf.serialize(pretty=True)
+
+ # Insert the model from a URDF/SDF resource
+ model_name = world.insert_model(model_description=sdf_string, is_urdf=False)
+
+ # Create the visualization function
+ def rollout(seed: Optional[int] = None) -> None:
+ """"""
+
+ # Reset the environment
+ observation, state_info = env.reset(seed=seed)
+
+ # Initialize the model state with the initial observation
+ world.update_model(
+ model_name=model_name,
+ joint_names=["linear", "pivot"],
+ joint_positions=np.array([observation[0], observation[2]]),
+ )
+
+ rtf = 1.0
+ down_sampling = 1
+ rate = RateLimiter(frequency=float(rtf / (sim.dt() * down_sampling)))
+
+ done = False
+
+ # Visualization loop
+ while not done:
+ action, _ = policy.predict(observation=observation, deterministic=True)
+ print(action)
+ observation, _, terminated, truncated, _ = env.step(action)
+ done = terminated or truncated
+
+ world.update_model(
+ model_name=model_name,
+ joint_names=["linear", "pivot"],
+ joint_positions=np.array([observation[0], observation[2]]),
+ )
+
+ print(done)
+ rate.sleep()
+
+ print("done")
+
+ return rollout
+
+
+# ============
+# ENVIRONMENTS
+# ============
+
+# TODO:
+#
+# - Initialize ANT already in contact -> otherwise it jumps around (idle after falling?)
+# - Tune spring for joint limits and joint friction
+# - wrapper to squash action space to [-1, 1]
+
+
+def make_jax_env_cartpole(
+ render_mode: Optional[str] = None,
+ max_episode_steps: Optional[int] = 500,
+) -> JaxEnv:
+ """"""
+
+ # TODO: single env -> time limit with stable_baselines?
+
+ import os
+
+ import torch
+
+ if not torch.cuda.is_available():
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
+ os.environ["NVIDIA_VISIBLE_DEVICES"] = ""
+
+ import warnings
+
+ warnings.simplefilter(action="ignore", category=FutureWarning)
+
+ env = NaNHandlerWrapper(env=CartpoleSwingUpFuncEnvV0())
+
+ if max_episode_steps is not None:
+ env = TimeLimit(env=env, max_episode_steps=max_episode_steps)
+
+ return JaxEnv(
+ render_mode=render_mode,
+ func_env=ToNumPyWrapper(
+ env=JaxTransformWrapper(
+ function=jax.jit,
+ env=FlattenSpacesWrapper(
+ env=ClipActionWrapper(
+ env=SquashActionWrapper(env=env),
+ )
+ ),
+ )
+ ),
+ )
+
+
+def make_jax_env_ant(
+ render_mode: Optional[str] = None,
+ max_episode_steps: Optional[int] = 1_000,
+) -> JaxEnv:
+ """"""
+
+ # TODO: single env -> time limit with stable_baselines?
+
+ import os
+
+ import torch
+
+ if not torch.cuda.is_available():
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
+ os.environ["NVIDIA_VISIBLE_DEVICES"] = ""
+
+ import warnings
+
+ warnings.simplefilter(action="ignore", category=FutureWarning)
+
+ env = NaNHandlerWrapper(env=AntReachTargetFuncEnvV0())
+
+ if max_episode_steps is not None:
+ env = TimeLimit(env=env, max_episode_steps=max_episode_steps)
+
+ return JaxEnv(
+ render_mode=render_mode,
+ func_env=ToNumPyWrapper(
+ env=JaxTransformWrapper(
+ function=jax.jit,
+ env=FlattenSpacesWrapper(
+ env=ClipActionWrapper(
+ env=SquashActionWrapper(env=env),
+ )
+ ),
+ )
+ ),
+ )
+
+
+# =============================
+# Test JaxVecEnv vs DummyVecEnv
+# =============================
+
+if __name__ == "__main__?":
+ """"""
+
+ max_episode_steps = 200
+ func_env = NaNHandlerWrapper(env=CartpoleSwingUpFuncEnvV0())
+
+ if max_episode_steps is not None:
+ func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps)
+
+ func_env = (
+ # ToNumPyWrapper(env=
+ # env=JaxTransformWrapper(
+ # function=jax.jit,
+ # env=FlattenSpacesWrapper(
+ ClipActionWrapper(
+ env=SquashActionWrapper(env=func_env),
+ )
+ # ),
+ # ),
+ # )
+ )
+
+ vec_env = make_vec_env_stable_baselines(
+ jax_dataclass_env=func_env,
+ n_envs=10,
+ seed=42,
+ vec_env_kwargs=dict(
+ # max_episode_steps=5_000,
+ jit_compile=True,
+ ),
+ )
+
+ # Seed the environment
+ # vec_env.seed(seed=42)
+
+ # Reset the environment.
+ # This has to be done only once since the vectorized environment supports autoreset.
+ observations = vec_env.reset()
+
+ # Initialize a random policy
+ random_policy = lambda obs: vec_env.jax_vector_env.action_space.sample()
+
+ for _ in range(1):
+ # Sample random actions
+ actions = random_policy(observations)
+
+ # Step the environment
+ observations, rewards, dones, step_infos = vec_env.step(actions=actions)
+
+ print(observations, rewards, dones, step_infos)
+ # print()
+ # print(dones)
+
+
+def evaluate(
+ env: gym.Env | Callable[[...], gym.Env],
+ num_episodes: int = 1,
+ seed: int | None = None,
+ render: bool = False,
+ policy: Callable[[npt.NDArray], npt.NDArray] | None = None,
+ # vec_env_norm: Optional[VecNormalize] = None,
+) -> None:
+ """"""
+
+ # Create the environment if a callable is passed
+ env = env if isinstance(env, gym.Env) else env()
+
+ # Initialize a random policy if none is passed
+ policy = policy if policy is not None else lambda obs: env.action_space.sample()
+
+ # if vec_env_norm is not None and not isinstance(vec_env_norm, VecNormalize):
+ # raise TypeError(vec_env_norm, VecNormalize)
+
+ episodes_length = []
+ cumulative_rewards = []
+
+ for e in range(num_episodes):
+ # Reset the environment
+ observation, step_info = env.reset(seed=seed)
+
+ # Initialize done flag
+ done = False
+
+ # Render the environment
+ if render:
+ env.render()
+
+ episodes_length += [0]
+ cumulative_rewards += [0]
+
+ # Evaluation loop
+ while not done:
+ # Increase episode length counter
+ episodes_length[-1] += 1
+
+ # Predict the action
+ action = policy(observation)
+
+ # Step the environment
+ observation, reward, terminal, truncated, step_info = env_eval.step(
+ action=action
+ )
+
+ # Determine if the episode is done
+ done = terminal or truncated
+
+ # Store the cumulative reward
+ cumulative_rewards[-1] += reward
+
+ # Render the environment
+ if render:
+ _ = env_eval.render()
+
+ print("ep_len_mean\t", np.array(episodes_length).mean())
+ print("ep_rew_mean\t", np.array(cumulative_rewards).mean())
+
+
+# Train with SB
+if __name__ == "__main__cartpole_cpu_vec_env":
+ """"""
+
+ max_episode_steps = 200
+ func_env = NaNHandlerWrapper(env=CartpoleSwingUpFuncEnvV0())
+
+ if max_episode_steps is not None:
+ func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps)
+
+ func_env = ClipActionWrapper(
+ env=SquashActionWrapper(env=func_env),
+ )
+
+ vec_env = make_vec_env_stable_baselines(
+ jax_dataclass_env=func_env,
+ n_envs=10,
+ seed=42,
+ vec_env_kwargs=dict(
+ jit_compile=True,
+ ),
+ )
+
+ import torch as th
+
+ model = PPO(
+ "MlpPolicy",
+ env=vec_env,
+ # n_steps=2048,
+ n_steps=256, # in the vector env -> real ones are x10
+ batch_size=256,
+ n_epochs=10,
+ gamma=0.95,
+ gae_lambda=0.9,
+ clip_range=0.1,
+ normalize_advantage=True,
+ # target_kl=0.010,
+ target_kl=0.025,
+ verbose=1,
+ learning_rate=0.000_300,
+ policy_kwargs=dict(
+ activation_fn=th.nn.ReLU,
+ net_arch=dict(pi=[512, 512], vf=[512, 512]),
+ log_std_init=np.log(0.05),
+ # squash_output=True,
+ ),
+ )
+
+ print(model.policy)
+
+ # Create the evaluation environment
+ env_eval = make_jax_env_cartpole(
+ render_mode="meshcat_viz",
+ max_episode_steps=500,
+ )
+
+ for _ in range(1):
+ # Train the model
+ model = model.learn(total_timesteps=50_000, progress_bar=False)
+
+ # Create the policy closure
+ policy = lambda observation: model.policy.predict(
+ observation=observation, deterministic=True
+ )[0]
+
+ # Evaluate the policy
+ print("Evaluating...")
+ evaluate(
+ env=env_eval,
+ num_episodes=10,
+ seed=None,
+ render=True,
+ policy=policy,
+ )
+
+# Train with SB
+if __name__ == "__main__cartpole_gpu_vec_env":
+ """"""
+
+ max_episode_steps = 200
+ func_env = NaNHandlerWrapper(env=CartpoleSwingUpFuncEnvV0())
+
+ if max_episode_steps is not None:
+ func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps)
+
+ func_env = ClipActionWrapper(
+ env=SquashActionWrapper(
+ # env=func_env
+ env=ActionNoiseWrapper(env=func_env)
+ ),
+ )
+
+ vec_env = make_vec_env_stable_baselines(
+ jax_dataclass_env=func_env,
+ # n_envs=10,
+ n_envs=512,
+ # n_envs=2048, # TODO
+ seed=42,
+ vec_env_kwargs=dict(
+ jit_compile=True,
+ ),
+ )
+
+ vec_env = VecMonitor(
+ venv=VecNormalize(
+ venv=vec_env,
+ training=True,
+ norm_obs=True,
+ norm_reward=True,
+ clip_obs=10.0,
+ clip_reward=10.0,
+ gamma=0.95,
+ epsilon=1e-8,
+ )
+ )
+
+ actions = vec_env.jax_vector_env.action_space.sample()
+
+ # 0: ok
+ # 1: ok
+ # 2: ok
+ # 3: ok
+ # 4: ok
+ # 5: ok
+ # 6: ok
+ # 7: -> now
+ # 8:
+ # 9:
+ vec_env.venv.venv.logger_rewards = []
+ seed = vec_env.seed(seed=7)[0]
+ _ = vec_env.reset()
+
+ import torch as th
+
+ model = PPO(
+ "MlpPolicy",
+ env=vec_env,
+ n_steps=5, # in the vector env -> real ones are x512
+ batch_size=256,
+ n_epochs=10,
+ gamma=0.95,
+ gae_lambda=0.9,
+ clip_range=0.1,
+ normalize_advantage=True,
+ target_kl=0.025,
+ verbose=1,
+ learning_rate=0.000_300,
+ policy_kwargs=dict(
+ activation_fn=th.nn.ReLU,
+ net_arch=dict(pi=[512, 512], vf=[512, 512]),
+ log_std_init=np.log(0.05),
+ # squash_output=True,
+ ),
+ )
+
+ print(model.policy)
+
+ # Create the evaluation environment
+ env_eval = make_jax_env_cartpole(
+ render_mode="meshcat_viz",
+ max_episode_steps=200,
+ )
+
+ # from stable_baselines3.common.
+
+ # rewards = np.zeros((10, 982)) DO NOT EXEC
+ # rewards_7 = np.array(vec_env.venv.venv.logger_rewards) CHANGE _X
+ # rewards[seed, :] = np.array(vec_env.venv.venv.logger_rewards)
+
+ # rewards = np.vstack([rewards, np.atleast_2d(vec_env.logger_rewards)])
+
+ # plt.plot(
+ # # np.arange(start=1, stop=len(vec_env.venv.venv.logger_rewards) + 1) * 512,
+ # # vec_env.venv.venv.logger_rewards,
+ # np.arange(start=1, stop=len(vec_env.venv.venv.logger_rewards) + 3) * 512,
+ # rewards.T,
+ # label=r"$\hat{r}$"
+ # )
+ # # plt.plot(step_data[model_js.name()].tf, joint_positions_mj, label=["d", "theta"])
+ # plt.grid(True)
+ # plt.legend()
+ # plt.xlabel("Time steps")
+ # plt.ylabel("Average reward over 512 environments")
+ # # plt.title("Trajectory of the model's base")
+ # plt.show()
+
+ # import pickle
+ # with open(file=pathlib.Path.home()
+ # / "git"
+ # / "jaxsim"
+ # / "scripts"
+ # / f"ppo_cartpole_swingup_rewards.pickle", mode="w+b") as f:
+ # pickle.dump(rewards, f)
+
+ # model.save(
+ # path=pathlib.Path.home()
+ # / "git"
+ # / "jaxsim"
+ # / "scripts"
+ # / f"ppo_cartpole_swing_up_seed={seed}.zip"
+ # )
+
+ for _ in range(5):
+ # Train the model
+ model = model.learn(total_timesteps=50_000, progress_bar=False)
+ # %time model = model.learn(total_timesteps=500_000, progress_bar=False)
+
+ # Create the policy closure
+ policy = lambda observation: model.policy.predict(
+ # observation=observation, deterministic=True
+ observation=vec_env.normalize_obs(observation),
+ deterministic=True,
+ )[0]
+
+ # Evaluate the policy
+ print("Evaluating...")
+ evaluate(
+ env=env_eval,
+ num_episodes=10,
+ seed=None,
+ render=True,
+ policy=policy,
+ # vec_env_norm=vec_env,
+ )
+
+ # for _ in range(n_steps):
+ # observation = mj_observation(mujoco_model=m)
+ # action = model.policy.predict(observation=observation, deterministic=True)[0]
+ # m.data.ctrl = np.atleast_1d(action)
+ # mujoco.mj_step(m.model, m.data)
+ # # mujoco.mj_forward(m.model, m.data)
+
+ # import palettable
+ # # https://jiffyclub.github.io/palettable/cartocolors/diverging/
+ # colors = palettable.cartocolors.diverging.Geyser_5.mpl_colors
+ #
+ # r = rewards.copy()
+ # mean = r.mean(axis=0)
+ # std = r.std(axis=0)
+ # std_up = mean + std/2
+ # std_down = mean - std/2
+ #
+ # fig, ax1 = plt.subplots(1, 1)
+ # ax1.fill_between(
+ # np.arange(start=1, stop=mean.size + 1) * 512,
+ # std_down,
+ # std_up,
+ # label=r"$\pm \sigma$",
+ # color=colors[1],
+ # )
+ # ax1.plot(
+ # np.arange(start=1, stop=mean.size + 1) * 512,
+ # mean,
+ # color=colors[0],
+ # )
+ # ax1.grid()
+ # # ax1.legend(loc="lower right")
+ # ax1.set_title(r"\textbf{Average reward}")
+ # ax1.set_xlabel("Samples")
+ #
+ # # plt.show()
+ #
+ # import tikzplotlib
+ # tikzplotlib.clean_figure()
+ # print(tikzplotlib.get_tikz_code())
+
+ # =======================
+ # Evaluation environments
+ # =======================
+ #
+ # env_eval = make_jax_env_cartpole(render_mode="meshcat_viz", max_episode_steps=None)
+ #
+ # # observation, step_info = env_eval.reset(seed=42)
+ # observation, step_info = env_eval.reset()
+ #
+ # # Initialize done flag
+ # done = False
+ #
+ # env_eval.render()
+ #
+ # i = 0
+ # cum_reward = 0.0
+ #
+ # while not done:
+ # i += 1
+ #
+ # if i == 2000:
+ # done = True
+ #
+ # # Sample a random action
+ # # action = 0.1* random_policy(env, observation)
+ # action, _ = model.policy.predict(observation=observation, deterministic=False)
+ #
+ # # Step the environment
+ # observation, reward, terminal, truncated, step_info = env_eval.step(
+ # action=action
+ # )
+ #
+ # print(reward)
+ # cum_reward += reward
+ #
+ # # Render the environment
+ # _ = env_eval.render()
+ #
+ # # print(observation, reward, terminal, truncated, step_info)
+ # # print(env.state)
+ #
+ # print("reward =", cum_reward)
+ # print("episode length =", i)
+ #
+ # env_eval.close()
+
+
+# Comparison with mujoco
+if __name__ == "__main_comparison_mujoco":
+ """"""
+
+ model = PPO.load(
+ path=pathlib.Path.home()
+ / "git"
+ / "jaxsim"
+ / "scripts"
+ / "ppo_cartpole_swing_up_seed=7.zip"
+ )
+
+ # Create the evaluation environment
+ env_eval = make_jax_env_cartpole(
+ render_mode="meshcat_viz",
+ max_episode_steps=250,
+ )
+
+ # =================
+ # Mujoco evaluation
+ # =================
+
+ def mj_observation(mujoco_model: MujocoModel) -> npt.NDArray:
+ """"""
+
+ mujoco.mj_forward(mujoco_model.model, mujoco_model.data)
+
+ pivot_pos = mujoco_model.joint_position(joint_name="pivot")
+ # θ = np.arctan2(np.sin(pivot_pos), np.cos(pivot_pos))
+ θ = pivot_pos
+
+ return (
+ np.array(
+ [
+ mujoco_model.joint_position(joint_name="linear"),
+ mujoco_model.joint_velocity(joint_name="linear"),
+ θ,
+ mujoco_model.joint_velocity(joint_name="pivot"),
+ ]
+ )
+ .squeeze()
+ .copy()
+ )
+
+ def mj_reset(
+ mujoco_model: MujocoModel, observation: Optional[npt.NDArray] = None
+ ) -> npt.NDArray:
+ """"""
+
+ observation = (
+ observation
+ if observation is not None
+ else np.array([0.0, 0.0, np.deg2rad(180), 0.0])
+ )
+
+ linear_pos = observation[0]
+ linear_vel = observation[1]
+ pivot_pos = observation[2]
+ pivot_vel = observation[3]
+
+ mujoco_model.data.qpos = np.array([linear_pos, pivot_pos])
+ mujoco_model.data.qvel = np.array([linear_vel, pivot_vel])
+ mujoco.mj_forward(mujoco_model.model, mujoco_model.data)
+
+ return mj_observation(mujoco_model=mujoco_model)
+
+ def mj_step(action: npt.NDArray, mujoco_model: MujocoModel) -> None:
+ """"""
+
+ n_steps = int(0.050 / mujoco_model.model.opt.timestep)
+ mujoco_model.data.ctrl = np.atleast_1d(action.squeeze()).copy()
+ mujoco.mj_step(mujoco_model.model, mujoco_model.data, n_steps)
+
+ # m.data.qpos = np.array([0.0, np.deg2rad(180)])
+ # m.data.qvel = np.array([0.0, 0.0])
+ # mujoco.mj_forward(m.model, m.data)
+
+ # # Create the policy closure
+ # policy = lambda observation: model.policy.predict(
+ # # observation=observation, deterministic=True
+ # observation=vec_env.normalize_obs(observation), deterministic=True
+ # )[0]
+
+ # ==============
+ # Mujoco regular
+ # ==============
+
+ model_xml_path = (
+ pathlib.Path.home()
+ / "git"
+ / "jaxsim"
+ / "examples"
+ / "resources"
+ / "cartpole_mj.xml"
+ )
+
+ self = m = MujocoModel(xml_path=model_xml_path)
+
+ mj_action = []
+ mj_pos_cart = []
+ mj_pos_pole = []
+
+ done = False
+ iterations = 0
+ mj_reset(mujoco_model=m)
+
+ while not done:
+ iterations += 1
+ observation = mj_observation(mujoco_model=m)
+ # action = model.policy.predict(observation=observation, deterministic=True)[0]
+ obs_policy = observation.copy()
+ obs_policy[2] = np.arctan2(np.sin(obs_policy[2]), np.cos(obs_policy[2]))
+ action = policy(obs_policy)
+
+ mj_step(action=action, mujoco_model=m)
+
+ mj_action.append(action * 50)
+ mj_pos_cart.append(observation[0])
+ mj_pos_pole.append(observation[2])
+
+ import time
+
+ time.sleep(0.050)
+
+ print(observation, "\t", action)
+
+ if iterations >= 201:
+ break
+
+ # ======
+ # Jaxsim
+ # ======
+
+ # js_action = []
+ # js_pos_cart = []
+ # js_pos_pole = []
+ #
+ # done = False
+ # iterations = 0
+ # observation, _ = env_eval.reset()
+ #
+ # while not done:
+ # iterations += 1
+ # action = policy(observation)
+ # observation, _, _, _, _, = env_eval.step(action)
+ #
+ # js_action.append(action * 50)
+ # js_pos_cart.append(observation[0])
+ # js_pos_pole.append(observation[2])
+ #
+ # # import time
+ # #
+ # # time.sleep(0.050)
+ #
+ # print(observation, "\t", action)
+ #
+ # if iterations >= 201:
+ # break
+
+ # ============
+ # Mujoco alt 1
+ # ============
+
+ model_xml_path = (
+ pathlib.Path.home()
+ / "git"
+ / "jaxsim"
+ / "examples"
+ / "resources"
+ / "cartpole_mj.xml"
+ )
+
+ m = MujocoModel(xml_path=model_xml_path)
+
+ mj_action_alt1 = []
+ mj_pos_cart_alt1 = []
+ mj_pos_pole_alt1 = []
+
+ done = False
+ iterations = 0
+ mj_reset(mujoco_model=m)
+
+ while not done:
+ iterations += 1
+ observation = mj_observation(mujoco_model=m)
+ # action = policy(observation)
+ obs_policy = observation.copy()
+ obs_policy[2] = np.arctan2(np.sin(obs_policy[2]), np.cos(obs_policy[2]))
+ action = policy(obs_policy)
+ mj_step(action=action, mujoco_model=m)
+
+ mj_action_alt1.append(action * 50)
+ mj_pos_cart_alt1.append(observation[0])
+ mj_pos_pole_alt1.append(observation[2])
+
+ import time
+
+ time.sleep(0.050)
+
+ print(observation, "\t", action)
+
+ if iterations >= 201:
+ break
+
+ mj_action_alt1 = np.array(mj_action_alt1)
+ mj_pos_cart_alt1 = np.array(mj_pos_cart_alt1)
+ mj_pos_pole_alt1 = np.array(mj_pos_pole_alt1)
+
+ # ============
+ # Mujoco alt 2
+ # ============
+
+ model_xml_path = (
+ pathlib.Path.home()
+ / "git"
+ / "jaxsim"
+ / "examples"
+ / "resources"
+ / "cartpole_mj.xml"
+ )
+
+ m = MujocoModel(xml_path=model_xml_path)
+
+ mj_action_alt2 = []
+ mj_pos_cart_alt2 = []
+ mj_pos_pole_alt2 = []
+
+ done = False
+ iterations = 0
+ mj_reset(mujoco_model=m)
+
+ while not done:
+ iterations += 1
+ observation = mj_observation(mujoco_model=m)
+ # action = policy(observation)
+ obs_policy = observation.copy()
+ obs_policy[2] = np.arctan2(np.sin(obs_policy[2]), np.cos(obs_policy[2]))
+ action = policy(obs_policy)
+ mj_step(action=action, mujoco_model=m)
+
+ mj_action_alt2.append(action * 50)
+ mj_pos_cart_alt2.append(observation[0])
+ mj_pos_pole_alt2.append(observation[2])
+
+ import time
+
+ time.sleep(0.050)
+
+ print(observation, "\t", action)
+
+ if iterations >= 201:
+ break
+
+ mj_action_alt2 = np.array(mj_action_alt2)
+ mj_pos_cart_alt2 = np.array(mj_pos_cart_alt2)
+ mj_pos_pole_alt2 = np.array(mj_pos_pole_alt2)
+
+ # ====
+ # Plot
+ # ====
+
+ import palettable
+
+ # https://jiffyclub.github.io/palettable/cartocolors/diverging/
+ # colors = palettable.cartocolors.diverging.Geyser_5.mpl_colors
+ colors = palettable.cartocolors.qualitative.Prism_8.mpl_colors
+
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True)
+ time = np.arange(start=0, stop=len(mj_action)) * 0.050
+
+ # ax1.plot(time, js_pos_pole, label=r"Jaxsim", color=colors[1], linewidth=1)
+ # ax1.plot(time, mj_pos_pole, label=r"Mujoco", color=colors[7], linewidth=1)
+ # ax2.plot(time, js_pos_cart, label=r"Jaxsim", color=colors[1], linewidth=1)
+ # ax2.plot(time, mj_pos_cart, label=r"Mujoco", color=colors[7], linewidth=1)
+ # ax3.plot(time, js_action, label=r"Jaxsim", color=colors[1], linewidth=1)
+ # ax3.plot(time, mj_action, label=r"Mujoco", color=colors[7], linewidth=1)
+
+ ax1.plot(time, mj_pos_pole, label=r"nominal", color=colors[1], linewidth=1)
+ ax1.plot(time, mj_pos_pole_alt1, label=r"mass", color=colors[3], linewidth=1)
+ ax1.plot(
+ time, mj_pos_pole_alt2, label=r"mass+friction", color=colors[7], linewidth=1
+ )
+ ax2.plot(time, mj_pos_cart, label=r"nominal", color=colors[1], linewidth=1)
+ ax2.plot(time, mj_pos_cart_alt1, label=r"mass", color=colors[3], linewidth=1)
+ ax2.plot(
+ time, mj_pos_cart_alt2, label=r"mass+friction", color=colors[7], linewidth=1
+ )
+ ax3.plot(time, mj_action, label=r"nominal", color=colors[1], linewidth=1)
+ ax3.plot(time, mj_action_alt1, label=r"mass", color=colors[3], linewidth=1)
+ ax3.plot(time, mj_action_alt2, label=r"mass+friction", color=colors[7], linewidth=1)
+
+ ax1.grid()
+ ax1.set_ylabel(r"Pole angle $\theta$ [rad]")
+ ax2.grid()
+ ax2.set_ylabel(r"Cart position $d$ [m]")
+ ax3.grid()
+ ax3.set_ylabel(r"Force applied to cart $f$ [N]")
+
+ # ax1.set_title(r"Pole angle $\theta$")
+ # ax2.set_title(r"Cart position $d$")
+ # ax3.set_title(r"Force applied to cart $f$")
+
+ # ax1.legend()
+ # ax2.legend()
+ # ax3.legend()
+
+ # plt.legend()
+ # plt.title(r"\textbf{Comparison of cartpole swing-up performance}")
+ # fig.suptitle(r"\textbf{Comparison of cartpole swing-up performance}")
+ fig.supxlabel("Time [s]")
+ # plt.show()
+
+ import tikzplotlib
+
+ tikzplotlib.clean_figure()
+ print(tikzplotlib.get_tikz_code())
+
+# Train with SB
+if __name__ == "__main__ant_vec_gpu_env":
+ """"""
+
+ max_episode_steps = 1000
+ func_env = NaNHandlerWrapper(env=AntReachTargetFuncEnvV0())
+
+ if max_episode_steps is not None:
+ func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps)
+
+ func_env = ClipActionWrapper(
+ env=SquashActionWrapper(env=func_env),
+ )
+
+ # TODO: rename _sb to prevent collision with module
+ vec_env_sb = make_vec_env_stable_baselines(
+ jax_dataclass_env=func_env,
+ # n_envs=10,
+ # n_envs=2048, # troppo -> JIT lungo
+ # n_envs=100,
+ # n_envs=1024,
+ # n_envs=2048,
+ n_envs=512,
+ seed=42,
+ vec_env_kwargs=dict(
+ jit_compile=True,
+ ),
+ )
+
+ # %time _ = vec_env_sb.reset()
+ # %time _ = vec_env_sb.reset()
+ # actions = vec_env_sb.jax_vector_env.action_space.sample()
+ # %time _ = vec_env_sb.step(actions)
+ # %time _ = vec_env_sb.step(actions)
+
+ import torch as th
+
+ # TODO: se ogni reset c'e' 1 sec di sim -> mega lento perche' ci sara' sempre
+ # un env che si sta resettando!
+
+ model = PPO(
+ "MlpPolicy",
+ env=vec_env_sb,
+ # n_steps=2048,
+ # n_steps=512, # in the vector env -> real ones are x10
+ # n_steps=10, # in the vector env -> real ones are x10
+ n_steps=4, # in the vector env -> real ones are x2048
+ # batch_size=256,
+ batch_size=1024,
+ n_epochs=10,
+ gamma=0.95,
+ gae_lambda=0.9,
+ clip_range=0.1,
+ normalize_advantage=True,
+ # target_kl=0.010,
+ target_kl=0.025,
+ verbose=1,
+ learning_rate=0.000_300,
+ policy_kwargs=dict(
+ activation_fn=th.nn.ReLU,
+ net_arch=dict(pi=[2048, 1024], vf=[1024, 1024, 256]),
+ # net_arch=dict(pi=[2048, 2048], vf=[2048, 1024, 512]),
+ log_std_init=np.log(0.05),
+ # squash_output=True,
+ ),
+ )
+
+ print(model.policy)
+
+ # Create the evaluation environment
+ env_eval = make_jax_env_ant(
+ render_mode="meshcat_viz",
+ max_episode_steps=1000,
+ )
+
+ for _ in range(10):
+ # Train the model
+ model = model.learn(total_timesteps=500_000, progress_bar=False)
+
+ # Create the policy closure
+ policy = lambda observation: model.policy.predict(
+ observation=observation, deterministic=True
+ )[0]
+
+ # Evaluate the policy
+ print("Evaluating...")
+ evaluate(
+ env=env_eval,
+ num_episodes=10,
+ seed=None,
+ render=True,
+ # policy=policy,
+ )
+
+ # evaluate(
+ # env=env_eval,
+ # num_episodes=10,
+ # seed=None,
+ # render=True,
+ # # policy=policy,
+ # )
+
+# =============
+# RANDOM POLICY
+# =============
+
+# TODO: generate a JaxVecEnv and validate with DummyVecEnv from SB
+
+if __name__ == "__main__+":
+ """"""
+
+ # Create the environment
+ env = make_jax_env_ant(render_mode="meshcat_viz", max_episode_steps=None)
+ # env = make_jax_env_cartpole(render_mode="meshcat_viz", max_episode_steps=10)
+
+ # Reset the environment
+ # observation, state_info = env.reset(seed=42)
+ observation, state_info = env.reset()
+
+ # Initialize a random policy
+ random_policy = lambda env, obs: env.action_space.sample()
+
+ # Initialize done flag
+ done = False
+
+ env.render()
+
+ # s = env.state
+ # env.func_env.env.transition(state=s, action=env.func_env.action_space.sample())
+ #
+ # with env.func_env.unwrapped.jaxsim.editable(validate=True) as sim:
+ # sim.data = env.state["env"]
+
+ # import time
+ #
+ # time.sleep(2)
+
+ i = 0
+ cum_reward = 0.0
+
+ while not done:
+ i += 1
+
+ if i == 2000:
+ done = True
+
+ # Sample a random action
+ # action = 0.1* random_policy(env, observation)
+ action, _ = model.policy.predict(observation=observation, deterministic=False)
+
+ # Step the environment
+ observation, reward, terminal, truncated, step_info = env.step(action=action)
+
+ print(reward)
+ cum_reward += reward
+
+ # Render the environment
+ _ = env.render()
+
+ # print(observation, reward, terminal, truncated, step_info)
+ # print(env.state)
+
+ print(cum_reward)
+
+ env.close()
+
+# =================
+# TRAINING PPO/TRPO
+# =================
+
+if __name__ == "__main__)":
+ """Stable Baselines"""
+
+ # Initialize properties
+ # seed = 42
+
+ # Create a single environment
+ # func_env = CartpoleSwingUpFuncEnvV0()
+
+ # env = JaxEnv(
+ # func_env=ToNumPyWrapper(
+ # env=JaxTransformWrapper(
+ # function=jax.jit,
+ # env=FlattenSpacesWrapper(env=CartpoleSwingUpFuncEnvV0()),
+ # )
+ # )
+ # )
+
+ # TODO: try with single env first?
+
+ # env = JaxEnv(
+ # func_env=ToNumPyWrapper(
+ # env=JaxTransformWrapper(
+ # function=jax.jit,
+ # env=FlattenSpacesWrapper(
+ # env=TimeLimit(env=func_env, max_episode_steps=1_000)
+ # ),
+ # )
+ # )
+ # )
+
+ # def make_env() -> gym.Env:
+ # def make_jax_env(max_episode_steps: Optional[int] = 500) -> JaxEnv:
+ # """"""
+ #
+ # # TODO: single env -> time limit with stable_baselines?
+ #
+ # if max_episode_steps is None:
+ # env = CartpoleSwingUpFuncEnvV0()
+ # else:
+ # env = TimeLimit(
+ # env=CartpoleSwingUpFuncEnvV0(), max_episode_steps=max_episode_steps
+ # )
+ #
+ # return JaxEnv(
+ # func_env=ToNumPyWrapper(
+ # env=JaxTransformWrapper(
+ # function=jax.jit,
+ # env=FlattenSpacesWrapper(env=env),
+ # )
+ # )
+ # )
+
+ # check_env(env=env, warn=True, skip_render_check=True)
+
+ # observation = env.reset()
+ # action = env.action_space.sample()
+ # observation, reward, terminated, truncate, info = env.step(action)
+ # print(observation, reward, terminated, truncate, info)
+
+ #
+ #
+ #
+
+ # vec_env = make_vec_env_stable_baselines(
+ # jax_dataclass_env=func_env,
+ # n_envs=10,
+ # seed=42,
+ # vec_env_kwargs=dict(
+ # max_episode_steps=5_000,
+ # jit_compile=True,
+ # ),
+ # )
+
+ # vec_env.reset()
+
+ vec_env = make_vec_env(
+ # env_id=lambda: make_jax_env_cartpole(max_episode_steps=500),
+ env_id=lambda: make_jax_env_ant(max_episode_steps=2_000),
+ n_envs=10,
+ seed=42,
+ vec_env_cls=SubprocVecEnv,
+ )
+
+ import torch as th
+
+ # ANT
+ model = PPO(
+ "MlpPolicy",
+ env=vec_env,
+ # n_steps=2048,
+ n_steps=512, # in the vector env -> real ones are x10
+ batch_size=256,
+ n_epochs=10,
+ gamma=0.95,
+ gae_lambda=0.98,
+ clip_range=0.1,
+ normalize_advantage=True,
+ # target_kl=0.010,
+ target_kl=0.025,
+ verbose=1,
+ learning_rate=0.000_300,
+ policy_kwargs=dict(
+ activation_fn=th.nn.ReLU,
+ net_arch=dict(pi=[1024, 512], vf=[1024, 512]),
+ # log_std_init=0.10,
+ # log_std_init=2.5,
+ log_std_init=np.log(0.25),
+ # squash_output=True,
+ ),
+ )
+
+ print(model.policy)
+ model = model.learn(total_timesteps=200_000, progress_bar=False)
+ # model = model.learn(total_timesteps=500_000, progress_bar=False)
+
+ #
+ # CARTPOLE
+ #
+
+ # TODO: with squash_output -> do I need to apply tanh to the action?
+ # model = PPO(
+ # "MlpPolicy",
+ # env=vec_env,
+ # # n_steps=2048,
+ # n_steps=256, # in the vector env -> real ones are x10
+ # batch_size=256,
+ # n_epochs=10,
+ # gamma=0.95,
+ # gae_lambda=0.9,
+ # clip_range=0.1,
+ # normalize_advantage=True,
+ # # target_kl=0.010,
+ # target_kl=0.025,
+ # verbose=1,
+ # learning_rate=0.000_300,
+ # policy_kwargs=dict(
+ # activation_fn=th.nn.ReLU,
+ # net_arch=dict(pi=[256, 256], vf=[256, 256]),
+ # log_std_init=0.05,
+ # # squash_output=True,
+ # ),
+ # )
+
+ model = TRPO(
+ "MlpPolicy",
+ env=vec_env,
+ n_steps=256, # in the vector env -> real ones are x10
+ batch_size=1024,
+ gamma=0.95,
+ gae_lambda=0.95,
+ normalize_advantage=True,
+ target_kl=0.025,
+ verbose=1,
+ learning_rate=0.000_300,
+ policy_kwargs=dict(
+ activation_fn=th.nn.ReLU,
+ net_arch=dict(pi=[256, 256], vf=[256, 256]),
+ log_std_init=0.05, # TODO np.log()..., maybe 0.05 too large?
+ # squash_output=True,
+ ),
+ )
+
+ print(model.policy)
+ model = model.learn(total_timesteps=500_000, progress_bar=False)
+ # model.learn(total_timesteps=25_000)
+ # model.save("ppo_cartpole")
+
+ # # Vectorize the environment.
+ # # Note: it automatically wraps the environment in a TimeLimit wrapper.
+ # vec_env = JaxVectorEnv(
+ # func_env=env,
+ # num_envs=10,
+ # max_episode_steps=5_000,
+ # jit_compile=True,
+ # )
+ #
+ # from jaxgym.vector.jax import FlattenSpacesVecWrapper
+ #
+ # vec_env = FlattenSpacesVecWrapper(env=vec_env)
+ #
+ # # test.reset(seed=0)
+ # # test.step(actions=test.action_space.sample())
+
+ # =========
+ # Visualize
+ # =========
+
+ visualize = False
+
+ if visualize:
+ rollout_visualizer = visualizer(env=lambda: make_jax_env(1_000), policy=model)
+
+ import time
+
+ time.sleep(3)
+ rollout_visualizer(None)
+
+
+if __name__ == "__main___":
+ """"""
+
+ # Initialize properties
+ seed = 42
+ # num_envs = 3
+
+ # Create a single environment
+ func_env = CartpoleSwingUpFuncEnvV0()
+
+ from jaxgym.jax.env import JaxEnv
+ from jaxgym.wrappers.jax import (
+ ClipActionWrapper,
+ FlattenSpacesWrapper,
+ JaxTransformWrapper,
+ )
+
+ func_env = ClipActionWrapper(env=func_env)
+ func_env = FlattenSpacesWrapper(env=func_env)
+ func_env = JaxTransformWrapper(env=func_env, function=jax.jit)
+
+ # state = func_env.initial(rng=jax.random.PRNGKey(seed=seed))
+ # action = func_env.action_space.sample()
+ # func_env.transition(state, action)
+
+ env = JaxEnv(func_env=func_env)
+
+ observation, state_info = env.reset(seed=seed)
+
+
+# TODO: specs kind of ok -> see bottom of cartpole
+if __name__ == "__main___":
+ """"""
+
+ # Initialize properties
+ seed = 42
+ # num_envs = 3
+
+ # Create a single environment
+ env = CartpoleSwingUpFuncEnvV0()
+
+ # Vectorize the environment.
+ # Note: it automatically wraps the environment in a TimeLimit wrapper.
+ vec_env = JaxVectorEnv(
+ func_env=env,
+ num_envs=1_000,
+ max_episode_steps=10,
+ jit_compile=True,
+ )
+
+ # FRIDAY:
+ # from jaxgym.functional.jax.flatten_spaces import FlattenSpacesWrapper
+ # test = FlattenSpacesWrapper(env=env)
+ # test.transition(
+ # state=test.initial(rng=jax.random.PRNGKey(0)), action=test.action_space.sample()
+ # )
+
+ from jaxgym.vector.jax import FlattenSpacesVecWrapper
+
+ test = FlattenSpacesVecWrapper(env=vec_env)
+ test.reset(seed=0)
+ test.step(actions=test.action_space.sample())
+
+ # import cProfile
+ # from pstats import SortKey
+ #
+ # with cProfile.Profile() as pr:
+ #
+ # for _ in range(1000):
+ # _ = test.step(actions=test.action_space.sample())
+ #
+ # pr.print_stats(sort=SortKey.CUMULATIVE)
+ #
+ # exit(0)
+
+ # o = test.env.observation_space.sample()
+ # test.env.observation_space.flatten_sample(o)
+
+ # exit(0)
+ # raise # all good!
+ # TODO: benchmark non-jit sections
+ # TODO: benchmark from the code instead cmdline after jit compilation
+
+ # Reset the environment.
+ # This has to be done only once since the vectorized environment supports autoreset.
+ observations, state_infos = vec_env.reset(seed=seed)
+
+ # Initialize a random policy
+ random_policy = lambda obs: vec_env.action_space.sample()
+
+ for _ in range(1):
+ # Sample random actions
+ actions = random_policy(observations)
+
+ # Step the environment
+ observations, rewards, terminals, truncated, step_infos = vec_env.step(
+ action=actions
+ )
+
+ print(observations, rewards, terminals, truncated, step_infos)
diff --git a/src/jaxgym/_spaces/__init__.py b/src/jaxgym/_spaces/__init__.py
new file mode 100644
index 000000000..36c133a4b
--- /dev/null
+++ b/src/jaxgym/_spaces/__init__.py
@@ -0,0 +1 @@
+from .space import Space
diff --git a/src/jaxgym/_spaces/pytree_orig.py b/src/jaxgym/_spaces/pytree_orig.py
new file mode 100644
index 000000000..7403cf98a
--- /dev/null
+++ b/src/jaxgym/_spaces/pytree_orig.py
@@ -0,0 +1,240 @@
+import gymnasium.spaces
+import jax.flatten_util
+import jax.numpy as jnp
+import jax.tree_util
+import numpy as np
+import numpy.typing as npt
+from gymnasium.spaces.utils import flatdim, flatten
+from gymnasium.vector.utils.spaces import batch_space
+
+import jaxsim.typing as jtp
+from jaxsim.utils import not_tracing, tracing
+
+from .space import Space
+
+# TODO: inherit from gymnasium.spaces?
+
+
+class PyTree(Space):
+ """"""
+
+ def __init__(self, low: jtp.PyTree, high: jtp.PyTree):
+ """"""
+
+ # ==========================
+ # Check low and high pytrees
+ # ==========================
+ # TODO: make generic (pytrees_with_same_dtype|shape|supported_dtype) and move
+ # to utils
+
+ # supported_dtypes = {
+ # jnp.array(0, dtype=jnp.float32).dtype,
+ # jnp.array(0, dtype=jnp.float64).dtype,
+ # jnp.array(0, dtype=int).dtype,
+ # jnp.array(0, dtype=bool).dtype,
+ # }
+ #
+ # dtypes_supported, _ = jax.flatten_util.ravel_pytree(
+ # jax.tree_util.tree_map(
+ # lambda l1, l2: jnp.array(l1).dtype in supported_dtypes
+ # and jnp.array(l2).dtype in supported_dtypes,
+ # low,
+ # high,
+ # )
+ # )
+ #
+ # if not jnp.all(dtypes_supported):
+ # # if not_tracing(low) and not jnp.all(dtypes_supported):
+ # # if jnp.where(jnp.array([tracing(low), jnp.all(dtypes_supported)]).any(), False, True):
+ # # if np.any([not_tracing(low), jnp.all(dtypes_supported)]):
+ # # if not_tracing(low):
+ # raise ValueError(
+ # "Either low or high pytrees have attributes with unsupported dtype"
+ # )
+ #
+ # shape_match, _ = jax.flatten_util.ravel_pytree(
+ # jax.tree_util.tree_map(
+ # lambda l1, l2: jnp.array(l1).shape == jnp.array(l2).shape, low, high
+ # )
+ # )
+ #
+ # if not jnp.all(shape_match):
+ # # if not_tracing(low) and not jnp.all(shape_match):
+ # raise ValueError("Wrong shape of low and high attributes")
+ #
+ # dtype_match, _ = jax.flatten_util.ravel_pytree(
+ # jax.tree_util.tree_map(
+ # lambda l1, l2: jnp.array(l1).dtype == jnp.array(l2).dtype, low, high
+ # )
+ # )
+ #
+ # if not jnp.all(dtype_match):
+ # # if not_tracing(low) and not jnp.all(dtype_match):
+ # raise ValueError("Wrong dtype of low and high attributes")
+
+ # Flatten the pytrees
+ low_flat, _ = jax.flatten_util.ravel_pytree(low)
+ high_flat, _ = jax.flatten_util.ravel_pytree(high)
+
+ if low_flat.dtype != high_flat.dtype:
+ raise ValueError(low_flat.dtype, high_flat.dtype)
+
+ if low_flat.shape != high_flat.shape:
+ raise ValueError(low_flat.shape, high_flat.shape)
+
+ # Transform all leafs to array and store them in the object
+ self.low = jax.tree_util.tree_map(lambda l: jnp.array(l), low)
+ self.high = jax.tree_util.tree_map(lambda l: jnp.array(l), high)
+ self.shape = low_flat.shape
+
+ # TODO: what if key is a vector?
+ def sample(self, key: jax.random.PRNGKeyArray) -> jtp.PyTree:
+ """"""
+
+ def random_array(
+ key, shape: tuple, min: jtp.PyTree, max: jtp.PyTree, dtype
+ ) -> jtp.Array:
+ """Helper to select the right sampling function for the supported dtypes"""
+
+ match dtype:
+ case jnp.float32.dtype | jnp.float64.dtype:
+ return jax.random.uniform(
+ key=key,
+ shape=shape,
+ minval=min,
+ maxval=max,
+ dtype=dtype,
+ )
+ case jnp.int16.dtype | jnp.int32.dtype | jnp.int64.dtype:
+ return jax.random.randint(
+ key=key,
+ shape=shape,
+ minval=min,
+ maxval=max + 1,
+ dtype=dtype,
+ )
+ case jnp.bool_.dtype:
+ return jax.random.randint(
+ key=key,
+ shape=shape,
+ minval=min,
+ maxval=max + 1,
+ ).astype(bool)
+ case _:
+ raise ValueError(dtype)
+
+ # Create and flatten a tree having a PRNGKey for each leaf
+ key_pytree = jax.tree_util.tree_map(lambda l: jax.random.PRNGKey(0), self.low)
+ key_pytree_flat, unflatten_fn = jax.flatten_util.ravel_pytree(key_pytree)
+
+ # Generate a pytree having a subkey in each leaf
+ key, *subkey_flat = jax.random.split(key=key, num=key_pytree_flat.size / 2 + 1)
+ subkey_pytree = unflatten_fn(jnp.array(subkey_flat).flatten())
+
+ # Generate a pytree sampling leafs according to their dtype and using a
+ # different key for each of them
+ return jax.tree_util.tree_map(
+ lambda low, high, subkey: random_array(
+ key=key, shape=low.shape, min=low, max=high, dtype=low.dtype
+ ),
+ self.low,
+ self.high,
+ subkey_pytree,
+ )
+
+ def contains(self, x: jtp.PyTree) -> bool:
+ """"""
+
+ def is_inside_bounds(x, low, high):
+ return jax.lax.select(
+ pred=jnp.all(jnp.array([jnp.all(x >= low), jnp.all(x <= high)])),
+ on_true=True,
+ on_false=False,
+ )
+
+ contains_all_leaves = jax.tree_util.tree_map(
+ lambda low, high, l: is_inside_bounds(x=l, low=low, high=high),
+ self.low,
+ self.high,
+ x,
+ )
+
+ contains_all_leaves_flat, _ = jax.flatten_util.ravel_pytree(contains_all_leaves)
+
+ return jnp.all(contains_all_leaves_flat)
+
+ @property
+ def is_np_flattenable(self) -> bool:
+ """Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`."""
+
+ return True
+
+ def to_box(self) -> gymnasium.spaces.Box:
+ """"""
+
+ # low_flat, _ = jax.flatten_util.ravel_pytree(self.low)
+ # high_flat, _ = jax.flatten_util.ravel_pytree(self.high)
+
+ # return gymnasium.spaces.Box(low=np.array(low_flat), high=np.array(high_flat))
+
+ return gymnasium.spaces.Box(
+ low=self.flatten_sample(x=self.low), high=self.flatten_sample(x=self.high)
+ )
+
+ # TODO: use above
+ def flatten_sample(self, x: jtp.PyTree) -> jtp.VectorJax:
+ """"""
+
+ x_flat, _ = jax.flatten_util.ravel_pytree(x)
+ return x_flat
+
+ def unflatten_sample(self, x: jtp.Vector) -> jtp.PyTree:
+ """"""
+
+ _, unflatten_fn = jax.flatten_util.ravel_pytree(self.low)
+ return unflatten_fn(x)
+
+ def clip(self, x: jtp.PyTree) -> jtp.PyTree:
+ """"""
+
+ return jax.tree_util.tree_map(
+ lambda low, high, l: jnp.array(
+ jnp.clip(a=l, a_min=low, a_max=high), dtype=low.dtype
+ ),
+ self.low,
+ self.high,
+ x,
+ )
+
+ # TODO: flatten()
+ # TODO: unflatten() from float with proper type casting
+ # TODO: normalize?
+
+
+@flatdim.register(PyTree)
+def _flatdim_pytree(space: PyTree) -> int:
+ """"""
+
+ low_flat, _ = jax.flatten_util.ravel_pytree(space.low)
+ return low_flat.size
+
+
+@flatten.register(PyTree)
+def _flatten_pytree(space: PyTree, x: jtp.PyTree) -> npt.NDArray:
+ """"""
+
+ assert x in space
+ x_flat, _ = jax.flatten_util.ravel_pytree(x)
+
+ return x_flat
+
+
+@batch_space.register(PyTree)
+def _batch_space_pytree(space: PyTree, n: int = 1) -> PyTree:
+ """"""
+
+ low_batched = jax.tree_util.tree_map(lambda l: jnp.stack([l] * n), space.low)
+ high_batched = jax.tree_util.tree_map(lambda l: jnp.stack([l] * n), space.high)
+
+ # TODO: np_random
+ return PyTree(low=low_batched, high=high_batched)
diff --git a/src/jaxgym/_spaces/space.py b/src/jaxgym/_spaces/space.py
new file mode 100644
index 000000000..7303e941a
--- /dev/null
+++ b/src/jaxgym/_spaces/space.py
@@ -0,0 +1,30 @@
+import abc
+
+import gymnasium.spaces
+import jax
+
+import jaxsim.typing as jtp
+
+
+class Space(abc.ABC):
+ """"""
+
+ # TODO: add num for multiple samples? Or if multiple keys -> multiple samples?
+ @abc.abstractmethod
+ def sample(self, key: jax.random.PRNGKey) -> jtp.PyTree:
+ """"""
+ pass
+
+ @abc.abstractmethod
+ def contains(self, x: jtp.PyTree) -> bool:
+ """"""
+ pass
+
+ # @abc.abstractmethod
+ # def to_gymnasium(self) -> gymnasium.Space:
+ # """"""
+ # pass
+
+ def __contains__(self, x: jtp.PyTree) -> bool:
+ """"""
+ return self.contains(x)
diff --git a/src/jaxgym/envs/__init__.py b/src/jaxgym/envs/__init__.py
new file mode 100644
index 000000000..1aff17745
--- /dev/null
+++ b/src/jaxgym/envs/__init__.py
@@ -0,0 +1,19 @@
+from typing import Any
+
+from gymnasium.envs.registration import (
+ load_plugin_envs,
+ make,
+ pprint_registry,
+ register,
+ registry,
+ spec,
+)
+
+register(
+ id="CartpoleSwingUpEnv-V0",
+ entry_point="jaxgym.envs.cartpole:CartpoleSwingUpEnvV0",
+ vector_entry_point="jaxgym.envs.cartpole:CartpoleSwingUpVectorEnvV0",
+ # max_episode_steps=5_000,
+ # reward_threshold=195.0,
+ # kwargs=dict(max_episode_steps=5_000, blabla=True),
+)
diff --git a/src/jaxgym/envs/ant.py b/src/jaxgym/envs/ant.py
new file mode 100644
index 000000000..de8b4d472
--- /dev/null
+++ b/src/jaxgym/envs/ant.py
@@ -0,0 +1,732 @@
+import dataclasses
+import pathlib
+from typing import Any, ClassVar, Optional
+
+import jax.numpy as jnp
+import jax.random
+import jax_dataclasses
+import numpy as np
+import numpy.typing as npt
+import rod
+
+import jaxgym.jax.pytree_space as spaces
+import jaxsim.typing as jtp
+from jaxgym.jax import JaxDataclassEnv, JaxEnv
+from jaxgym.vector.jax import JaxVectorEnv
+from jaxsim import JaxSim
+from jaxsim.physics.algos.soft_contacts import SoftContactsParams
+from jaxsim.simulation import simulator_callbacks
+from jaxsim.simulation.ode_integration import IntegratorType
+from jaxsim.simulation.simulator import SimulatorData, VelRepr
+from jaxsim.utils import JaxsimDataclass, Mutability
+
+
+@jax_dataclasses.pytree_dataclass
+class AntObservation(JaxsimDataclass):
+ """Observation of the Ant environment."""
+
+ base_height: jtp.Float
+ gravity_projection: jtp.Array
+
+ joint_positions: jtp.Array
+ joint_velocities: jtp.Array
+
+ base_linear_velocity: jtp.Array
+ base_angular_velocity: jtp.Array
+
+ contact_state: jtp.Array
+
+ distance_from_goal_x: jtp.Float
+ distance_from_goal_y: jtp.Float
+
+ @staticmethod
+ def build(
+ base_height: jtp.Float,
+ gravity_projection: jtp.Array,
+ joint_positions: jtp.Array,
+ joint_velocities: jtp.Array,
+ base_linear_velocity: jtp.Array,
+ base_angular_velocity: jtp.Array,
+ contact_state: jtp.Array,
+ distance_from_goal_x: jtp.Float,
+ distance_from_goal_y: jtp.Float,
+ ) -> "AntObservation":
+ """Build an AntObservation object."""
+
+ return AntObservation(
+ base_height=jnp.array(base_height, dtype=float),
+ gravity_projection=jnp.array(gravity_projection, dtype=float),
+ joint_positions=jnp.array(joint_positions, dtype=float),
+ joint_velocities=jnp.array(joint_velocities, dtype=float),
+ base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
+ base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
+ contact_state=jnp.array(contact_state, dtype=bool),
+ distance_from_goal_x=jnp.array(distance_from_goal_x, dtype=float),
+ distance_from_goal_y=jnp.array(distance_from_goal_y, dtype=float),
+ )
+
+
+import multiprocessing
+
+from meshcat_viz import MeshcatWorld
+
+
+@dataclasses.dataclass
+class MeshcatVizRenderState:
+ """Render state of a meshcat-viz visualizer."""
+
+ world: MeshcatWorld = dataclasses.dataclass(init=False)
+
+ _gui_process: Optional[multiprocessing.Process] = dataclasses.field(
+ default=None, init=False, repr=False, hash=False, compare=False
+ )
+
+ _jaxsim_to_meshcat_viz_name: dict[str, str] = dataclasses.field(
+ default_factory=dict, init=False, repr=False, hash=False, compare=False
+ )
+
+ def __post_init__(self) -> None:
+ """"""
+
+ self.world = MeshcatWorld()
+ self.world.open()
+
+ def close(self) -> None:
+ """"""
+
+ if self.world is not None:
+ self.world.close()
+
+ if self._gui_process is not None:
+ self._gui_process.terminate()
+ self._gui_process.close()
+
+ @staticmethod
+ def open_window(web_url: str) -> None:
+ """Open a new window with the given web url."""
+
+ import webview
+
+ print(web_url)
+ webview.create_window("meshcat", web_url)
+ webview.start(gui="qt")
+
+ def open_window_in_process(self) -> None:
+ """"""
+
+ if self._gui_process is not None:
+ self._gui_process.terminate()
+ self._gui_process.close()
+
+ self._gui_process = multiprocessing.Process(
+ target=MeshcatVizRenderState.open_window, args=(self.world.web_url,)
+ )
+ self._gui_process.start()
+
+
+StateType = dict[str, SimulatorData | jtp.Array]
+ActType = jnp.ndarray
+ObsType = AntObservation
+RewardType = float | jnp.ndarray
+TerminalType = bool | jnp.ndarray
+RenderStateType = MeshcatVizRenderState
+
+
+@jax_dataclasses.pytree_dataclass
+class AntReachTargetFuncEnvV0(
+ JaxDataclassEnv[
+ StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ ]
+):
+ """Ant environment implementing a target reaching task."""
+
+ name: ClassVar = jax_dataclasses.static_field(default="AntReachTargetFuncEnvV0")
+
+ # Store an instance of the JaxSim simulator.
+ # It gets initialized with SimulatorData with a functional approach.
+ _simulator: JaxSim = jax_dataclasses.field(default=None)
+
+ def __post_init__(self) -> None:
+ """Environment initialization."""
+
+ # Dummy initialization (not needed here)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ _ = self.jaxsim
+ model = self.jaxsim.get_model(model_name="ant")
+
+ # simulator_data = self.initial(rng=jax.random.PRNGKey(seed=0))
+ # dofs = simulator_data.models["ant"].dofs()
+ # dofs = model.dofs()
+
+ # Create the action space (static attribute)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ high = jnp.array([25.0] * model.dofs(), dtype=float)
+ self._action_space = spaces.PyTree(low=-high, high=high)
+
+ # Get joint limits
+ s_min, s_max = model.joint_limits()
+ s_range = s_max - s_min
+
+ low = AntObservation.build(
+ base_height=0.25,
+ gravity_projection=-jnp.ones(3),
+ # joint_positions=s_min - 0.10 * s_range,
+ joint_positions=s_min,
+ # joint_velocities=-4.0 * jnp.ones_like(s_min),
+ joint_velocities=-50.0 * jnp.ones_like(s_min),
+ base_linear_velocity=-5.0 * jnp.ones(3),
+ base_angular_velocity=-10.0 * jnp.ones(3),
+ contact_state=jnp.array([False] * 4),
+ distance_from_goal_x=-10.0,
+ distance_from_goal_y=-10.0,
+ )
+
+ high = AntObservation.build(
+ base_height=1.0,
+ gravity_projection=jnp.ones(3),
+ # joint_positions=s_max + 0.10 * s_range,
+ joint_positions=s_max,
+ # joint_velocities=4.0 * jnp.ones_like(s_max),
+ joint_velocities=50.0 * jnp.ones_like(s_max),
+ base_linear_velocity=5.0 * jnp.ones(3),
+ base_angular_velocity=10.0 * jnp.ones(3),
+ contact_state=jnp.array([True] * 4),
+ distance_from_goal_x=10.0,
+ distance_from_goal_y=10.0,
+ )
+
+ # Create the observation space (static attribute)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ self._observation_space = spaces.PyTree(low=low, high=high)
+
+ @property
+ def jaxsim(self) -> JaxSim:
+ """"""
+
+ if self._simulator is not None:
+ return self._simulator
+
+ # Create the jaxsim simulator.
+ # We use a small integration step so that contact detection is more accurate,
+ # and perform multiple integration steps when we apply the action.
+ simulator = JaxSim.build(
+ # Note: any change of either 'step_size' or 'steps_per_run' requires
+ # updating the number of integration steps in the 'transition' method.
+ # step_size=0.000_500,
+ step_size=0.000_250,
+ steps_per_run=1,
+ # velocity_representation=VelRepr.Inertial, # TODO
+ velocity_representation=VelRepr.Body,
+ integrator_type=IntegratorType.EulerSemiImplicit,
+ simulator_data=SimulatorData(
+ gravity=jnp.array([0, 0, -10.0]),
+ # contact_parameters=SoftContactsParams.build(K=5_000, D=10),
+ contact_parameters=SoftContactsParams.build(K=10_000, D=20),
+ ),
+ ).mutable(mutable=True, validate=False)
+
+ # Get the SDF path
+ model_sdf_path = (
+ pathlib.Path.home()
+ / "git"
+ / "jaxsim"
+ / "examples"
+ / "resources"
+ / "ant.sdf"
+ )
+
+ # TODO: load with rod and change the pos limits spring & friction params
+
+ # Insert the model
+ _ = simulator.insert_model_from_description(
+ model_description=model_sdf_path, model_name="ant"
+ )
+
+ # Fix the pytree structure of the model data so that its corresponding shape
+ # does not change. This is important to keep enabled the shape validation
+ # checks for JIT compilation.
+ simulator.data.models = {
+ model_name: jax.tree_util.tree_map(lambda leaf: jnp.array(leaf), model_data)
+ for model_name, model_data in simulator.data.models.items()
+ }
+
+ # Store the simulator object and configure it as immutable with enabled
+ # pytree structure validation.
+ # This is done to ensure that the corresponding pytree structure remains constant,
+ # preventing unwanted JIT recompilations due to mistakes when setting its data.
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ self._simulator = simulator.mutable(mutable=True, validate=True)
+
+ return self._simulator
+
+ def initial(self, rng: Any = None) -> StateType:
+ """"""
+
+ # Split the key
+ subkey1, subkey2 = jax.random.split(rng, num=2)
+
+ # Sample an initial observation
+ initial_observation: AntObservation = self.observation_space.sample_with_key(
+ key=subkey1
+ )
+
+ # Sample a goal position
+ goal_xy_position = jax.random.uniform(
+ key=subkey2, minval=-5.0, maxval=5.0, shape=(2,)
+ )
+
+ with self.jaxsim.editable(validate=False) as simulator:
+ # Reset the simulator and get the model
+ simulator.reset(remove_models=False)
+ model = simulator.get_model(model_name="ant")
+
+ # Reset the joint positions
+ model.reset_joint_positions(
+ positions=initial_observation.joint_positions,
+ joint_names=model.joint_names(),
+ )
+
+ # Reset the joint velocities
+ # model.reset_joint_velocities(
+ # velocities=0.1 * initial_observation.joint_velocities,
+ # joint_names=model.joint_names(),
+ # )
+
+ # TODO: inizializzare s.t. non ci siano penetrazioni leg/terrain
+ # Reset the base position
+ model.reset_base_position(
+ # position=jnp.array([0, 0, initial_observation.base_height])
+ position=jnp.array([0, 0, 0.5])
+ )
+
+ # Reset the base velocity
+ model.reset_base_velocity(
+ base_velocity=jnp.hstack(
+ [
+ 0.1 * initial_observation.base_linear_velocity,
+ 0.1 * initial_observation.base_angular_velocity,
+ ]
+ )
+ )
+
+ # Simulate for 1s so that the model starts from a
+ # resting pose on the ground
+ # simulator = simulator.step_over_horizon(
+ # horizon_steps=2 * 1000, clear_inputs=True
+ # )s
+
+ # Return the simulation state
+ return dict(
+ simulator_data=simulator.data,
+ goal=jnp.array(goal_xy_position, dtype=float),
+ )
+
+ def transition(
+ self, state: StateType, action: ActType, rng: Any = None
+ ) -> StateType:
+ """"""
+
+ # Get the JaxSim simulator
+ simulator = self.jaxsim
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ with simulator.editable(validate=True) as simulator:
+ simulator.data = state["simulator_data"]
+
+ @jax_dataclasses.pytree_dataclass
+ class SetTorquesOverHorizon(simulator_callbacks.PreStepCallback):
+ def pre_step(self, sim: JaxSim) -> JaxSim:
+ """"""
+
+ model = sim.get_model(model_name="ant")
+ model.zero_input()
+ model.set_joint_generalized_force_targets(
+ forces=jnp.atleast_1d(action), joint_names=model.joint_names()
+ )
+
+ return sim
+
+ # Compute the number of integration steps to perform
+ # transition_step_duration = 0.050
+ # number_of_integration_steps = int(transition_step_duration / simulator.dt())
+ # number_of_integration_steps = jnp.array(
+ # transition_step_duration / simulator.dt(), dtype=jnp.uint32
+ # )
+
+ # number_of_integration_steps = 100 # 0.050
+ number_of_integration_steps = 40 # 0.010 # TODO 20 for having 0.010
+
+ # Stepping logic
+ with simulator.editable(validate=True) as simulator:
+ # simulator, _ = simulator.step_over_horizon_plain(
+ simulator, _ = simulator.step_over_horizon(
+ horizon_steps=number_of_integration_steps,
+ clear_inputs=False,
+ callback_handler=SetTorquesOverHorizon(),
+ )
+
+ # Return the new environment state (updated SimulatorData)
+ return state | dict(simulator_data=simulator.data)
+
+ def observation(self, state: StateType) -> ObsType:
+ """"""
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ # and get the simulated model
+ with self.jaxsim.editable(validate=True) as simulator:
+ simulator.data = state["simulator_data"]
+ model = simulator.get_model("ant")
+
+ # Compute the normalized gravity projection in the body frame
+ W_R_B = model.base_orientation(dcm=True)
+ # W_gravity = state.simulator.gravity()
+ W_gravity = self.jaxsim.gravity()
+ B_gravity = W_R_B.T @ (W_gravity / jnp.linalg.norm(W_gravity))
+
+ W_p_B = model.base_position()
+ W_p_goal = jnp.hstack([state["goal"].squeeze(), 0])
+
+ # Compute the distance between the base and the goal in the body frame
+ B_p_distance = W_R_B.T @ (W_p_goal - W_p_B)
+
+ # Build the observation from the state
+ return AntObservation.build(
+ base_height=model.base_position()[2],
+ gravity_projection=B_gravity,
+ joint_positions=model.joint_positions(),
+ joint_velocities=model.joint_velocities(),
+ base_linear_velocity=model.base_velocity()[0:3],
+ base_angular_velocity=model.base_velocity()[3:6],
+ contact_state=model.in_contact(
+ link_names=[
+ name
+ for name in model.link_names()
+ if name.startswith("leg_") and name.endswith("_lower")
+ ]
+ ),
+ distance_from_goal_x=B_p_distance[0],
+ distance_from_goal_y=B_p_distance[1],
+ # contact_state=jnp.array(
+ # [
+ # model.get_link(name).in_contact()
+ # for name in model.link_names()
+ # if name.startswith("leg_") and name.endswith("_lower")
+ # ],
+ # dtype=bool,
+ # ),
+ )
+
+ def reward(
+ self, state: StateType, action: ActType, next_state: StateType
+ ) -> RewardType:
+ """"""
+
+ # with self.jaxsim.editable(validate=True) as simulator_pre:
+ # simulator_pre.data = state["simulator_data"]
+ # model_pre = simulator_pre.get_model("ant")
+
+ with self.jaxsim.editable(validate=True) as simulator_next:
+ simulator_next.data = next_state["simulator_data"]
+ model_next = simulator_next.get_model("ant")
+
+ # W_p_B_pre = model_pre.base_position()
+ # W_p_B_next = model_next.base_position()
+ # v_WB = (W_p_B_next - W_p_B_pre) / simulator_pre.dt()
+
+ terminal = self.terminal(state=state)
+ obs_in_space = jax.lax.select(
+ pred=self.observation_space.contains(x=self.observation(state=state)),
+ on_true=1.0,
+ on_false=0.0,
+ )
+
+ # Position of the base
+ W_p_B = model_next.base_position()
+ W_p_xy_goal = state["goal"]
+
+ reward = 0.0
+ reward += 1.0 * (1.0 - jnp.array(terminal, dtype=float)) # alive
+ reward += 5.0 * obs_in_space #
+ # reward += 100.0 * v_WB[0] # forward velocity
+ reward -= jnp.linalg.norm(W_p_B[0:2] - W_p_xy_goal) # distance from goal
+ reward += 1.0 * model_next.in_contact(
+ link_names=[
+ name
+ for name in model_next.link_names()
+ if name.startswith("leg_") and name.endswith("_lower")
+ ]
+ ).any().astype(
+ float
+ ) # contact status
+ reward -= 0.1 * jnp.linalg.norm(action) / action.size # control cost
+
+ return reward
+
+ def terminal(self, state: StateType) -> TerminalType:
+ """"""
+
+ # Get the current observation
+ observation = self.observation(state=state)
+
+ base_too_high = (
+ observation.base_height >= self.observation_space.high.base_height
+ )
+ #
+ # no_feet_in_contact = jnp.where(observation.contact_state.any(), False, True)
+
+ # The state is terminal if the observation is outside is space
+ # return jax.lax.select(
+ # pred=self.observation_space.contains(x=observation),
+ # on_true=False,
+ # on_false=True,
+ # )
+ # return jnp.array([base_too_high, no_feet_in_contact]).any()
+ # return no_feet_in_contact
+ return base_too_high
+
+ # =========
+ # Rendering
+ # =========
+
+ def render_image(
+ self, state: StateType, render_state: RenderStateType
+ ) -> tuple[RenderStateType, npt.NDArray]:
+ """Show the state."""
+
+ model_name = "ant"
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ # and get the simulated model
+ with self.jaxsim.editable(validate=False) as simulator:
+ simulator.data = state["simulator_data"]
+ model = simulator.get_model(model_name=model_name)
+
+ # Insert the model lazily in the visualizer if it is not already there
+ if model_name not in render_state.world._meshcat_models.keys():
+ from rod.urdf.exporter import UrdfExporter
+
+ urdf_string = UrdfExporter.sdf_to_urdf_string(
+ sdf=rod.Sdf(
+ version="1.7",
+ model=model.physics_model.description.extra_info["sdf_model"],
+ ),
+ pretty=True,
+ gazebo_preserve_fixed_joints=False,
+ )
+
+ meshcat_viz_name = render_state.world.insert_model(
+ model_description=urdf_string, is_urdf=True, model_name=None
+ )
+
+ render_state._jaxsim_to_meshcat_viz_name[model_name] = meshcat_viz_name
+
+ # Check that the model is in the visualizer
+ if (
+ not render_state._jaxsim_to_meshcat_viz_name[model_name]
+ in render_state.world._meshcat_models.keys()
+ ):
+ raise ValueError(f"The '{model_name}' model is not in the meshcat world")
+
+ # Update the model in the visualizer
+ render_state.world.update_model(
+ model_name=render_state._jaxsim_to_meshcat_viz_name[model_name],
+ joint_names=model.joint_names(),
+ joint_positions=model.joint_positions(),
+ base_position=model.base_position(),
+ base_quaternion=model.base_orientation(dcm=False),
+ )
+
+ return render_state, np.empty(0)
+
+ def render_init(self, open_gui: bool = False, **kwargs) -> RenderStateType:
+ """Initialize the render state."""
+
+ # Initialize the render state
+ meshcat_viz_state = MeshcatVizRenderState()
+
+ if open_gui:
+ meshcat_viz_state.open_window_in_process()
+
+ return meshcat_viz_state
+
+ def render_close(self, render_state: RenderStateType) -> None:
+ """Close the render state."""
+
+ render_state.close()
+ # render_state.world.close()
+ #
+ # if render_state._gui_process is not None:
+ # render_state._gui_process.terminate()
+ # render_state._gui_process.close()
+
+ # TODO: fare classe generica che dato un JaxSim visualizza tutti i modelli
+ # -> vettorizzato?? metterlo dentro JaxSIm? Fare classe nuova in jaxsim?
+ # def update_meshcat_world(
+ # self, world: "MeshcatWorld", state: StateType # TODO: come gestire lo stato??
+ # ) -> "MeshcatWorld":
+ # """"""
+ #
+ # # Initialize the simulator with the environment state (containing SimulatorData)
+ # # and get the simulated model
+ # with self.jaxsim.editable(validate=False) as simulator:
+ # simulator.data = state
+ # model = simulator.get_model("ant")
+ #
+ # # Add the model to the world if not already present
+ # if "ant" not in world._meshcat_models.keys():
+ # _ = world.insert_model(
+ # model_description=(
+ # pathlib.Path.home()
+ # / "git"
+ # / "jaxsim"
+ # / "examples"
+ # / "resources"
+ # / "ant.sdf"
+ # ),
+ # is_urdf=False,
+ # model_name="ant",
+ # )
+ #
+ # # Update the model
+ # world.update_model(
+ # model_name="ant",
+ # joint_names=model.joint_names(),
+ # joint_positions=model.joint_positions(),
+ # base_position=model.base_position(),
+ # base_quaternion=model.base_orientation(dcm=False),
+ # )
+ #
+ # return world
+
+
+class AntReachTargetEnvV0(JaxEnv):
+ """"""
+
+ def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None:
+ """"""
+
+ from jaxgym.wrappers.jax import (
+ ClipActionWrapper,
+ FlattenSpacesWrapper,
+ JaxTransformWrapper,
+ TimeLimit,
+ )
+
+ func_env = AntReachTargetFuncEnvV0()
+
+ func_env_wrapped = func_env
+ func_env_wrapped = TimeLimit(
+ env=func_env_wrapped, max_episode_steps=5_000
+ ) # TODO
+ func_env_wrapped = ClipActionWrapper(env=func_env_wrapped)
+ func_env_wrapped = FlattenSpacesWrapper(env=func_env_wrapped)
+ func_env_wrapped = JaxTransformWrapper(env=func_env_wrapped, function=jax.jit)
+
+ super().__init__(
+ func_env=func_env_wrapped,
+ metadata=self.metadata,
+ render_mode=render_mode,
+ )
+
+
+class AntReachTargetVectorEnvV0(JaxVectorEnv):
+ """"""
+
+ metadata = dict()
+
+ def __init__(
+ self,
+ # func_env: JaxDataclassEnv[
+ # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ # ],
+ num_envs: int,
+ render_mode: str | None = None,
+ # max_episode_steps: int = 5_000,
+ jit_compile: bool = True,
+ **kwargs,
+ ) -> None:
+ """"""
+
+ print("+++", kwargs)
+
+ env = AntReachTargetFuncEnvV0()
+
+ # Vectorize the environment.
+ # Note: it automatically wraps the environment in a TimeLimit wrapper.
+ super().__init__(
+ func_env=env,
+ num_envs=num_envs,
+ metadata=self.metadata,
+ render_mode=render_mode,
+ max_episode_steps=5_000, # TODO
+ jit_compile=jit_compile,
+ )
+
+ # from jaxgym.vector.jax import FlattenSpacesVecWrapper
+ #
+ # vec_env_wrapped = FlattenSpacesVecWrapper(env=vec_env)
+
+
+if __name__ == "__main__":
+ """Stable Baselines"""
+
+ from typing import Optional
+
+ from jaxgym.wrappers.jax import (
+ FlattenSpacesWrapper,
+ JaxTransformWrapper,
+ TimeLimit,
+ ToNumPyWrapper,
+ )
+
+ def make_jax_env(
+ max_episode_steps: Optional[int] = 500, jit: bool = True
+ ) -> JaxEnv:
+ """"""
+
+ # TODO: single env -> time limit with stable_baselines?
+
+ if max_episode_steps in {None, 0}:
+ env = AntReachTargetFuncEnvV0()
+ else:
+ env = TimeLimit(
+ env=AntReachTargetFuncEnvV0(), max_episode_steps=max_episode_steps
+ )
+
+ return JaxEnv(
+ func_env=ToNumPyWrapper(
+ env=FlattenSpacesWrapper(env=env)
+ if not jit
+ else JaxTransformWrapper(
+ function=jax.jit,
+ env=FlattenSpacesWrapper(env=env),
+ ),
+ ),
+ render_mode="meshcat_viz",
+ )
+
+ env = make_jax_env(max_episode_steps=5, jit=False)
+
+ obs, state_info = env.reset(seed=0)
+ _ = env.render()
+ raise
+ for _ in range(5):
+ action = env.action_space.sample()
+ # obs, reward, terminated, truncated, info = env.step(action=action)
+ obs, reward, terminated, truncated, info = env.step(
+ action=jnp.zeros_like(action)
+ )
+
+ # =========
+ # Visualize
+ # =========
+
+ visualize = False
+
+ if visualize:
+ rollout_visualizer = visualizer(env=lambda: make_jax_env(1_000), policy=model)
+
+ import time
+
+ time.sleep(3)
+ rollout_visualizer(None)
diff --git a/src/jaxgym/envs/cartpole.py b/src/jaxgym/envs/cartpole.py
new file mode 100644
index 000000000..cad6830be
--- /dev/null
+++ b/src/jaxgym/envs/cartpole.py
@@ -0,0 +1,570 @@
+import pathlib
+from typing import Any, ClassVar
+
+import jax.numpy as jnp
+import jax.random
+import jax_dataclasses
+import numpy as np
+import numpy.typing as npt
+
+import jaxgym.jax.pytree_space as spaces
+import jaxsim.typing as jtp
+from jaxgym.envs.ant import MeshcatVizRenderState
+from jaxgym.jax import JaxDataclassEnv, JaxEnv
+from jaxgym.vector.jax import JaxVectorEnv
+from jaxsim import JaxSim, logging
+from jaxsim.simulation.ode_integration import IntegratorType
+from jaxsim.simulation.simulator import SimulatorData, VelRepr
+from jaxsim.utils import JaxsimDataclass, Mutability
+
+
+@jax_dataclasses.pytree_dataclass
+class CartpoleObservation(JaxsimDataclass):
+ """Observation of the CartPole environment."""
+
+ linear_pos: jtp.Float
+ linear_vel: jtp.Float
+
+ pivot_pos: jtp.Float
+ pivot_vel: jtp.Float
+
+ @staticmethod
+ def build(
+ linear_pos: jtp.Float,
+ linear_vel: jtp.Float,
+ pivot_pos: jtp.Float,
+ pivot_vel: jtp.Float,
+ ) -> "CartpoleObservation":
+ """"""
+
+ return CartpoleObservation(
+ linear_pos=jnp.array(linear_pos, dtype=float),
+ linear_vel=jnp.array(linear_vel, dtype=float),
+ pivot_pos=jnp.array(pivot_pos, dtype=float),
+ pivot_vel=jnp.array(pivot_vel, dtype=float),
+ )
+
+
+StateType = SimulatorData
+ActType = jnp.ndarray
+ObsType = CartpoleObservation
+RewardType = float | jnp.ndarray
+TerminalType = bool | jnp.ndarray
+RenderStateType = MeshcatVizRenderState
+
+
+@jax_dataclasses.pytree_dataclass
+class CartpoleSwingUpFuncEnvV0(
+ JaxDataclassEnv[
+ StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ ]
+):
+ """CartPole environment implementing a swing-up task."""
+
+ name: ClassVar = jax_dataclasses.static_field(
+ default="CartpoleSwingUpFunctionalEnvV0"
+ )
+
+ # Store an instance of the JaxSim simulator.
+ # It gets initialized with SimulatorData with a functional approach.
+ _simulator: JaxSim = jax_dataclasses.field(default=None)
+
+ def __post_init__(self) -> None:
+ """Environment initialization."""
+
+ # Dummy initialization (not needed here)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ _ = self.jaxsim
+ # _ = self.initial(rng=jax.random.PRNGKey(seed=0))
+
+ # Create the action space (static attribute)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ self._action_space = spaces.PyTree(
+ low=jnp.array(-50.0, dtype=float), high=jnp.array(50.0, dtype=float)
+ )
+
+ low = CartpoleObservation.build(
+ linear_pos=-2.4,
+ linear_vel=-10.0,
+ pivot_pos=-jnp.pi,
+ pivot_vel=-4 * jnp.pi,
+ )
+
+ high = CartpoleObservation.build(
+ linear_pos=2.4,
+ linear_vel=10.0,
+ pivot_pos=jnp.pi,
+ pivot_vel=4 * jnp.pi,
+ )
+
+ # Create the observation space (static attribute)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ self._observation_space = spaces.PyTree(low=low, high=high)
+
+ @property
+ def jaxsim(self) -> JaxSim:
+ """"""
+
+ if self._simulator is not None:
+ return self._simulator
+
+ # T = 0.010
+ # dt = 0.001
+ T = 0.050
+ dt = 0.000_500
+
+ # Create the jaxsim simulator
+ simulator = JaxSim.build(
+ # step_size=0.001,
+ # steps_per_run=10,
+ step_size=dt,
+ steps_per_run=int(T / dt),
+ velocity_representation=VelRepr.Inertial,
+ integrator_type=IntegratorType.EulerSemiImplicit,
+ simulator_data=SimulatorData(gravity=jnp.array([0, 0, -10.0])),
+ ).mutable(mutable=True, validate=False)
+
+ # Get the SDF path
+ model_urdf_path = (
+ pathlib.Path.home()
+ / "git"
+ / "jaxsim"
+ / "examples"
+ / "resources"
+ / "cartpole.urdf"
+ )
+
+ # Insert the model
+ _ = simulator.insert_model_from_description(
+ model_description=model_urdf_path, model_name="cartpole"
+ )
+
+ # Fix the pytree structure of the model data so that its corresponding shape
+ # does not change. This is important to keep enabled the shape validation
+ # checks for JIT compilation.
+ simulator.data.models = {
+ model_name: jax.tree_util.tree_map(lambda leaf: jnp.array(leaf), model_data)
+ for model_name, model_data in simulator.data.models.items()
+ }
+
+ # Store the simulator object and configure it as immutable with enabled
+ # pytree structure validation.
+ # This is done to ensure that the corresponding pytree structure remains constant,
+ # preventing unwanted JIT recompilations due to mistakes when setting its data.
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ self._simulator = simulator.mutable(mutable=True, validate=True)
+
+ return self._simulator
+
+ def initial(self, rng: Any = None) -> StateType:
+ """"""
+
+ # Sample an initial observation
+ initial_observation: CartpoleObservation = (
+ self.observation_space.sample_with_key(key=rng)
+ )
+
+ with self.jaxsim.editable(validate=False) as simulator:
+ # Reset the simulator and get the model
+ simulator.reset(remove_models=False)
+ model = simulator.get_model(model_name="cartpole")
+
+ # Reset the joint positions
+ model.reset_joint_positions(
+ positions=0.9
+ * jnp.array(
+ [initial_observation.linear_pos, initial_observation.pivot_pos]
+ ),
+ joint_names=["linear", "pivot"],
+ )
+
+ # Reset the joint velocities
+ model.reset_joint_velocities(
+ velocities=0.9
+ * jnp.array(
+ [initial_observation.linear_vel, initial_observation.pivot_vel]
+ ),
+ joint_names=["linear", "pivot"],
+ )
+
+ # TODO: reset the joint velocities
+ # logging.error("ZEROOO")
+ # model.reset_joint_positions(positions=jnp.array([0, jnp.deg2rad(180.0)]))
+ # model.reset_joint_velocities(velocities=jnp.array([0, 0.0]))
+
+ # Return the simulation state
+ return simulator.data
+
+ def transition(
+ self, state: StateType, action: ActType, rng: Any = None
+ ) -> StateType:
+ """"""
+
+ # Get the JaxSim simulator
+ simulator = self.jaxsim
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ with simulator.editable(validate=True) as simulator:
+ simulator.data = state
+
+ # Stepping logic
+ with simulator.editable(validate=True) as simulator:
+ # Get the simulated model
+ model = simulator.get_model(model_name="cartpole")
+
+ # Zero all the inputs
+ model.zero_input()
+
+ # print(action)
+ # action = action.squeeze()
+
+ # Apply a linear force to the cart
+ model.set_joint_generalized_force_targets(
+ forces=jnp.atleast_1d(action), joint_names=["linear"]
+ )
+
+ # TODO: in multi-step -> reset action?
+ # Or always one step and handle multi-steps with callbacks e.g. controllers?
+ simulator.step(clear_inputs=False)
+
+ # Return the new environment state (updated SimulatorData)
+ return simulator.data
+
+ def observation(self, state: StateType) -> ObsType:
+ """"""
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ # and get the simulated model
+ with self.jaxsim.editable(validate=False) as simulator:
+ simulator.data = state
+ model = simulator.get_model("cartpole")
+
+ # Extract the positions and velocities of the joints
+ linear_pos, pivot_pos = model.joint_positions()
+ linear_vel, pivot_vel = model.joint_velocities()
+
+ # Build the observation from the state
+ return CartpoleObservation.build(
+ linear_pos=linear_pos,
+ linear_vel=linear_vel,
+ # Make sure that the pivot position is always in [-π, π]
+ pivot_pos=jnp.arctan2(jnp.sin(pivot_pos), jnp.cos(pivot_pos)),
+ pivot_vel=pivot_vel,
+ )
+
+ def reward(
+ self, state: StateType, action: ActType, next_state: StateType
+ ) -> RewardType:
+ """"""
+
+ # Get the current observation
+ observation = self.observation(state=next_state)
+ # observation = type(self).observation(self=self, state=next_state)
+
+ # Compute the reward terms
+ reward_alive = 1.0 - jnp.array(self.terminal(state=next_state), dtype=float)
+ # type(self).terminal(self=self, state=next_state), dtype=float
+ reward_pivot = jnp.cos(observation.pivot_pos)
+ cost_action = jnp.sqrt(action.dot(action))
+ cost_pivot_vel = jnp.sqrt(observation.pivot_vel ** 2)
+ cost_linear_pos = jnp.abs(observation.linear_pos)
+
+ reward = 0
+ reward += reward_alive
+ reward += reward_pivot
+ reward -= 0.001 * cost_action
+ reward -= 0.100 * cost_pivot_vel
+ reward -= 0.500 * cost_linear_pos
+
+ return reward
+
+ def terminal(self, state: StateType) -> TerminalType:
+ """"""
+
+ # Get the current observation
+ observation = self.observation(state=state)
+ # observation = type(self).observation(self=self, state=state)
+
+ # The state is terminal if the observation is outside is space
+ return jax.lax.select(
+ pred=self.observation_space.contains(x=observation),
+ on_true=False,
+ on_false=True,
+ )
+
+ # =========
+ # Rendering
+ # =========
+
+ # =========
+ # Rendering
+ # =========
+
+ def render_image(
+ self, state: StateType, render_state: RenderStateType
+ ) -> tuple[RenderStateType, npt.NDArray]:
+ """Show the state."""
+
+ model_name = "cartpole"
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ # and get the simulated model
+ with self.jaxsim.editable(validate=False) as simulator:
+ simulator.data = state
+ model = simulator.get_model(model_name=model_name)
+
+ # Insert the model lazily in the visualizer if it is not already there
+ if model_name not in render_state.world._meshcat_models.keys():
+ import rod
+ from rod.urdf.exporter import UrdfExporter
+
+ urdf_string = UrdfExporter.sdf_to_urdf_string(
+ sdf=rod.Sdf(
+ version="1.7",
+ model=model.physics_model.description.extra_info["sdf_model"],
+ ),
+ pretty=True,
+ gazebo_preserve_fixed_joints=False,
+ )
+
+ meshcat_viz_name = render_state.world.insert_model(
+ model_description=urdf_string, is_urdf=True, model_name=None
+ )
+
+ render_state._jaxsim_to_meshcat_viz_name[model_name] = meshcat_viz_name
+
+ # Check that the model is in the visualizer
+ if (
+ not render_state._jaxsim_to_meshcat_viz_name[model_name]
+ in render_state.world._meshcat_models.keys()
+ ):
+ raise ValueError(f"The '{model_name}' model is not in the meshcat world")
+
+ # Update the model in the visualizer
+ render_state.world.update_model(
+ model_name=render_state._jaxsim_to_meshcat_viz_name[model_name],
+ joint_names=model.joint_names(),
+ joint_positions=model.joint_positions(),
+ base_position=model.base_position(),
+ base_quaternion=model.base_orientation(dcm=False),
+ )
+
+ return render_state, np.empty(0)
+
+ def render_init(self, open_gui: bool = False, **kwargs) -> RenderStateType:
+ """Initialize the render state."""
+
+ # Initialize the render state
+ meshcat_viz_state = MeshcatVizRenderState()
+
+ if open_gui:
+ meshcat_viz_state.open_window_in_process()
+
+ return meshcat_viz_state
+
+ def render_close(self, render_state: RenderStateType) -> None:
+ """Close the render state."""
+
+ render_state.close()
+
+
+class CartpoleSwingUpEnvV0(JaxEnv):
+ """"""
+
+ def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None:
+ """"""
+
+ from jaxgym.wrappers.jax import (
+ ClipActionWrapper,
+ FlattenSpacesWrapper,
+ JaxTransformWrapper,
+ TimeLimit,
+ )
+
+ func_env = CartpoleSwingUpFuncEnvV0()
+
+ func_env_wrapped = func_env
+ func_env_wrapped = TimeLimit(env=func_env_wrapped, max_episode_steps=5_000)
+ func_env_wrapped = ClipActionWrapper(env=func_env_wrapped)
+ func_env_wrapped = FlattenSpacesWrapper(env=func_env_wrapped)
+ func_env_wrapped = JaxTransformWrapper(env=func_env_wrapped, function=jax.jit)
+
+ super().__init__(
+ func_env=func_env_wrapped,
+ metadata=self.metadata,
+ render_mode=render_mode,
+ )
+
+
+class CartpoleSwingUpVectorEnvV0(JaxVectorEnv):
+ """"""
+
+ metadata = dict()
+
+ def __init__(
+ self,
+ # func_env: JaxDataclassEnv[
+ # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ # ],
+ num_envs: int,
+ render_mode: str | None = None,
+ # max_episode_steps: int = 5_000,
+ jit_compile: bool = True,
+ **kwargs,
+ ) -> None:
+ """"""
+
+ # print("+++", kwargs)
+
+ env = CartpoleSwingUpFuncEnvV0()
+
+ # Vectorize the environment.
+ # Note: it automatically wraps the environment in a TimeLimit wrapper.
+ super().__init__(
+ func_env=env,
+ num_envs=num_envs,
+ metadata=self.metadata,
+ render_mode=render_mode,
+ max_episode_steps=5_000, # TODO
+ jit_compile=jit_compile,
+ )
+
+ # from jaxgym.vector.jax import FlattenSpacesVecWrapper
+ #
+ # vec_env_wrapped = FlattenSpacesVecWrapper(env=vec_env)
+
+
+if __name__ == "__main__REGISTER":
+ """"""
+
+ import gymnasium as gym
+
+ import jaxgym.envs
+
+ gym.envs.registry.keys()
+
+ #
+ #
+ #
+
+ env = gym.make("CartpoleSwingUpEnv-V0")
+ env.spec.pprint(print_all=True)
+
+ #
+ #
+ #
+
+ vec_env = gym.make_vec(
+ "CartpoleSwingUpEnv-V0", num_envs=2, vectorization_mode="custom"
+ )
+ vec_env.spec.pprint(print_all=True)
+
+ from jaxgym.vector.jax.wrappers import FlattenSpacesVecWrapper
+
+ vec_env_wrapped = FlattenSpacesVecWrapper(env=vec_env)
+
+
+if __name__ == "__main__+":
+ """"""
+
+ # env = CartpoleFunctionalEnvV0()
+ # state = env.initial(rng=jax.random.PRNGKey(0))
+ # action = env.action_space.sample(key=jax.random.PRNGKey(1))
+
+ key = jax.random.PRNGKey(0)
+ num = 1000
+
+ # TODO next week:
+ # - this is ok
+ # - write Env wrapper for autoreset / check gymnasium -> lambda for get_action from pytorch?
+ # - figure out what the info dicts of gymnasium are
+ # - decide how to perform training loop -> rl algos from where (jax-based or pytorch)?
+ #
+ # Pytorch is ok if we sample in parallel only a single step (e.g. on thousands of envs)
+
+ # from jaxgym.functional.wrappers.transform import TransformWrapper
+ # from jaxgym.functional.wrappers.jax.time_limit import TimeLimit
+ #
+ # env = CartpoleSwingUpFunctionalEnvV0()
+ # env = TimeLimit(env=env, max_episode_steps=100)
+ # # vec_env.transform(func=jax.vmap)
+ # # vec_env.transform(func=jax.jit)
+ # vec_env = TransformWrapper(env=env, function=jax.vmap)
+ # vec_env = TransformWrapper(env=vec_env, function=jax.jit)
+ # states = vec_env.initial(rng=jax.random.split(key, num=num))
+ # _ = vec_env.observation(state=states)
+ # action = vec_env.action_space.sample(key=jax.random.split(key, num=1).squeeze())
+ # actions = jnp.repeat(action, repeats=num, axis=0)
+ # next_states = vec_env.transition(state=states, action=actions)
+ # reward = vec_env.reward(state=states, action=actions, next_state=next_states)
+ # infos = vec_env.step_info(state=states, action=actions, next_state=next_states)
+
+ # from jaxgym.functional.jax.vector import JaxVectorEnv
+ # from jaxgym.functional.jax.time_limit import TimeLimit
+ # from jaxgym.functional.wrappers.transform import TransformWrapper
+ # from jaxgym.functional.core import FuncWrapper
+ #
+ # env = CartpoleSwingUpFunctionalEnvV0()
+ #
+ # env_wrapped = TimeLimit(env=env, max_episode_steps=100)
+ # env_wrapped = TransformWrapper(env=env_wrapped, function=jax.jit)
+ # # CartpoleSwingUpFunctionalEnvV0.transform(self=env_wrapped, func=jax.jit)
+ # state = env_wrapped.initial(rng=key)
+ # _ = env_wrapped.observation(state=state)
+ # action = env_wrapped.action_space.sample(key=key)
+ # next_state = env_wrapped.transition(state=state, action=action)
+ # reward = env_wrapped.reward(state=state, action=action, next_state=next_state)
+ # info = env_wrapped.step_info(state=state, action=action, next_state=next_state)
+ # next_state = env_wrapped.transition(state=next_state, action=action)
+
+ from jaxgym.functional.jax.time_limit import TimeLimit
+ from jaxgym.functional.jax.transform import JaxTransformWrapper
+ from jaxgym.functional.jax.vector import JaxVectorEnv
+ from jaxgym.functional.wrappers.transform import TransformWrapper
+
+ env = CartpoleSwingUpFunctionalEnvV0()
+ # env = TimeLimit(env=env, max_episode_steps=3)
+ num_envs = 2
+ vec_env = JaxVectorEnv(
+ func_env=env, num_envs=num_envs, max_episode_steps=4, jit_compile=True
+ )
+ observations, state_infos = vec_env.reset()
+
+ actions = vec_env.action_space.sample(key=key)
+ # actions = jnp.repeat(jnp.atleast_2d(action).T, repeats=num_envs, axis=1).T
+
+ # TODO: the output dict misses the final observation when truncated
+ # final_observation | final_info
+ _ = vec_env.step(action=actions)
+
+ # self = vec_env
+ # env = self.func_env
+ # states = self.states
+ # keys_1 = self.subkey(num=self.num_envs)
+ # keys_2 = self.subkey(num=self.num_envs)
+ # states, _ = jax.jit(JaxVectorEnv.step_autoreset_func)(
+ # env, states, actions, keys_1, keys_2
+ # )
+
+ # @jax.jit
+ # def split(key: jax.random.PRNGKeyArray, num: int) -> jax.random.PRNGKeyArray:
+ # return jax.random.split(key=key, num=num)
+ #
+ # _ = split(key, 5)
+
+ # _ = vec_env.func_env.transition(state=vec_env.states, action=actions)
+ # _ = vec_env.step(action=actions)
+
+ # observation = env.observation(state=state)
+ # terminal = env.terminal(state=state)
+ # reward = env.reward(state=state, action=action)
+ # next_state = env.transition(state=state, action=action, rng=jax.random.PRNGKey(2))
+
+ # jax.tree_util.tree_structure(state)
+ # jax.tree_util.tree_structure(next_state)
+ #
+ # jax.tree_util.tree_leaves(state)
+ # jax.tree_util.tree_leaves(next_state)
+
+ # with env.jaxsim.editable(validate=True) as simulator:
+ # simulator.data = next_state
diff --git a/src/jaxgym/envs/ergocub.py b/src/jaxgym/envs/ergocub.py
new file mode 100644
index 000000000..c3d45457c
--- /dev/null
+++ b/src/jaxgym/envs/ergocub.py
@@ -0,0 +1,810 @@
+import dataclasses
+import functools
+import multiprocessing
+import os
+import warnings
+from typing import Any, Dict, List, Optional, Type, Union
+
+import gymnasium as gym
+import jax.numpy as jnp
+import jax.random
+import jax_dataclasses
+import numpy as np
+import numpy.typing as npt
+import rod
+from gymnasium.experimental.vector.vector_env import VectorWrapper
+from meshcat_viz import MeshcatWorld
+from resolve_robotics_uri_py import resolve_robotics_uri
+from stable_baselines3 import PPO
+from stable_baselines3.common import vec_env as vec_env_sb
+from stable_baselines3.common.env_util import make_vec_env
+from stable_baselines3.common.vec_env import VecMonitor, VecNormalize
+from torch import nn
+
+import jaxgym.jax.pytree_space as spaces
+import jaxsim.typing as jtp
+from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper, JaxEnv, PyTree
+from jaxgym.vector.jax import FlattenSpacesVecWrapper, JaxVectorEnv
+from jaxgym.wrappers.jax import (
+ ActionNoiseWrapper,
+ ClipActionWrapper,
+ FlattenSpacesWrapper,
+ JaxTransformWrapper,
+ NaNHandlerWrapper,
+ SquashActionWrapper,
+ TimeLimit,
+ ToNumPyWrapper,
+)
+from jaxsim import JaxSim
+from jaxsim.physics.algos.soft_contacts import SoftContactsParams
+from jaxsim.simulation import simulator_callbacks
+from jaxsim.simulation.ode_integration import IntegratorType
+from jaxsim.simulation.simulator import SimulatorData, VelRepr
+from jaxsim.utils import JaxsimDataclass, Mutability
+
+warnings.simplefilter(action="ignore", category=FutureWarning)
+
+
+@jax_dataclasses.pytree_dataclass
+class ErgoCubObservation(JaxsimDataclass):
+ """Observation of the ErgoCub environment."""
+
+ base_height: jtp.Float
+ gravity_projection: jtp.Array
+
+ joint_positions: jtp.Array
+ joint_velocities: jtp.Array
+
+ base_linear_velocity: jtp.Array
+ base_angular_velocity: jtp.Array
+
+ contact_state: jtp.Array
+
+ @staticmethod
+ def build(
+ base_height: jtp.Float,
+ gravity_projection: jtp.Array,
+ joint_positions: jtp.Array,
+ joint_velocities: jtp.Array,
+ base_linear_velocity: jtp.Array,
+ base_angular_velocity: jtp.Array,
+ contact_state: jtp.Array,
+ ) -> "ErgoCubObservation":
+ """Build an ErgoCubObservation object."""
+
+ return ErgoCubObservation(
+ base_height=jnp.array(base_height, dtype=float),
+ gravity_projection=jnp.array(gravity_projection, dtype=float),
+ joint_positions=jnp.array(joint_positions, dtype=float),
+ joint_velocities=jnp.array(joint_velocities, dtype=float),
+ base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
+ base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
+ contact_state=jnp.array(contact_state, dtype=bool),
+ )
+
+
+@dataclasses.dataclass
+class MeshcatVizRenderState:
+ """Render state of a meshcat-viz visualizer."""
+
+ world: MeshcatWorld = dataclasses.dataclass(init=False)
+
+ _gui_process: Optional[multiprocessing.Process] = dataclasses.field(
+ default=None, init=False, repr=False, hash=False, compare=False
+ )
+
+ _jaxsim_to_meshcat_viz_name: dict[str, str] = dataclasses.field(
+ default_factory=dict, init=False, repr=False, hash=False, compare=False
+ )
+
+ def __post_init__(self) -> None:
+ """"""
+
+ self.world = MeshcatWorld()
+ self.world.open()
+
+ def close(self) -> None:
+ """"""
+
+ if self.world is not None:
+ self.world.close()
+
+ if self._gui_process is not None:
+ self._gui_process.terminate()
+ self._gui_process.close()
+
+ @staticmethod
+ def open_window(web_url: str) -> None:
+ """Open a new window with the given web url."""
+
+ import webview
+
+ print(web_url)
+ webview.create_window("meshcat", web_url)
+ webview.start(gui="qt")
+
+ def open_window_in_process(self) -> None:
+ """"""
+
+ if self._gui_process is not None:
+ self._gui_process.terminate()
+ self._gui_process.close()
+
+ self._gui_process = multiprocessing.Process(
+ target=MeshcatVizRenderState.open_window, args=(self.world.web_url,)
+ )
+ self._gui_process.start()
+
+
+StateType = dict[str, SimulatorData | jtp.Array]
+ActType = jnp.ndarray
+ObsType = ErgoCubObservation
+RewardType = float | jnp.ndarray
+TerminalType = bool | jnp.ndarray
+RenderStateType = MeshcatVizRenderState
+
+
+@jax_dataclasses.pytree_dataclass
+class ErgoCubWalkFuncEnvV0(
+ JaxDataclassEnv[
+ StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ ]
+):
+ """ErgoCub environment implementing a target reaching task."""
+
+ name: ClassVar = jax_dataclasses.static_field(default="ErgoCubWalkFuncEnvV0")
+
+ # Store an instance of the JaxSim simulator.
+ # It gets initialized with SimulatorData with a functional approach.
+ _simulator: JaxSim = jax_dataclasses.field(default=None)
+
+ def __post_init__(self) -> None:
+ """Environment initialization."""
+
+ model = self.jaxsim.get_model(model_name="ErgoCub")
+
+ # Create the action space (static attribute)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ high = jnp.array([25.0] * model.dofs(), dtype=float)
+
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ self._action_space = spaces.PyTree(low=-high, high=high)
+
+ # Get joint limits
+ s_min, s_max = model.joint_limits()
+ s_range = s_max - s_min
+
+ low = ErgoCubObservation.build(
+ base_height=0.25,
+ gravity_projection=-jnp.ones(3),
+ joint_positions=s_min,
+ joint_velocities=-50.0 * jnp.ones_like(s_min),
+ base_linear_velocity=-5.0 * jnp.ones(3),
+ base_angular_velocity=-10.0 * jnp.ones(3),
+ contact_state=jnp.array([False] * 4),
+ )
+
+ high = ErgoCubObservation.build(
+ base_height=1.0,
+ gravity_projection=jnp.ones(3),
+ joint_positions=s_max,
+ joint_velocities=50.0 * jnp.ones_like(s_max),
+ base_linear_velocity=5.0 * jnp.ones(3),
+ base_angular_velocity=10.0 * jnp.ones(3),
+ contact_state=jnp.array([True] * 4),
+ )
+
+ # Create the observation space (static attribute)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ self._observation_space = spaces.PyTree(low=low, high=high)
+
+ @property
+ def jaxsim(self) -> JaxSim:
+ """"""
+
+ if self._simulator is not None:
+ return self._simulator
+
+ # Create the jaxsim simulator.
+ simulator = JaxSim.build(
+ # Note: any change of either 'step_size' or 'steps_per_run' requires
+ # updating the number of integration steps in the 'transition' method.
+ step_size=0.000_250,
+ steps_per_run=1,
+ velocity_representation=VelRepr.Body,
+ integrator_type=IntegratorType.EulerSemiImplicit,
+ simulator_data=SimulatorData(
+ gravity=jnp.array([0, 0, -10.0]),
+ contact_parameters=SoftContactsParams.build(K=10_000, D=20),
+ ),
+ ).mutable(mutable=True, validate=False)
+
+ # Get the SDF path
+ model_sdf_path = resolve_robotics_uri(
+ "package://ergoCub/robots/ergoCubGazeboV1_minContacts/model.urdf"
+ )
+
+ # Insert the model
+ _ = simulator.insert_model_from_description(
+ model_description=model_sdf_path, model_name="ErgoCub"
+ )
+
+ simulator.data.models = {
+ model_name: jax.tree_util.tree_map(lambda leaf: jnp.array(leaf), model_data)
+ for model_name, model_data in simulator.data.models.items()
+ }
+
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ self._simulator = simulator.mutable(mutable=True, validate=True)
+
+ return self._simulator
+
+ def initial(self, rng: Any = None) -> StateType:
+ """"""
+ assert jax.dtypes.issubdtype(rng, jax.dtypes.prng_key)
+
+ # Split the key
+ subkey1, subkey2 = jax.random.split(rng, num=2)
+
+ # Sample an initial observation
+ initial_observation: ErgoCubObservation = (
+ self.observation_space.sample_with_key(key=subkey1)
+ )
+
+ # Sample a goal position
+ goal_xy_position = jax.random.uniform(
+ key=subkey2, minval=-5.0, maxval=5.0, shape=(2,)
+ )
+
+ with self.jaxsim.editable(validate=False) as simulator:
+ # Reset the simulator and get the model
+ simulator.reset(remove_models=False)
+ model = simulator.get_model(model_name="ErgoCub")
+
+ # Reset the joint positions
+ model.reset_joint_positions(
+ positions=initial_observation.joint_positions,
+ joint_names=model.joint_names(),
+ )
+
+ # Reset the base position
+ model.reset_base_position(position=jnp.array([0, 0, 0.5]))
+
+ # Reset the base velocity
+ model.reset_base_velocity(
+ base_velocity=jnp.hstack(
+ [
+ 0.1 * initial_observation.base_linear_velocity,
+ 0.1 * initial_observation.base_angular_velocity,
+ ]
+ )
+ )
+
+ # Return the simulation state
+ return dict(
+ simulator_data=simulator.data,
+ goal=jnp.array(goal_xy_position, dtype=float),
+ )
+
+ def transition(
+ self, state: StateType, action: ActType, rng: Any = None
+ ) -> StateType:
+ """"""
+
+ # Get the JaxSim simulator
+ simulator = self.jaxsim
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ with simulator.editable(validate=True) as simulator:
+ simulator.data = state["simulator_data"]
+
+ @jax_dataclasses.pytree_dataclass
+ class SetTorquesOverHorizon(simulator_callbacks.PreStepCallback):
+ def pre_step(self, sim: JaxSim) -> JaxSim:
+ """"""
+
+ model = sim.get_model(model_name="ErgoCub")
+ model.zero_input()
+ model.set_joint_generalized_force_targets(
+ forces=jnp.atleast_1d(action), joint_names=model.joint_names()
+ )
+
+ return sim
+
+ number_of_integration_steps = 40 # 0.010 # TODO 20 for having 0.010
+
+ # Stepping logic
+ with simulator.editable(validate=True) as simulator:
+ simulator, _ = simulator.step_over_horizon(
+ horizon_steps=number_of_integration_steps,
+ clear_inputs=False,
+ callback_handler=SetTorquesOverHorizon(),
+ )
+
+ # Return the new environment state (updated SimulatorData)
+ return state | dict(simulator_data=simulator.data)
+
+ def observation(self, state: StateType) -> ObsType:
+ """"""
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ # and get the simulated model
+ with self.jaxsim.editable(validate=True) as simulator:
+ simulator.data = state["simulator_data"]
+ model = simulator.get_model("ErgoCub")
+
+ # Compute the normalized gravity projection in the body frame
+ W_R_B = model.base_orientation(dcm=True)
+ W_gravity = self.jaxsim.gravity()
+ B_gravity = W_R_B.T @ (W_gravity / jnp.linalg.norm(W_gravity))
+
+ W_p_B = model.base_position()
+ W_p_goal = jnp.hstack([state["goal"].squeeze(), 0])
+
+ # Compute the distance between the base and the goal in the body frame
+ B_p_distance = W_R_B.T @ (W_p_goal - W_p_B)
+
+ # Build the observation from the state
+ return ErgoCubObservation.build(
+ base_height=model.base_position()[2],
+ gravity_projection=B_gravity,
+ joint_positions=model.joint_positions(),
+ joint_velocities=model.joint_velocities(),
+ base_linear_velocity=model.base_velocity()[0:3],
+ base_angular_velocity=model.base_velocity()[3:6],
+ contact_state=model.in_contact(
+ link_names=[name for name in model.link_names() if "_ankle" in name]
+ ),
+ )
+
+ def reward(
+ self, state: StateType, action: ActType, next_state: StateType
+ ) -> RewardType:
+ """"""
+
+ with self.jaxsim.editable(validate=True) as simulator_next:
+ simulator_next.data = next_state["simulator_data"]
+ model_next = simulator_next.get_model("ErgoCub")
+
+ terminal = self.terminal(state=state)
+ obs_in_space = jax.lax.select(
+ pred=self.observation_space.contains(x=self.observation(state=state)),
+ on_true=1.0,
+ on_false=0.0,
+ )
+
+ # Position of the base
+ W_p_B = model_next.base_position()
+ W_p_xy_goal = state["goal"]
+
+ reward = 0.0
+ reward += 1.0 * (1.0 - jnp.array(terminal, dtype=float)) # alive
+ reward += 5.0 * obs_in_space #
+ # reward += 100.0 * v_WB[0] # forward velocity
+ reward -= jnp.linalg.norm(W_p_B[0:2] - W_p_xy_goal) # distance from goal
+ reward += 1.0 * model_next.in_contact(
+ link_names=[
+ name
+ for name in model_next.link_names()
+ if name.startswith("leg_") and name.endswith("_lower")
+ ]
+ ).any().astype(float)
+ reward -= 0.1 * jnp.linalg.norm(action) / action.size # control cost
+
+ return reward
+
+ def terminal(self, state: StateType) -> TerminalType:
+ # Get the current observation
+ observation = self.observation(state=state)
+
+ base_too_high = (
+ observation.base_height >= self.observation_space.high.base_height
+ )
+ return base_too_high
+
+ # =========
+ # Rendering
+ # =========
+
+ def render_image(
+ self, state: StateType, render_state: RenderStateType
+ ) -> tuple[RenderStateType, npt.NDArray]:
+ """Show the state."""
+
+ model_name = "ErgoCub"
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ # and get the simulated model
+ with self.jaxsim.editable(validate=False) as simulator:
+ simulator.data = state["simulator_data"]
+ model = simulator.get_model(model_name=model_name)
+
+ # Insert the model lazily in the visualizer if it is not already there
+ if model_name not in render_state.world._meshcat_models.keys():
+ from rod.urdf.exporter import UrdfExporter
+
+ urdf_string = UrdfExporter.sdf_to_urdf_string(
+ sdf=rod.Sdf(
+ version="1.7",
+ model=model.physics_model.description.extra_info["sdf_model"],
+ ),
+ pretty=True,
+ gazebo_preserve_fixed_joints=False,
+ )
+
+ meshcat_viz_name = render_state.world.insert_model(
+ model_description=urdf_string, is_urdf=True, model_name=None
+ )
+
+ render_state._jaxsim_to_meshcat_viz_name[model_name] = meshcat_viz_name
+
+ # Check that the model is in the visualizer
+ if (
+ not render_state._jaxsim_to_meshcat_viz_name[model_name]
+ in render_state.world._meshcat_models.keys()
+ ):
+ raise ValueError(f"The '{model_name}' model is not in the meshcat world")
+
+ # Update the model in the visualizer
+ render_state.world.update_model(
+ model_name=render_state._jaxsim_to_meshcat_viz_name[model_name],
+ joint_names=model.joint_names(),
+ joint_positions=model.joint_positions(),
+ base_position=model.base_position(),
+ base_quaternion=model.base_orientation(dcm=False),
+ )
+
+ return render_state, np.empty(0)
+
+ def render_init(self, open_gui: bool = False, **kwargs) -> RenderStateType:
+ """Initialize the render state."""
+
+ # Initialize the render state
+ meshcat_viz_state = MeshcatVizRenderState()
+
+ if open_gui:
+ meshcat_viz_state.open_window_in_process()
+
+ return meshcat_viz_state
+
+ def render_close(self, render_state: RenderStateType) -> None:
+ """Close the render state."""
+
+ render_state.close()
+
+
+class ErgoCubWalkEnvV0(JaxEnv):
+ """"""
+
+ def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None:
+ """"""
+
+ from jaxgym.wrappers.jax import (
+ ClipActionWrapper,
+ FlattenSpacesWrapper,
+ JaxTransformWrapper,
+ TimeLimit,
+ )
+
+ func_env = ErgoCubWalkFuncEnvV0()
+
+ func_env_wrapped = func_env
+ func_env_wrapped = TimeLimit(
+ env=func_env_wrapped, max_episode_steps=5_000
+ ) # TODO
+ func_env_wrapped = ClipActionWrapper(env=func_env_wrapped)
+ func_env_wrapped = FlattenSpacesWrapper(env=func_env_wrapped)
+ func_env_wrapped = JaxTransformWrapper(env=func_env_wrapped, function=jax.jit)
+
+ super().__init__(
+ func_env=func_env_wrapped,
+ metadata=self.metadata,
+ render_mode=render_mode,
+ )
+
+
+class ErgoCubWalkVectorEnvV0(JaxVectorEnv):
+ """"""
+
+ metadata = dict()
+
+ def __init__(
+ self,
+ # func_env: JaxDataclassEnv[
+ # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ # ],
+ num_envs: int,
+ render_mode: str | None = None,
+ # max_episode_steps: int = 5_000,
+ jit_compile: bool = True,
+ **kwargs,
+ ) -> None:
+ """"""
+
+ print("+++", kwargs)
+
+ env = ErgoCubWalkFuncEnvV0()
+
+ # Vectorize the environment.
+ # Note: it automatically wraps the environment in a TimeLimit wrapper.
+ super().__init__(
+ func_env=env,
+ num_envs=num_envs,
+ metadata=self.metadata,
+ render_mode=render_mode,
+ max_episode_steps=5_000, # TODO
+ jit_compile=jit_compile,
+ )
+
+ # from jaxgym.vector.jax import FlattenSpacesVecWrapper
+ #
+ # vec_env_wrapped = FlattenSpacesVecWrapper(env=vec_env)
+
+
+if __name__ == "__main__":
+ """Stable Baselines"""
+
+ def make_jax_env(
+ max_episode_steps: Optional[int] = 500, jit: bool = True
+ ) -> JaxEnv:
+ """"""
+
+ # TODO: single env -> time limit with stable_baselines?
+
+ if max_episode_steps in {None, 0}:
+ env = ErgoCubWalkFuncEnvV0()
+ else:
+ env = TimeLimit(
+ env=ErgoCubWalkFuncEnvV0(), max_episode_steps=max_episode_steps
+ )
+
+ return JaxEnv(
+ func_env=ToNumPyWrapper(
+ env=FlattenSpacesWrapper(env=env)
+ if not jit
+ else JaxTransformWrapper(
+ function=jax.jit,
+ env=FlattenSpacesWrapper(env=env),
+ ),
+ ),
+ render_mode="meshcat_viz",
+ )
+
+ class CustomVecEnvSB(vec_env_sb.VecEnv):
+ """"""
+
+ metadata = {"render_modes": []}
+
+ def __init__(
+ self,
+ jax_vector_env: JaxVectorEnv | VectorWrapper,
+ log_rewards: bool = False,
+ # num_envs: int,
+ # observation_space: spaces.Space,
+ # action_space: spaces.Space,
+ # render_mode: Optional[str] = None,
+ ) -> None:
+ """"""
+
+ if not isinstance(jax_vector_env.unwrapped, JaxVectorEnv):
+ raise TypeError(type(jax_vector_env))
+
+ self.jax_vector_env = jax_vector_env
+
+ single_env_action_space: PyTree = (
+ jax_vector_env.unwrapped.single_action_space
+ )
+
+ single_env_observation_space: PyTree = (
+ jax_vector_env.unwrapped.single_observation_space
+ )
+
+ super().__init__(
+ num_envs=self.jax_vector_env.num_envs,
+ action_space=single_env_action_space.to_box(),
+ observation_space=single_env_observation_space.to_box(),
+ )
+
+ self.actions = np.zeros_like(self.jax_vector_env.action_space.sample())
+
+ # Initialize the RNG seed
+ self._seed = None
+ self.seed()
+
+ # Initialize the rewards logger
+ self.logger_rewards = [] if log_rewards else None
+
+ def reset(self) -> vec_env_sb.base_vec_env.VecEnvObs:
+ """"""
+
+ observations, state_infos = self.jax_vector_env.reset(seed=self._seed)
+ return np.array(observations)
+
+ def step_async(self, actions: np.ndarray) -> None:
+ self.actions = actions
+
+ @staticmethod
+ @functools.partial(jax.jit, static_argnames=("batch_size",))
+ def tree_inverse_transpose(
+ pytree: jtp.PyTree, batch_size: int
+ ) -> List[jtp.PyTree]:
+ """"""
+
+ return [
+ jax.tree_util.tree_map(lambda leaf: leaf[i], pytree)
+ for i in range(batch_size)
+ ]
+
+ def step_wait(self) -> vec_env_sb.base_vec_env.VecEnvStepReturn:
+ """"""
+
+ (
+ observations,
+ rewards,
+ terminals,
+ truncated,
+ step_infos,
+ ) = self.jax_vector_env.step(actions=self.actions)
+
+ done = np.logical_or(terminals, truncated)
+
+ # list_of_step_infos = [
+ # jax.tree_util.tree_map(lambda l: l[i], step_infos)
+ # for i in range(self.jax_vector_env.num_envs)
+ # ]
+
+ list_of_step_infos = self.tree_inverse_transpose(
+ pytree=step_infos, batch_size=self.jax_vector_env.num_envs
+ )
+
+ # def pytree_to_numpy(pytree: jtp.PyTree) -> jtp.PyTree:
+ # return jax.tree_util.tree_map(lambda leaf: np.array(leaf), pytree)
+ #
+ # list_of_step_infos_numpy = [pytree_to_numpy(pt) for pt in list_of_step_infos]
+
+ list_of_step_infos_numpy = [
+ ToNumPyWrapper.pytree_to_numpy(pytree=pt) for pt in list_of_step_infos
+ ]
+
+ if self.logger_rewards is not None:
+ self.logger_rewards.append(np.array(rewards).mean())
+
+ return (
+ np.array(observations),
+ np.array(rewards),
+ np.array(done),
+ list_of_step_infos_numpy,
+ )
+
+ def close(self) -> None:
+ return self.jax_vector_env.close()
+
+ def get_attr(
+ self, attr_name: str, indices: vec_env_sb.base_vec_env.VecEnvIndices = None
+ ) -> List[Any]:
+ raise AttributeError
+ # raise NotImplementedError
+
+ def set_attr(
+ self,
+ attr_name: str,
+ value: Any,
+ indices: vec_env_sb.base_vec_env.VecEnvIndices = None,
+ ) -> None:
+ raise NotImplementedError
+
+ def env_method(
+ self,
+ method_name: str,
+ *method_args,
+ indices: vec_env_sb.base_vec_env.VecEnvIndices = None,
+ **method_kwargs,
+ ) -> List[Any]:
+ raise NotImplementedError
+
+ def env_is_wrapped(
+ self,
+ wrapper_class: Type[gym.Wrapper],
+ indices: vec_env_sb.base_vec_env.VecEnvIndices = None,
+ ) -> List[bool]:
+ return [False] * self.num_envs
+ # raise NotImplementedError
+
+ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
+ """"""
+
+ if seed is None:
+ seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
+
+ if np.array(seed, dtype="uint32") != np.array(seed):
+ raise ValueError(f"seed must be compatible with 'uint32' casting")
+
+ self._seed = seed
+ return [seed]
+
+ # _ = self.jax_vector_env.reset(seed=seed)
+ # return [None]
+
+ def make_vec_env_stable_baselines(
+ jax_dataclass_env: JaxDataclassEnv | JaxDataclassWrapper,
+ n_envs: int = 1,
+ seed: Optional[int] = None,
+ vec_env_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> vec_env_sb.VecEnv:
+ """"""
+
+ env = jax_dataclass_env
+
+ vec_env_kwargs = vec_env_kwargs if vec_env_kwargs is not None else dict()
+
+ vec_env = JaxVectorEnv(
+ func_env=env,
+ num_envs=n_envs,
+ **vec_env_kwargs,
+ )
+
+ # Flatten the PyTree spaces to regular Box spaces
+ vec_env = FlattenSpacesVecWrapper(env=vec_env)
+
+ vec_env_sb = CustomVecEnvSB(jax_vector_env=vec_env, log_rewards=True)
+
+ if seed is not None:
+ _ = vec_env_sb.seed(seed=seed)
+
+ return vec_env_sb
+
+ os.environ["IGN_GAZEBO_RESOURCE_PATH"] = "/conda/share/" # DEBUG
+
+ max_episode_steps = 200
+ func_env = NaNHandlerWrapper(env=ErgoCubWalkFuncEnvV0())
+
+ if max_episode_steps is not None:
+ func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps)
+
+ func_env = ClipActionWrapper(
+ env=SquashActionWrapper(env=ActionNoiseWrapper(env=func_env)),
+ )
+
+ vec_env = make_vec_env_stable_baselines(
+ jax_dataclass_env=func_env,
+ n_envs=6000,
+ seed=42,
+ vec_env_kwargs=dict(
+ jit_compile=True,
+ ),
+ )
+
+ vec_env = VecMonitor(
+ venv=VecNormalize(
+ venv=vec_env,
+ training=True,
+ )
+ )
+
+ vec_env.venv.venv.logger_rewards = []
+ seed = vec_env.seed(seed=7)[0]
+ _ = vec_env.reset()
+
+ model = PPO(
+ "MlpPolicy",
+ env=vec_env,
+ n_steps=5, # in the vector env -> real ones are x512
+ batch_size=256,
+ n_epochs=10,
+ gamma=0.95,
+ gae_lambda=0.9,
+ clip_range=0.1,
+ normalize_advantage=True,
+ target_kl=0.025,
+ verbose=2,
+ learning_rate=0.000_300,
+ policy_kwargs=dict(
+ activation_fn=nn.ReLU,
+ net_arch=dict(pi=[512, 512], vf=[512, 512]),
+ log_std_init=np.log(0.05),
+ ),
+ )
+
+ print(model.policy)
+
+ model = model.learn(total_timesteps=50000, progress_bar=True)
diff --git a/src/jaxgym/functional/__init__.py b/src/jaxgym/functional/__init__.py
new file mode 100644
index 000000000..f8a909b60
--- /dev/null
+++ b/src/jaxgym/functional/__init__.py
@@ -0,0 +1,4 @@
+from .func_env import FuncEnv
+from .func_wrapper import FuncWrapper
+
+# from .func_space import FuncSpace
diff --git a/src/jaxgym/functional/_jax/__init__.py b/src/jaxgym/functional/_jax/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/jaxgym/functional/_jax/_autoreset.py b/src/jaxgym/functional/_jax/_autoreset.py
new file mode 100644
index 000000000..a20faa5fd
--- /dev/null
+++ b/src/jaxgym/functional/_jax/_autoreset.py
@@ -0,0 +1,82 @@
+from typing import Any, Generic
+
+import jax.numpy as jnp
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.functional.core import FuncWrapper
+
+WrapperStateType = StateType
+WrapperObsType = ObsType
+WrapperActType = ActType
+WrapperRewardType = RewardType
+
+# TODO: non si puo' fare wrappando FuncWrapper perche' observation deve chiamare
+# initial, e initial ha bisogno di rng
+# Implementare sopra env.FunctionalJaxEnv? -> Il JaxVecEnv fa gia' autoreset di default.
+
+
+@jax_dataclasses.pytree_dataclass
+class AutoResetWrapper(
+ FuncWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ],
+ Generic[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ ],
+):
+ """"""
+
+ def is_done(self, state: WrapperStateType) -> bool:
+ """"""
+
+ info = self.env.step_info()
+
+ return jnp.array(
+ [
+ self.terminal(state=state),
+ "truncated" in info and info["truncated"] is True,
+ ],
+ dtype=bool,
+ ).any()
+
+ def observation(self, state: WrapperStateType) -> WrapperObsType:
+ """"""
+
+ return self.env.observation(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state)
+ )
+
+ def step_info(
+ self, state: WrapperStateType, action: ActType, next_state: WrapperStateType
+ ) -> dict[str, Any]:
+ """Info dict about a full transition."""
+
+ return self.env.step_info(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state),
+ action=action,
+ next_state=self.wrapper_state_to_environment_state(
+ wrapper_state=next_state
+ ),
+ )
diff --git a/src/jaxgym/functional/_jax/env.py b/src/jaxgym/functional/_jax/env.py
new file mode 100644
index 000000000..831ab5c64
--- /dev/null
+++ b/src/jaxgym/functional/_jax/env.py
@@ -0,0 +1,115 @@
+from typing import Any
+
+import gymnasium as gym
+import jax
+import jax.numpy as jnp
+import jax.random as jrng
+import numpy as np
+from gymnasium.envs.registration import EnvSpec
+from gymnasium.experimental.functional import ActType, FuncEnv, StateType
+from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
+from gymnasium.utils import seeding
+from gymnasium.vector.utils import batch_space
+
+
+# TODO: this is still copy-paste from gymnasium
+# TODO: update with my logic
+# TODO: autoreset wrapper works on top of this env
+class FunctionalJaxEnv(gym.Env):
+ """A conversion layer for jax-based environments."""
+
+ state: StateType
+ rng: jrng.PRNGKey
+
+ def __init__(
+ self,
+ func_env: FuncEnv,
+ metadata: dict[str, Any] | None = None,
+ render_mode: str | None = None,
+ reward_range: tuple[float, float] = (-float("inf"), float("inf")),
+ spec: EnvSpec | None = None,
+ ):
+ """Initialize the environment from a FuncEnv."""
+
+ if metadata is None:
+ metadata = {"render_mode": []}
+
+ self.func_env = func_env
+
+ self.observation_space = func_env.observation_space
+ self.action_space = func_env.action_space
+
+ self.metadata = metadata
+ self.render_mode = render_mode
+ self.reward_range = reward_range
+
+ self.spec = spec
+
+ self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
+
+ if self.render_mode == "rgb_array":
+ self.render_state = self.func_env.render_init()
+ else:
+ self.render_state = None
+
+ np_random, _ = seeding.np_random()
+ seed = np_random.integers(0, 2 ** 32 - 1, dtype="uint32")
+
+ self.rng = jrng.PRNGKey(seed)
+
+ def reset(self, *, seed: int | None = None, options: dict | None = None):
+ """Resets the environment using the seed."""
+
+ super().reset(seed=seed)
+ if seed is not None:
+ self.rng = jrng.PRNGKey(seed)
+
+ rng, self.rng = jrng.split(self.rng)
+
+ self.state = self.func_env.initial(rng=rng)
+ obs = self.func_env.observation(self.state)
+ info = self.func_env.state_info(self.state)
+
+ obs = jax_to_numpy(obs)
+
+ return obs, info
+
+ def step(self, action: ActType):
+ """Steps through the environment using the action."""
+
+ if self._is_box_action_space:
+ assert isinstance(self.action_space, gym.spaces.Box) # For typing
+ action = np.clip(action, self.action_space.low, self.action_space.high)
+ else: # Discrete
+ # For now we assume jax envs don't use complex spaces
+ err_msg = f"{action!r} ({type(action)}) invalid"
+ assert self.action_space.contains(action), err_msg
+
+ rng, self.rng = jrng.split(self.rng)
+
+ next_state = self.func_env.transition(self.state, action, rng)
+ observation = self.func_env.observation(next_state)
+ reward = self.func_env.reward(self.state, action, next_state)
+ terminated = self.func_env.terminal(next_state)
+ info = self.func_env.step_info(self.state, action, next_state)
+ self.state = next_state
+
+ observation = jax_to_numpy(observation)
+
+ return observation, float(reward), bool(terminated), False, info
+
+ # def render(self):
+ # """Returns the render state if `render_mode` is "rgb_array"."""
+ # if self.render_mode == "rgb_array":
+ # self.render_state, image = self.func_env.render_image(
+ # self.state, self.render_state
+ # )
+ # return image
+ # else:
+ # raise NotImplementedError
+ #
+ # def close(self):
+ # """Closes the environments and render state if set."""
+ # if self.render_state is not None:
+ # self.func_env.render_close(self.render_state)
+ # self.render_state = None
diff --git a/src/jaxgym/functional/func_env.py b/src/jaxgym/functional/func_env.py
new file mode 100644
index 000000000..65ef80963
--- /dev/null
+++ b/src/jaxgym/functional/func_env.py
@@ -0,0 +1,184 @@
+import abc
+from typing import Any, Generic
+
+import gymnasium as gym
+import numpy.typing as npt
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+
+# Similar to https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/experimental/functional.py
+class FuncEnv(
+ Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType],
+ abc.ABC,
+):
+ """
+ Base class for functional environments.
+
+ Note:
+ This class is meant to be kept stateless.
+ The state of the environment (possibly encapsulated with the states of wrappers)
+ should be stored in the `state` argument of the methods.
+
+ Note:
+ This functional approach mainly targets JAX-based environments, but has been
+ formulated in a generic way so that it can be implemented in other frameworks.
+ """
+
+ # These spaces have to be populated in the __post_init__ method.
+ # If necessary, the __post_init__ method can access the simulator.
+ # _action_space: spaces.Space | None = jax_dataclasses.static_field(init=False)
+ # _observation_space: spaces.Space | None = jax_dataclasses.static_field(init=False)
+ _action_space: gym.Space | None = None
+ _observation_space: gym.Space | None = None
+
+ @property
+ def unwrapped(
+ self,
+ ) -> "FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType]":
+ """Return the innermost environment."""
+
+ return self
+
+ @property
+ def action_space(self) -> gym.Space[ActType]:
+ """Return the action space."""
+
+ return self._action_space
+
+ @property
+ def observation_space(self) -> gym.Space[ObsType]:
+ """Return the observation space."""
+
+ return self._observation_space
+
+ # ================
+ # Abstract methods
+ # ================
+
+ @abc.abstractmethod
+ def initial(self, rng: Any = None) -> StateType:
+ """
+ Initialize the environment returning its initial state.s.
+
+ Args:
+ rng: A resource to initialize the RNG of the functional environment.
+
+ Returns:
+ The initial state of the environment.
+ """
+ pass
+
+ @abc.abstractmethod
+ def transition(
+ self, state: StateType, action: ActType, rng: Any = None
+ ) -> StateType:
+ """
+ Compute the next state by applying the given action to the functional environment
+ in the given state.
+
+ Args:
+ state:
+ action:
+ rng:
+
+ Returns:
+ The next state of the environment.
+ """
+ pass
+
+ @abc.abstractmethod
+ def observation(self, state: StateType) -> ObsType:
+ """
+ Compute the observation of the environment in the given state.
+
+ Args:
+ state:
+
+ Returns:
+ The observation computed from the given state.
+ """
+ pass
+
+ @abc.abstractmethod
+ def reward(
+ self, state: StateType, action: ActType, next_state: StateType
+ ) -> RewardType:
+ """"""
+ pass
+
+ @abc.abstractmethod
+ def terminal(self, state: StateType) -> TerminalType:
+ """"""
+ pass
+
+ # =========
+ # Rendering
+ # =========
+
+ # @abc.abstractmethod
+ def render_image(
+ self, state: StateType, render_state: RenderStateType
+ ) -> tuple[RenderStateType, npt.NDArray]:
+ """Show the state."""
+ raise NotImplementedError
+
+ # @abc.abstractmethod
+ def render_init(self, **kwargs) -> RenderStateType:
+ """Initialize the render state."""
+ raise NotImplementedError
+
+ # @abc.abstractmethod
+ def render_close(self, render_state: RenderStateType) -> None:
+ """Close the render state."""
+ raise NotImplementedError
+
+ # =============
+ # Other methods
+ # =============
+
+ def state_info(self, state: StateType) -> dict[str, Any]:
+ """Info dict about a single state."""
+
+ return {}
+
+ def step_info(
+ self, state: StateType, action: ActType, next_state: StateType
+ ) -> dict[str, Any]:
+ """Info dict about a full transition."""
+
+ return {
+ # TODO: this is to keep the autoreset structure constant and return
+ # always a dict with the same keys. Check if this is necessary.
+ # "state_info": self.state_info(state=state),
+ # "next_state_info": self.state_info(state=next_state),
+ }
+
+ # TODO: remove
+ # def transform(self, func: Callable[[Callable], Callable]) -> None:
+ # """Functional transformations."""
+ # raise NotImplementedError
+ # # with self.mutable_context(mutability=Mutability.MUTABLE):
+ # self.initial = func(self.initial)
+ # self.transition = func(self.transition)
+ # self.observation = func(self.observation)
+ # self.reward = func(self.reward)
+ # self.terminal = func(self.terminal)
+ # self.state_info = func(self.state_info)
+ # self.step_info = func(self.step_info)
+
+ def __str__(self) -> str:
+ """"""
+
+ return f"<{type(self).__name__}>"
+
+ def __repr__(self) -> str:
+ """"""
+
+ return str(self)
diff --git a/src/jaxgym/functional/func_space.py b/src/jaxgym/functional/func_space.py
new file mode 100644
index 000000000..cb199a4c0
--- /dev/null
+++ b/src/jaxgym/functional/func_space.py
@@ -0,0 +1,6 @@
+import gymnasium as gym
+
+
+class FuncSpace(gym.Space):
+ def __init__(self) -> None:
+ raise NotImplementedError
diff --git a/src/jaxgym/functional/func_wrapper.py b/src/jaxgym/functional/func_wrapper.py
new file mode 100644
index 000000000..e6ada8798
--- /dev/null
+++ b/src/jaxgym/functional/func_wrapper.py
@@ -0,0 +1,225 @@
+from typing import Any, Generic, TypeVar
+
+import gymnasium as gym
+import numpy.typing as npt
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.functional import FuncEnv
+
+WrapperStateType = TypeVar("WrapperStateType")
+WrapperObsType = TypeVar("WrapperObsType")
+WrapperActType = TypeVar("WrapperActType")
+WrapperRewardType = TypeVar("WrapperRewardType")
+
+
+class FuncWrapper(
+ FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType],
+ Generic[
+ # FuncEnv types
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ # FuncWrapper types
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ],
+):
+ """"""
+
+ env: FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType]
+
+ _action_space: gym.Space | None = None
+ _observation_space: gym.Space | None = None
+
+ @property
+ def unwrapped(
+ self,
+ ) -> FuncEnv[
+ StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ ]:
+ """"""
+
+ return self.env.unwrapped
+
+ @property
+ def action_space(self) -> gym.Space[ActType]:
+ """"""
+
+ return (
+ self._action_space
+ if self._action_space is not None
+ else self.env.action_space
+ )
+
+ @property
+ def observation_space(self) -> gym.Space[ObsType]:
+ """"""
+
+ return (
+ self._observation_space
+ if self._observation_space is not None
+ else self.env.observation_space
+ )
+
+ @action_space.setter
+ def action_space(self, space: gym.Space[ActType]) -> None:
+ """"""
+
+ self._action_space = space
+
+ @observation_space.setter
+ def observation_space(self, space: gym.Space[ObsType]) -> None:
+ """"""
+
+ self._observation_space = space
+
+ def __str__(self):
+ """"""
+
+ return f"<{type(self).__name__}{self.env}>"
+
+ def __repr__(self):
+ """"""
+
+ return str(self)
+
+ # ==================================
+ # Implementation of FunEnv interface
+ # ==================================
+
+ def initial(self, rng: Any = None) -> WrapperStateType:
+ """"""
+
+ return self.env.initial(rng=rng)
+
+ def transition(
+ self, state: WrapperStateType, action: WrapperActType, rng: Any = None
+ ) -> WrapperStateType:
+ """"""
+
+ return self.env.transition(state=state, action=action, rng=rng)
+
+ def observation(self, state: WrapperStateType) -> WrapperObsType:
+ """"""
+
+ return self.env.observation(state=state)
+
+ def reward(
+ self,
+ state: WrapperStateType,
+ action: WrapperActType,
+ next_state: WrapperStateType,
+ ) -> WrapperRewardType:
+ """"""
+
+ return self.env.reward(state=state, action=action, next_state=next_state)
+
+ def terminal(self, state: WrapperStateType) -> TerminalType:
+ """"""
+
+ return self.env.terminal(state=state)
+
+ def state_info(self, state: WrapperStateType) -> dict[str, Any]:
+ """"""
+
+ return self.env.state_info(state=state)
+
+ def step_info(
+ self,
+ state: WrapperStateType,
+ action: WrapperActType,
+ next_state: WrapperStateType,
+ ) -> dict[str, Any]:
+ """"""
+
+ return self.env.step_info(state=state, action=action, next_state=next_state)
+
+ # =========
+ # Rendering
+ # =========
+
+ def render_image(
+ self, state: StateType, render_state: RenderStateType
+ ) -> tuple[RenderStateType, npt.NDArray]:
+ """"""
+ return self.env.render_image(state=state, render_state=render_state)
+
+ def render_init(self, **kwargs) -> RenderStateType:
+ """"""
+ return self.env.render_init(**kwargs)
+
+ def render_close(self, render_state: RenderStateType) -> None:
+ """"""
+ return self.env.render_close(render_state=render_state)
+
+
+# class ActionFuncWrapper(
+# FuncWrapper[
+# # FuncEnv types
+# StateType,
+# ObsType,
+# ActType,
+# RewardType,
+# TerminalType,
+# RenderStateType,
+# # FuncWrapper types
+# StateType,
+# ObsType,
+# WrapperActType,
+# RewardType,
+# ],
+# Generic[WrapperActType],
+# abc.ABC,
+# ):
+# """"""
+#
+# @abc.abstractmethod
+# def action(self, action: WrapperActType) -> ActType:
+# """"""
+#
+# pass
+#
+# def reward(
+# self,
+# state: WrapperStateType,
+# action: WrapperActType,
+# next_state: WrapperStateType,
+# ) -> WrapperRewardType:
+# """"""
+#
+# return self.env.reward(
+# state=state, action=self.action(action=action), next_state=next_state
+# )
+#
+# def transition(
+# self, state: WrapperStateType, action: WrapperActType, rng: Any = None
+# ) -> WrapperStateType:
+# """"""
+#
+# return self.env.transition(
+# state=state, action=self.action(action=action), rng=rng
+# )
+#
+# def step_info(
+# self,
+# state: WrapperStateType,
+# action: WrapperActType,
+# next_state: WrapperStateType,
+# ) -> dict[str, Any]:
+# """"""
+#
+# return self.env.step_info(
+# state=state, action=self.action(action=action), next_state=next_state
+# )
diff --git a/src/jaxgym/jax/__init__.py b/src/jaxgym/jax/__init__.py
new file mode 100644
index 000000000..3f85d53fb
--- /dev/null
+++ b/src/jaxgym/jax/__init__.py
@@ -0,0 +1,6 @@
+from .dataclass_func_env import JaxDataclassEnv
+from .dataclass_func_env_wrapper import JaxDataclassActionWrapper, JaxDataclassWrapper
+from .env import JaxEnv
+from .pytree_space import PyTree
+
+# from .jaxsim_func_env import JaxSimFuncEnv
diff --git a/src/jaxgym/jax/dataclass_func_env.py b/src/jaxgym/jax/dataclass_func_env.py
new file mode 100644
index 000000000..90086b5f3
--- /dev/null
+++ b/src/jaxgym/jax/dataclass_func_env.py
@@ -0,0 +1,41 @@
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+import jaxgym.jax.pytree_space as spaces
+from jaxgym.functional import FuncEnv
+from jaxsim.utils import JaxsimDataclass
+
+
+@jax_dataclasses.pytree_dataclass
+class JaxDataclassEnv(
+ FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType],
+ JaxsimDataclass,
+):
+ """
+ Base class for JAX-based functional environments.
+
+ Note:
+ All environments implementing this class must be pytree_dataclasses.
+ """
+
+ # Override spaces for JAX, storing them as static fields.
+ # Note: currently only PyTree spaces are supported.
+ # Note: always sample from these spaces using functional methods in order to
+ # avoid incurring in JIT recompilations.
+ _action_space: spaces.PyTree | None = jax_dataclasses.static_field(init=False)
+ _observation_space: spaces.PyTree | None = jax_dataclasses.static_field(init=False)
+
+ @property
+ def action_space(self) -> spaces.PyTree:
+ return self._action_space
+
+ @property
+ def observation_space(self) -> spaces.PyTree:
+ return self._observation_space
diff --git a/src/jaxgym/jax/dataclass_func_env_wrapper.py b/src/jaxgym/jax/dataclass_func_env_wrapper.py
new file mode 100644
index 000000000..757399013
--- /dev/null
+++ b/src/jaxgym/jax/dataclass_func_env_wrapper.py
@@ -0,0 +1,174 @@
+import abc
+from typing import Any, Generic
+
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+import jaxgym.jax.pytree_space as spaces
+from jaxgym.functional.func_wrapper import (
+ FuncWrapper,
+ WrapperActType,
+ WrapperObsType,
+ WrapperRewardType,
+ WrapperStateType,
+)
+from jaxgym.jax import JaxDataclassEnv
+from jaxsim.utils import JaxsimDataclass
+
+
+@jax_dataclasses.pytree_dataclass
+class JaxDataclassWrapper(
+ FuncWrapper[
+ #
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ #
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ],
+ # TODO
+ Generic[
+ #
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ #
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ],
+ JaxsimDataclass,
+):
+ """"""
+
+ env: JaxDataclassEnv[
+ StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ ]
+
+ _action_space: spaces.PyTree | None = jax_dataclasses.static_field(init=False)
+ _observation_space: spaces.PyTree | None = jax_dataclasses.static_field(init=False)
+
+ def __post_init__(self) -> None:
+ """"""
+
+ if not isinstance(self.env.unwrapped, JaxDataclassEnv):
+ raise TypeError(type(self.env.unwrapped), JaxDataclassEnv)
+
+ @property
+ def action_space(self) -> spaces.PyTree:
+ """"""
+
+ return (
+ self._action_space
+ if self._action_space is not None
+ else self.env.action_space
+ )
+
+ @property
+ def observation_space(self) -> spaces.PyTree:
+ """"""
+
+ return (
+ self._observation_space
+ if self._observation_space is not None
+ else self.env.observation_space
+ )
+
+ @action_space.setter
+ def action_space(self, space: spaces.PyTree) -> None:
+ """"""
+
+ self._action_space = space
+
+ @observation_space.setter
+ def observation_space(self, space: spaces.PyTree) -> None:
+ """"""
+
+ self._observation_space = space
+
+
+@jax_dataclasses.pytree_dataclass
+class JaxDataclassActionWrapper(
+ JaxDataclassWrapper[
+ # FuncEnv types
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ # FuncWrapper types
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ ],
+ # TODO
+ Generic[
+ #
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ ],
+ abc.ABC,
+):
+ """"""
+
+ @abc.abstractmethod
+ def action(self, action: WrapperActType) -> ActType:
+ """"""
+
+ pass
+
+ def reward(
+ self,
+ state: WrapperStateType,
+ action: WrapperActType,
+ next_state: WrapperStateType,
+ ) -> WrapperRewardType:
+ """"""
+
+ return self.env.reward(
+ state=state, action=self.action(action=action), next_state=next_state
+ )
+
+ def transition(
+ self, state: WrapperStateType, action: WrapperActType, rng: Any = None
+ ) -> WrapperStateType:
+ """"""
+
+ return self.env.transition(
+ state=state, action=self.action(action=action), rng=rng
+ )
+
+ def step_info(
+ self,
+ state: WrapperStateType,
+ action: WrapperActType,
+ next_state: WrapperStateType,
+ ) -> dict[str, Any]:
+ """"""
+
+ return self.env.step_info(
+ state=state, action=self.action(action=action), next_state=next_state
+ )
diff --git a/src/jaxgym/jax/env.py b/src/jaxgym/jax/env.py
new file mode 100644
index 000000000..87856e63e
--- /dev/null
+++ b/src/jaxgym/jax/env.py
@@ -0,0 +1,265 @@
+# import multiprocessing
+from typing import Any, Generic, SupportsFloat
+
+import gymnasium as gym
+import jax.numpy as jnp
+import jax.random
+import numpy as np
+from gymnasium.core import ActType, ObsType, RenderFrame
+from gymnasium.envs.registration import EnvSpec
+from gymnasium.utils import seeding
+
+import jaxgym.jax.pytree_space as spaces
+from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper
+from jaxsim import logging
+
+# from meshcat_viz import MeshcatWorld
+
+
+class JaxEnv(gym.Env[ObsType, ActType], Generic[ObsType, ActType]):
+ """"""
+
+ action_space: spaces.PyTree
+ observation_space: spaces.PyTree
+
+ metadata: dict[str, Any] = {"render_modes": ["meshcat_viz", "meshcat_viz_gui"]}
+
+ def __init__(
+ self,
+ # func_env: FuncEnv | FuncWrapper,
+ func_env: JaxDataclassEnv | JaxDataclassWrapper,
+ metadata: dict[str, Any] | None = None,
+ render_mode: str | None = None,
+ reward_range: tuple[float, float] = (-float("inf"), float("inf")),
+ spec: EnvSpec | None = None,
+ ) -> None:
+ """"""
+
+ if not isinstance(func_env.unwrapped, JaxDataclassEnv):
+ raise TypeError(type(func_env.unwrapped), JaxDataclassEnv)
+
+ metadata = metadata if metadata is not None else dict(render_mode=list())
+
+ # Store the jax environment
+ self.func_env = func_env
+
+ # Initialize the state of the environment
+ self.state = None
+
+ # Expose the same spaces
+ self.action_space = func_env.action_space
+ self.observation_space = func_env.observation_space
+ # assert isinstance(self.action_space, spaces.PyTree)
+ # assert isinstance(self.observation_space, spaces.PyTree)
+
+ # Store the other mandatory attributes that gym.Env expects
+ self.metadata = metadata
+ self.render_mode = render_mode
+ self.reward_range = reward_range
+
+ # Store the environment specs
+ self.spec = spec
+
+ # self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
+ #
+ # if self.render_mode == "rgb_array":
+ # self.render_state = self.func_env.render_init()
+ # else:
+ # self.render_state = None
+ self.render_state = None
+ self._meshcat_world = None # old
+ self._meshcat_window = None # old
+
+ # Initialize the RNGs with a random seed
+ seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
+ self._np_random, _ = seeding.np_random(seed=int(seed))
+ self.rng = jax.random.PRNGKey(seed=seed)
+
+ def subkey(self, num: int = 1) -> jax.random.PRNGKeyArray:
+ """
+ Generate one or multiple sub-keys from the internal key.
+
+ Note:
+ The internal key is automatically updated, there's no need to handle
+ the environment key externally.
+
+ Args:
+ num: Number of keys to generate.
+
+ Returns:
+ The generated sub-keys.
+ """
+
+ self.rng, *sub_keys = jax.random.split(self.rng, num=num + 1)
+ return jnp.stack(sub_keys).squeeze()
+
+ def step(
+ self, action: ActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """"""
+
+ # TODO: clip action with wrapper
+ # assert isinstance(self.action_space, spaces.PyTree)
+ # action = self.action_space.clip(x=action)
+
+ # if self._is_box_action_space:
+ # assert isinstance(self.action_space, gym.spaces.Box) # For typing
+ # action = np.clip(action, self.action_space.low, self.action_space.high)
+ # else: # Discrete
+ # # For now we assume jax envs don't use complex spaces
+ # err_msg = f"{action!r} ({type(action)}) invalid"
+ # assert self.action_space.contains(action), err_msg
+
+ # rng, self.rng = jrng.split(self.rng)
+
+ # Advance the functional environment
+ next_state = self.func_env.transition(
+ state=self.state, action=action, rng=self.subkey(num=1)
+ )
+
+ # Extract updated data from the advanced environment
+ observation = self.func_env.observation(state=next_state)
+ reward = self.func_env.reward(
+ state=self.state, action=action, next_state=next_state
+ )
+ info = self.func_env.step_info(
+ state=self.state, action=action, next_state=next_state
+ )
+
+ # Detect if the environment reached a terminal state
+ terminated = self.func_env.terminal(state=next_state)
+ truncated = (
+ # False if "truncated" not in info else type(terminated)(info["truncated"])
+ type(terminated)(False)
+ if "truncated" not in info
+ else type(terminated)(info["truncated"])
+ )
+
+ # Remove the redundant "truncated" entry from info if present
+ _ = info.pop("truncated", None)
+
+ # Store the updated state
+ self.state = next_state
+
+ # observation = jax_to_numpy(observation)
+
+ return observation, reward, terminated, truncated, info
+
+ def reset(
+ self, *, seed: int | None = None, options: dict | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Resets the environment using the seed."""
+
+ super().reset(seed=seed)
+ self.rng = jax.random.PRNGKey(seed) if seed is not None else self.rng
+
+ # Seed the spaces
+ self.action_space.seed(seed=seed)
+ self.observation_space.seed(seed=seed)
+
+ # Generate initial state
+ self.state = self.func_env.initial(rng=self.subkey(num=1))
+
+ # Sample the initial observation and info
+ obs = self.func_env.observation(state=self.state)
+ info = self.func_env.state_info(state=self.state)
+
+ # obs = jax_to_numpy(obs)
+
+ # assert self.observation_space.contains(obs), obs
+ if obs not in self.observation_space:
+ logging.warning(f"Initial observation not in observation space")
+ logging.debug(obs)
+
+ return obs, info
+
+ # @property
+ # def visualizer(self) -> MeshcatWorld:
+ # """Returns the visualizer if `render_mode` is 'meshcat_viz'."""
+ #
+ # if self._meshcat_world is not None:
+ # return self._meshcat_world
+ #
+ # world = MeshcatWorld()
+ # world.open()
+ #
+ # def open_window(web_url: str) -> None:
+ # import tkinter
+ #
+ # import webview
+ #
+ # # Create an instance of tkinter frame or window
+ # win = tkinter.Tk()
+ # win.geometry("700x350")
+ #
+ # webview.create_window("meshcat", web_url)
+ # webview.start(gui="qt")
+ #
+ # # TODO: non si apre niente in subprocess!
+ # p = multiprocessing.Process(target=open_window, args=(world.web_url,))
+ #
+ # self._meshcat_window = p
+ # self._meshcat_world = world
+ # return self._meshcat_world
+
+ # def update_meshcat_world(self) -> None:
+ # """"""
+ #
+ # return None
+
+ def render(self) -> RenderFrame | list[RenderFrame] | None:
+ """Returns the render state if `render_mode` is 'rgb_array'."""
+
+ if self.render_mode not in {None, "meshcat_viz", "meshcat_viz_gui"}:
+ raise NotImplementedError(self.render_mode)
+
+ if self.render_mode is None:
+ return None
+
+ if self.render_state is None:
+ if self.render_mode == "meshcat_viz":
+ self.render_state = self.func_env.render_init(open_gui=False)
+ elif self.render_mode == "meshcat_viz_gui":
+ self.render_state = self.func_env.render_init(open_gui=True)
+ else:
+ raise ValueError(self.render_mode)
+
+ self.render_state, image = self.func_env.render_image(
+ self.state, self.render_state
+ )
+
+ return image
+
+ # # TODO: how to create proper interfaces?
+ # self._meshcat_world = self.func_env.unwrapped.update_meshcat_world(
+ # world=self.visualizer, state=self.state["env"]
+ # )
+ #
+ # return None
+
+ # if self.render_mode == "rgb_array":
+ # self.render_state, image = self.func_env.render_image(
+ # self.state, self.render_state
+ # )
+ # return image
+ # else:
+ # raise NotImplementedError
+
+ def close(self) -> None:
+ """"""
+
+ # import meshcat_viz.meshcat
+ #
+ # if self._meshcat_world is not None:
+ # # self._meshcat_world.close()
+ # self._meshcat_world = None
+ # self._meshcat_window.kill()
+ # self._meshcat_window = None
+
+ if self.render_state is not None:
+ self.func_env.render_close(self.render_state)
+ self.render_state = None
+
+ def __str__(self):
+ """Returns the wrapper name and the :attr:`env` representation string."""
+ return f"<{type(self).__name__}{self.func_env}>"
diff --git a/src/jaxgym/jax/jaxsim_func_env.py b/src/jaxgym/jax/jaxsim_func_env.py
new file mode 100644
index 000000000..fb7c15af0
--- /dev/null
+++ b/src/jaxgym/jax/jaxsim_func_env.py
@@ -0,0 +1,3 @@
+class JaxSimFuncEnv:
+ def __init__(self) -> None:
+ raise NotImplementedError
diff --git a/src/jaxgym/jax/pytree_space.py b/src/jaxgym/jax/pytree_space.py
new file mode 100644
index 000000000..6cdee9a3b
--- /dev/null
+++ b/src/jaxgym/jax/pytree_space.py
@@ -0,0 +1,378 @@
+import copy
+import logging
+from typing import Any
+
+import gymnasium as gym
+import jax.flatten_util
+import jax.numpy as jnp
+import jax.tree_util
+import numpy as np
+import numpy.typing as npt
+from gymnasium.spaces.utils import flatdim, flatten
+from gymnasium.vector.utils.spaces import batch_space
+
+import jaxsim.typing as jtp
+from jaxsim.utils import not_tracing
+
+
+class PyTree(gym.Space[jtp.PyTree]):
+ """A generic space operating on JAX PyTree objects."""
+
+ def __init__(
+ self,
+ low: jtp.PyTree,
+ high: jtp.PyTree,
+ seed: int | None = None,
+ # TODO:
+ vectorize: int | None = None,
+ ) -> None:
+ """"""
+
+ # ====================
+ # Handle vectorization
+ # ====================
+
+ self.vectorize = vectorize
+ self.vectorized = False
+
+ if vectorize is not None and vectorize < 2:
+ msg = f"Ignoring 'vectorize={vectorize}' argument since it is < 2"
+ logging.warning(msg=msg)
+
+ if vectorize is not None and vectorize >= 2:
+ self.vectorized = True
+ low = jax.tree_util.tree_map(lambda l: jnp.stack([l] * vectorize), low)
+ high = jax.tree_util.tree_map(lambda l: jnp.stack([l] * vectorize), high)
+
+ # ==========================
+ # Check low and high pytrees
+ # ==========================
+ # TODO: make generic (pytrees_with_same_dtype|shape|supported_dtype) and move
+ # to utils
+
+ def check() -> None:
+ supported_dtypes = {
+ jnp.array(0, dtype=jnp.float32).dtype,
+ jnp.array(0, dtype=jnp.float64).dtype,
+ jnp.array(0, dtype=int).dtype,
+ jnp.array(0, dtype=bool).dtype,
+ }
+
+ dtypes_supported = self.flatten_pytree(
+ pytree=jax.tree_util.tree_map(
+ lambda l1, l2: jnp.array(l1).dtype in supported_dtypes
+ and jnp.array(l2).dtype in supported_dtypes,
+ low,
+ high,
+ )
+ )
+
+ if not jnp.all(dtypes_supported):
+ raise ValueError(
+ "Either low or high pytrees have attributes with unsupported dtype"
+ )
+
+ shape_match = self.flatten_pytree(
+ pytree=jax.tree_util.tree_map(
+ lambda l1, l2: jnp.array(l1).shape == jnp.array(l2).shape, low, high
+ )
+ )
+
+ if not jnp.all(shape_match):
+ raise ValueError("Wrong shape of low and high attributes")
+
+ dtype_match = self.flatten_pytree(
+ pytree=jax.tree_util.tree_map(
+ lambda l1, l2: jnp.array(l1).dtype == jnp.array(l2).dtype, low, high
+ )
+ )
+
+ if not jnp.all(dtype_match):
+ raise ValueError("Wrong dtype of low and high attributes")
+
+ if not_tracing(var=low):
+ check()
+
+ # ===============
+ # Build the space
+ # ===============
+
+ # Flatten the pytrees
+ low_flat = self.flatten_pytree(pytree=low)
+ high_flat = self.flatten_pytree(pytree=high)
+
+ if low_flat.dtype != high_flat.dtype:
+ raise ValueError(low_flat.dtype, high_flat.dtype)
+
+ if low_flat.shape != high_flat.shape:
+ raise ValueError(low_flat.shape, high_flat.shape)
+
+ # Transform all leafs to array and store them in the object
+ self.low = jax.tree_util.tree_map(lambda l: jnp.array(l), low)
+ self.high = jax.tree_util.tree_map(lambda l: jnp.array(l), high)
+
+ # Initialize the seed if not given
+ seed = (
+ seed
+ if seed is not None
+ else np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
+ )
+
+ # Initialize the JAX random key
+ self.key = jax.random.PRNGKey(seed=seed)
+
+ # Initialize parent class
+ super().__init__(shape=None, dtype=None, seed=int(seed))
+
+ def subkey(self, num: int = 1) -> jax.random.PRNGKeyArray:
+ """
+ Generate one or multiple sub-keys from the internal key.
+
+ Note:
+ The internal key is automatically updated, there's no need to handle
+ the environment key externally.
+
+ Args:
+ num: Number of keys to generate.
+
+ Returns:
+ The generated sub-keys.
+ """
+
+ self.key, *sub_keys = jax.random.split(self.key, num=num + 1)
+ return jnp.stack(sub_keys).squeeze()
+
+ # TODO: what if key is a vector? -> multiple outputs?
+ def sample_with_key(self, key: jax.random.PRNGKeyArray) -> jtp.PyTree:
+ """"""
+
+ def random_array(
+ key, shape: tuple, min: jtp.PyTree, max: jtp.PyTree, dtype
+ ) -> jtp.Array:
+ """Helper to select the right sampling function for the supported dtypes"""
+
+ match dtype:
+ case jnp.float32.dtype | jnp.float64.dtype:
+ return jax.random.uniform(
+ key=key,
+ shape=shape,
+ minval=min,
+ maxval=max,
+ dtype=dtype,
+ )
+ case jnp.int16.dtype | jnp.int32.dtype | jnp.int64.dtype:
+ return jax.random.randint(
+ key=key,
+ shape=shape,
+ minval=min,
+ maxval=max + 1,
+ dtype=dtype,
+ )
+ case jnp.bool_.dtype:
+ return jax.random.randint(
+ key=key,
+ shape=shape,
+ minval=min,
+ maxval=max + 1,
+ ).astype(bool)
+ case _:
+ raise ValueError(dtype)
+
+ # Create and flatten a tree having a PRNGKey for each leaf.
+ # We do this just to get the number of keys we need to generate, and the
+ # function to unflatten the ravelled tree.
+ dummy_pytree = jax.tree_util.tree_map(lambda l: jax.random.PRNGKey(0), self.low)
+ dummy_pytree_flat, unflatten_fn = jax.flatten_util.ravel_pytree(dummy_pytree)
+
+ # Use the subkey to generate new keys, one for each leaf.
+ # Note: the division by 2 is needed because keys are vector of 2 elements.
+ subkey_flat = jax.random.split(key=key, num=dummy_pytree_flat.size // 2)
+
+ # Generate a pytree having a different subkey in each leaf
+ subkey_pytree = unflatten_fn(jnp.array(subkey_flat).flatten())
+
+ # Generate a pytree by sampling leafs according to their dtype and using a
+ # different subkey for each of them
+ return jax.tree_util.tree_map(
+ lambda low, high, subkey: random_array(
+ key=subkey, shape=low.shape, min=low, max=high, dtype=low.dtype
+ ),
+ self.low,
+ self.high,
+ subkey_pytree,
+ )
+
+ def sample(self, mask: Any | None = None) -> jtp.PyTree:
+ """"""
+
+ # Generate a subkey
+ subkey = self.subkey(num=1)
+
+ return self.sample_with_key(key=subkey)
+
+ def seed(self, seed: int | None = None) -> list[int]:
+ """"""
+
+ seed = (
+ seed
+ if seed is not None
+ else np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
+ )
+
+ self.key = jax.random.PRNGKey(seed=seed)
+ return super().seed(seed=seed)
+
+ def contains(self, x: jtp.PyTree) -> bool:
+ """"""
+
+ def is_inside_bounds(x, low, high):
+ return jax.lax.select(
+ pred=x.size == 0 or jnp.all((x >= low) & (x <= high)),
+ on_true=True,
+ on_false=False,
+ )
+
+ contains_all_leaves = jax.tree_util.tree_map(
+ lambda low, high, l: is_inside_bounds(x=l, low=low, high=high),
+ self.low,
+ self.high,
+ x,
+ )
+
+ contains_all_leaves_flat = self.flatten_pytree(pytree=contains_all_leaves)
+
+ return jnp.all(contains_all_leaves_flat)
+
+ @property
+ def is_np_flattenable(self) -> bool:
+ """Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`."""
+
+ return True
+
+ def to_box(self) -> gym.spaces.Box:
+ """"""
+
+ get_first_element = lambda pytree: jax.tree_util.tree_map(
+ lambda l: l[0], pytree
+ )
+
+ low = self.low if not self.vectorized else get_first_element(self.low)
+ high = self.high if not self.vectorized else get_first_element(self.high)
+
+ low_flat = np.array(self.flatten_pytree(pytree=low))
+ high_flat = np.array(self.flatten_pytree(pytree=high))
+
+ if self.vectorized:
+ assert self.vectorize >= 2
+
+ repeats = tuple([self.vectorize] + [1] * low_flat.ndim)
+
+ low_flat = np.tile(low_flat, repeats)
+ high_flat = np.tile(high_flat, repeats)
+
+ return gym.spaces.Box(
+ low=np.array(low_flat, dtype=np.float32),
+ high=np.array(high_flat, dtype=np.float32),
+ seed=copy.deepcopy(self.np_random),
+ )
+
+ def to_dict(self) -> gym.spaces.Dict:
+ # if low/high are a dataclass -> convert to dict
+ raise NotImplementedError
+
+ @staticmethod
+ def flatten_pytree(pytree: jtp.PyTree) -> jtp.VectorJax:
+ """"""
+
+ # print("flatten_pytree")
+ pytree_flat, _ = jax.flatten_util.ravel_pytree(pytree)
+ return pytree_flat
+
+ def flatten_sample(self, pytree: jtp.PyTree) -> jtp.VectorJax:
+ """"""
+
+ if not self.vectorized:
+ return self.flatten_pytree(pytree=pytree)
+
+ # @jax.jit
+ # def flatten_pytree(pytree: jtp.PyTree) -> jtp.ArrayJax:
+ # print("compiling")
+ # return jax.vmap(self.flatten_pytree)(pytree)
+ #
+ # return flatten_pytree(pytree=pytree)
+
+ # TODO: this trigger recompilation -> do some trick
+ # return jax.jit(jax.vmap(self.flatten_pytree))(pytree)
+ return PyTree._flatten_sample_vmap(pytree)
+
+ @staticmethod
+ @jax.jit
+ def _flatten_sample_vmap(pytree: jtp.PyTree) -> jtp.VectorJax:
+ return jax.vmap(PyTree.flatten_pytree)(pytree)
+
+ def unflatten_sample(self, x: jtp.Vector) -> jtp.PyTree:
+ """"""
+
+ if not self.vectorized:
+ _, unflatten_fn = jax.flatten_util.ravel_pytree(self.low)
+ return unflatten_fn(x)
+
+ # low_1d = jax.tree_util.tree_map(lambda l: l[0], self.low)
+ # low_1d_flat, unflatten_fn = jax.flatten_util.ravel_pytree(low_1d)
+ #
+ # @jax.jit
+ # def unflatten_sample(x: jtp.Vector) -> jtp.PyTree:
+ # return jax.vmap(unflatten_fn)(x)
+ #
+ # return unflatten_sample(x=x)
+ return PyTree._unflatten_sample_vmap(x=x, low=self.low)
+
+ @staticmethod
+ @jax.jit
+ def _unflatten_sample_vmap(x: jtp.Vector, low: jtp.PyTree) -> jtp.PyTree:
+ """"""
+
+ low_1d = jax.tree_util.tree_map(lambda l: l[0], low)
+ low_1d_flat, unflatten_fn = jax.flatten_util.ravel_pytree(low_1d)
+
+ return jax.vmap(unflatten_fn)(x)
+
+ def clip(self, x: jtp.PyTree) -> jtp.PyTree:
+ """"""
+
+ # TODO: prevent recompilation
+ # @jax.jit
+ def _clip(pytree: jtp.PyTree) -> jtp.PyTree:
+ return jax.tree_util.tree_map(
+ lambda low, high, leaf: jnp.array(
+ jnp.clip(a=leaf, a_min=low, a_max=high), dtype=jnp.array(low).dtype
+ ),
+ self.low,
+ self.high,
+ pytree,
+ )
+
+ return _clip(pytree=x)
+
+
+@flatdim.register(PyTree)
+def _flatdim_pytree(space: PyTree) -> int:
+ """"""
+
+ low_flat = space.flatten_sample(pytree=space.low)
+ return low_flat.size
+
+
+@flatten.register(PyTree)
+def _flatten_pytree(space: PyTree, x: jtp.PyTree) -> npt.NDArray:
+ """"""
+
+ assert x in space
+ return space.flatten_sample(pytree=x)
+
+
+@batch_space.register(PyTree)
+def _batch_space_pytree(space: PyTree, n: int = 1) -> PyTree:
+ """"""
+
+ return PyTree(low=space.low, high=space.high, vectorize=n)
diff --git a/src/jaxgym/jax/utils.py b/src/jaxgym/jax/utils.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/jaxgym/stable_baselines.py b/src/jaxgym/stable_baselines.py
new file mode 100644
index 000000000..c6526ab7c
--- /dev/null
+++ b/src/jaxgym/stable_baselines.py
@@ -0,0 +1,209 @@
+import functools
+from typing import Any, Dict, List, Optional, Type, Union
+
+import gymnasium as gym
+import jax.random
+import numpy as np
+from gymnasium.experimental.vector.vector_env import VectorWrapper
+from stable_baselines3.common import vec_env as vec_env_sb
+
+import jaxsim.typing as jtp
+from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper, PyTree
+from jaxgym.vector.jax import FlattenSpacesVecWrapper, JaxVectorEnv
+from jaxgym.wrappers.jax import ToNumPyWrapper
+
+
+class CustomVecEnvSB(vec_env_sb.VecEnv):
+ """Custom vectorized environment for SB3."""
+
+ metadata = {"render_modes": []}
+
+ def __init__(
+ self,
+ jax_vector_env: JaxVectorEnv | VectorWrapper,
+ ) -> None:
+ """
+ Create a custom vectorized environment for SB3 from a JaxVectorEnv.
+
+ Args:
+ jax_vector_env: The JaxVectorEnv to wrap.
+ """
+
+ if not isinstance(jax_vector_env.unwrapped, JaxVectorEnv):
+ raise TypeError(type(jax_vector_env))
+
+ self.jax_vector_env = jax_vector_env
+
+ single_env_action_space: PyTree = jax_vector_env.unwrapped.single_action_space
+
+ single_env_observation_space: PyTree = (
+ jax_vector_env.unwrapped.single_observation_space
+ )
+
+ super().__init__(
+ num_envs=self.jax_vector_env.num_envs,
+ action_space=single_env_action_space.to_box(),
+ observation_space=single_env_observation_space.to_box(),
+ render_mode=None,
+ )
+
+ self.actions = np.zeros_like(self.jax_vector_env.action_space.sample())
+
+ # Initialize the RNG seed
+ self._seed = None
+ self.seed()
+
+ def reset(self) -> vec_env_sb.base_vec_env.VecEnvObs:
+ """Reset all the environments."""
+
+ observations, state_infos = self.jax_vector_env.reset(seed=self._seed)
+ return np.array(observations)
+
+ def step_async(self, actions: np.ndarray) -> None:
+ self.actions = actions
+
+ @staticmethod
+ @functools.partial(jax.jit, static_argnames=("batch_size",))
+ def tree_inverse_transpose(pytree: jtp.PyTree, batch_size: int) -> List[jtp.PyTree]:
+ """
+ Utility function to perform the inverse of a pytree transpose operation.
+
+ It converts a pytree having the batch size in the first dimension of its leaves
+ to a list of pytrees having a single batch sample in their leaves.
+
+ Note: Check the direct transpose operation in the following link:
+ https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
+
+ Args:
+ pytree: The batched pytree.
+ batch_size: The batch size.
+
+ Returns:
+ A list of pytrees having a single batch sample in their leaves.
+ """
+
+ return [
+ jax.tree_util.tree_map(lambda leaf: leaf[i], pytree)
+ for i in range(batch_size)
+ ]
+
+ def step_wait(self) -> vec_env_sb.base_vec_env.VecEnvStepReturn:
+ """Wait for the step taken with step_async()."""
+
+ (
+ observations,
+ rewards,
+ terminals,
+ truncated,
+ step_infos,
+ ) = self.jax_vector_env.step(actions=self.actions)
+
+ done = np.logical_or(terminals, truncated)
+
+ # Convert the infos from a batched dictionary to a list of dictionaries
+ list_of_step_infos = self.tree_inverse_transpose(
+ pytree=step_infos, batch_size=self.jax_vector_env.num_envs
+ )
+
+ # Convert all info data to numpy
+ list_of_step_infos_numpy = [
+ ToNumPyWrapper.pytree_to_numpy(pytree=pt) for pt in list_of_step_infos
+ ]
+
+ return (
+ np.array(observations),
+ np.array(rewards),
+ np.array(done),
+ list_of_step_infos_numpy,
+ )
+
+ def close(self) -> None:
+ """Clean up the environment's resources."""
+
+ return self.jax_vector_env.close()
+
+ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
+ """Sets the random seeds for all environments."""
+
+ if seed is None:
+ seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
+
+ if np.array(seed, dtype="uint32") != np.array(seed):
+ raise ValueError(f"seed must be compatible with 'uint32' casting")
+
+ self._seed = seed
+ return [seed]
+
+ def get_attr(
+ self, attr_name: str, indices: vec_env_sb.base_vec_env.VecEnvIndices = None
+ ) -> List[Any]:
+ raise NotImplementedError
+
+ def set_attr(
+ self,
+ attr_name: str,
+ value: Any,
+ indices: vec_env_sb.base_vec_env.VecEnvIndices = None,
+ ) -> None:
+ raise NotImplementedError
+
+ def env_method(
+ self,
+ method_name: str,
+ *method_args,
+ indices: vec_env_sb.base_vec_env.VecEnvIndices = None,
+ **method_kwargs,
+ ) -> List[Any]:
+ raise NotImplementedError
+
+ def env_is_wrapped(
+ self,
+ wrapper_class: Type[gym.Wrapper],
+ indices: vec_env_sb.base_vec_env.VecEnvIndices = None,
+ ) -> List[bool]:
+ raise NotImplementedError
+
+
+def make_vec_env_stable_baselines(
+ jax_dataclass_env: JaxDataclassEnv | JaxDataclassWrapper,
+ n_envs: int = 1,
+ seed: Optional[int] = None,
+ # monitor_dir: Optional[str] = None,
+ vec_env_kwargs: Optional[Dict[str, Any]] = None,
+) -> vec_env_sb.VecEnv:
+ """
+ Create a SB3 vectorized environment from an individual `JaxDataclassEnv`.
+
+ Args:
+ jax_dataclass_env: The individual `JaxDataclassEnv`.
+ n_envs: Number of parallel environments.
+ seed: The seed for the vectorized environment.
+ vec_env_kwargs: Additional arguments to pass upon environment creation.
+
+ Returns:
+ The SB3 vectorized environment.
+ """
+
+ env = jax_dataclass_env
+ vec_env_kwargs = vec_env_kwargs if vec_env_kwargs is not None else dict()
+
+ # Vectorize the environment.
+ # Note: it automatically wraps the environment in a TimeLimit wrapper.
+ # Note: the space must be PyTree.
+ vec_env = JaxVectorEnv(
+ func_env=env,
+ num_envs=n_envs,
+ **vec_env_kwargs,
+ )
+
+ # Flatten the PyTree spaces to regular Box spaces
+ vec_env = FlattenSpacesVecWrapper(env=vec_env)
+
+ # Convert the vectorized environment to a SB3 vectorized environment
+ vec_env_sb = CustomVecEnvSB(jax_vector_env=vec_env)
+
+ # Set the seed
+ if seed is not None:
+ _ = vec_env_sb.seed(seed=seed)
+
+ return vec_env_sb
diff --git a/src/jaxgym/tests/test_spaces.py b/src/jaxgym/tests/test_spaces.py
new file mode 100644
index 000000000..0a5e656d0
--- /dev/null
+++ b/src/jaxgym/tests/test_spaces.py
@@ -0,0 +1,102 @@
+import jax.numpy as jnp
+import jax.random
+import jax_dataclasses
+import pytest
+
+import jaxsim.typing as jtp
+from jaxgym import spaces
+
+
+def compare(low: jtp.PyTree, high: jtp.PyTree, box: spaces.Box) -> None:
+ """"""
+
+ assert box.contains(x=low)
+ assert box.contains(x=high)
+
+ key = jax.random.PRNGKey(seed=0)
+
+ for _ in range(10):
+ key, subkey = jax.random.split(key=key, num=2)
+ sample = box.sample(key=subkey)
+ assert box.contains(x=sample)
+
+
+def test_box_numpy() -> None:
+ """"""
+
+ low = jnp.zeros(10)
+ high = jnp.ones(10)
+
+ box = spaces.Box(low=low, high=high)
+
+ compare(low=low, high=high, box=box)
+
+ assert box.contains(x=0.5 * jnp.ones_like(low))
+ assert not box.contains(x=1.5 * jnp.ones_like(low))
+ assert not box.contains(x=-0.5 * jnp.ones_like(low))
+
+ with pytest.raises(ValueError):
+ _ = spaces.Box(low=low, high=jnp.ones(low.size + 1))
+
+ with pytest.raises(ValueError):
+ _ = spaces.Box(low=low, high=jnp.ones(low.size, dtype=int))
+
+
+def test_box_pytree() -> None:
+ """"""
+
+ @jax_dataclasses.pytree_dataclass
+ class SimplePyTree:
+ flag: jtp.Bool
+ value: jtp.Float
+ position: jtp.Vector
+ velocity: jtp.Vector
+
+ @staticmethod
+ def zero() -> "SimplePyTree":
+ return SimplePyTree(
+ flag=False,
+ value=0,
+ position=jnp.zeros(5),
+ velocity=jnp.zeros(10),
+ )
+
+ zero = SimplePyTree.zero()
+
+ low = SimplePyTree(
+ flag=False,
+ value=-42.0,
+ position=-10.0 * jnp.ones_like(zero.position),
+ velocity=0.1 * jnp.ones_like(zero.velocity),
+ )
+
+ high = SimplePyTree(
+ flag=True,
+ value=42.0,
+ position=10.0 * jnp.ones_like(zero.position),
+ velocity=5.0 * jnp.ones_like(zero.velocity),
+ )
+
+ box = spaces.Box(low=low, high=high)
+
+ compare(low=low, high=high, box=box)
+
+ # Wrong dimension of 'position'
+ with pytest.raises(ValueError):
+ wrong_high = SimplePyTree(
+ flag=True,
+ value=42.0,
+ position=jnp.zeros(6),
+ velocity=jnp.zeros(10),
+ )
+ _ = spaces.Box(low=low, high=wrong_high)
+
+ # Wrong type of 'position' and 'value'
+ with pytest.raises(ValueError):
+ wrong_high = SimplePyTree(
+ flag=True,
+ value=int(42),
+ position=jnp.zeros(5, dtype=int),
+ velocity=jnp.zeros(10),
+ )
+ _ = spaces.Box(low=low, high=wrong_high)
diff --git a/src/jaxgym/vector/__init__.py b/src/jaxgym/vector/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/jaxgym/vector/jax/__init__.py b/src/jaxgym/vector/jax/__init__.py
new file mode 100644
index 000000000..0cab549d5
--- /dev/null
+++ b/src/jaxgym/vector/jax/__init__.py
@@ -0,0 +1,2 @@
+from .vector_env import JaxVectorEnv
+from .wrappers import FlattenSpacesVecWrapper
diff --git a/src/jaxgym/vector/jax/vector_env.py b/src/jaxgym/vector/jax/vector_env.py
new file mode 100644
index 000000000..096db16a8
--- /dev/null
+++ b/src/jaxgym/vector/jax/vector_env.py
@@ -0,0 +1,359 @@
+import copy
+from typing import Any, Sequence
+
+import jax.flatten_util
+import jax.numpy as jnp
+import jax.random
+import numpy as np
+from gymnasium.envs.registration import EnvSpec
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv
+from gymnasium.utils import seeding
+from gymnasium.vector.utils import batch_space
+
+import jaxgym.jax.pytree_space as spaces
+import jaxsim.typing as jtp
+from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper
+from jaxgym.wrappers.jax import JaxTransformWrapper, TimeLimit
+from jaxsim import logging
+from jaxsim.utils import not_tracing
+
+
+# https://github.com/Farama-Foundation/Gymnasium/blob/e0cd42f77504060e770ab52932bf7eba45ff1976/gymnasium/experimental/functional_jax_env.py#L116
+# TODO: allow num_envs = 1 so we have automatically autoreset?
+# class JaxVectorEnv(VectorEnv[VectorObsType, VectorActType, ArrayType]):
+# Note no dataclass here on all stuff related to VectorEnv
+class JaxVectorEnv(VectorEnv[ObsType, ActType, ArrayType]):
+ """
+ A vectorized version of JAX-based functional environments exposing `VectorEnv` APIs.
+ """
+
+ observation_space: spaces.PyTree
+ action_space: spaces.PyTree
+
+ def __init__(
+ self,
+ func_env: JaxDataclassEnv[
+ StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ ]
+ | JaxDataclassWrapper,
+ num_envs: int,
+ max_episode_steps: int = 0,
+ metadata: dict[str, Any] | None = None,
+ render_mode: str | None = None,
+ reward_range: tuple[float, float] = (-float("inf"), float("inf")),
+ spec: EnvSpec | None = None,
+ jit_compile: bool = True,
+ ) -> None:
+ """"""
+
+ if not isinstance(func_env.unwrapped, JaxDataclassEnv):
+ raise TypeError(type(func_env.unwrapped), JaxDataclassEnv)
+
+ metadata = metadata if metadata is not None else dict(render_mode=list())
+
+ self.num_envs = num_envs
+ self.func_env_single = func_env
+ self.single_observation_space = func_env.observation_space
+ self.single_action_space = func_env.action_space
+
+ # TODO: convert other spaces to their PyTree equivalent
+ assert isinstance(func_env.action_space, spaces.PyTree)
+ assert isinstance(func_env.observation_space, spaces.PyTree)
+
+ self.action_space = batch_space(self.single_action_space, n=num_envs)
+ self.observation_space = batch_space(self.single_observation_space, n=num_envs)
+
+ # TODO: attributes below
+ self.metadata = metadata
+ self.render_mode = render_mode
+ self.reward_range = reward_range
+ self.spec = spec
+ # self.time_limit = max_episode_steps
+
+ # Store the original functional environment
+ self.func_env_single = func_env
+
+ def has_wrapper(
+ func_env: JaxDataclassEnv | JaxDataclassWrapper,
+ wrapper_cls: type,
+ ) -> bool:
+ """"""
+
+ while not isinstance(func_env, JaxDataclassEnv):
+ if isinstance(func_env, wrapper_cls):
+ return True
+
+ func_env = func_env.env
+
+ return False
+
+ # Always wrap the environment in a TimeLimit wrapper, that automatically counts
+ # the number of steps and issues a "truncated" flag.
+ # Note: the TimeLimit wrapper is a no-op if max_episode_steps is 0.
+ # Note: the state of the wrapped environment now is different. The state of
+ # the original environment is now encapsulated in a dictionary.
+ # TODO: make this optional? Check if it is already wrapped?
+ # if max_episode_steps is not None:
+ if not has_wrapper(func_env=self.func_env_single, wrapper_cls=TimeLimit):
+ logging.debug(
+ "[JaxVectorEnv] Wrapping the environment in a 'TimeLimit' wrapper"
+ )
+ self.func_env_single = TimeLimit(
+ env=self.func_env_single, max_episode_steps=max_episode_steps
+ )
+
+ # Initialize the attribute that will store the environments state
+ self.states = None
+
+ # Initialize the step counter
+ # TODO: handled by TimeLimit
+ # self.steps = jnp.zeros(self.num_envs, dtype=jnp.uint32)
+
+ # TODO: in our case, assume pytree? -> batch easy and generic?
+ # --> singledispatch from gymnasium.space to pytree? And add a wrapper to_numpy|to_pytorch later?
+ # Doing like this, Obs|Action|Reward are always pytree.
+ # self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
+
+ # if self.render_mode == "rgb_array":
+ # self.render_state = self.func_env.render_init()
+ # else:
+ # self.render_state = None
+
+ # Initialize the RNGs with a random seed
+ seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32")
+ self._np_random, _ = seeding.np_random(seed=int(seed))
+ self._key = jax.random.PRNGKey(seed=seed)
+
+ # self.func_env = TransformWrapper(env=self.func_env, function=jax.vmap)
+ self.func_env = JaxTransformWrapper(env=self.func_env_single, function=jax.vmap)
+
+ # Compile resources in JIT if requested.
+ # Note: this wrapper will override any other JIT wrapper already present.
+ if jit_compile:
+ self.step_autoreset_func = jax.jit(self.step_autoreset_func)
+ self.func_env = JaxTransformWrapper(env=self.func_env, function=jax.jit)
+
+ def subkey(self, num: int = 1) -> jax.random.PRNGKeyArray:
+ """
+ Generate one or multiple sub-keys from the internal key.
+
+ Note:
+ The internal key is automatically updated, there's no need to handle
+ the environment key externally.
+
+ Args:
+ num: Number of keys to generate.
+
+ Returns:
+ The generated sub-keys.
+ """
+
+ self._key, *sub_keys = jax.random.split(self._key, num=num + 1)
+ return jnp.stack(sub_keys).squeeze()
+
+ def reset(
+ self, *, seed: int | None = None, options: dict | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """
+ Reset the environments.
+
+ Note:
+ This method should be called just once after the creation of the vectorized
+ environment. This class implements autoreset, therefore environments that
+ either terminated or have been truncated get automatically reset.
+
+ Args:
+ seed:
+ options:
+
+ Returns:
+ A tuple containing the initial observations and the initial states' info.
+ """
+
+ super().reset(seed=seed)
+ self._key = jax.random.PRNGKey(seed) if seed is not None else self._key
+
+ # Generate initial states
+ self.states = self.func_env.initial(rng=self.subkey(num=self.num_envs))
+
+ # Sample initial observations and infos
+ observations = self.func_env.observation(self.states)
+ infos = self.func_env.state_info(self.states)
+
+ return observations, infos
+
+ @staticmethod
+ def binary_mask_pytree(
+ pytree_a: jtp.PyTree, pytree_b: jtp.PyTree, mask: Sequence[bool]
+ ) -> jtp.PyTree:
+ """
+ Compute a new vectorized PyTree selecting elements from either of the
+ two input PyTrees according to the boolean mask.
+
+ Note:
+ The shapes of pytree_a and pytree_b must match, and they must be vectorized,
+ meaning that all their leafs have share the dimension of the first axis.
+ The mask should have as many elements as this shared dimension.
+
+ Args:
+ pytree_a: the first vectorized PyTree object.
+ pytree_b: the second vectorized PyTree object.
+ mask: the boolean mask to select elements either from pytree_a (when True)
+ or pytree_b (when False).
+
+ Returns:
+ A new PyTree having elements taken either from pytree_a or pytree_a
+ according to mask.
+ """
+
+ def check():
+ first_dim_a = jax.tree_util.tree_map(lambda l: l.shape[0], pytree_a)
+ first_dim_b = jax.tree_util.tree_map(lambda l: l.shape[0], pytree_b)
+
+ # Check that the input PyTrees have the same first dimension of their leaves
+ if first_dim_a != first_dim_b:
+ raise ValueError()
+
+ in_axis_a = jnp.unique(
+ jax.flatten_util.ravel_pytree(first_dim_a)[0]
+ ).squeeze()
+ in_axis_b = jnp.unique(
+ jax.flatten_util.ravel_pytree(first_dim_a)[0]
+ ).squeeze()
+
+ # Check that all leaves have the same first dimension and it matches with
+ # the length of the mask
+ if in_axis_a != in_axis_b != len(mask):
+ raise ValueError()
+
+ if not_tracing(var=pytree_a):
+ check()
+
+ # Convert the boolean mask to a PyTree having boolean leaves.
+ # True elements of the leaves are taken from pytree_a, False ones from pytree_b.
+ mask_pytree = jax.tree_util.tree_map(
+ lambda l: jnp.ones_like(l, dtype=bool)
+ * mask[(...,) + (jnp.newaxis,) * (l.ndim - 1)],
+ pytree_a,
+ )
+
+ # Create the output pytree taking elements from either pytree_a or pytree_b
+ # according to the boolean PyTree built from the mask
+ tree_out = jax.tree_util.tree_map(
+ lambda a, b, m: jnp.where(m, a, b), pytree_a, pytree_b, mask_pytree
+ )
+
+ return tree_out
+
+ @staticmethod
+ def step_autoreset_func(
+ env: JaxDataclassEnv | JaxDataclassWrapper,
+ states: StateType,
+ actions: ActType,
+ key1: jax.random.PRNGKeyArray,
+ key2: jax.random.PRNGKeyArray,
+ ) -> tuple[StateType, tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]]:
+ """"""
+
+ # Duplicate the keys
+ # TODO: with new jax version maybe split can be jitted -> pass just one key and split inside
+ # key, *subkey_1 = jax.random.split(key, num=num_envs + 1)
+ # key, *subkey_2 = jax.random.split(key, num=num_envs + 1)
+
+ # Compute data by stepping the environments
+ next_states = env.transition(state=states, action=actions, rng=key1)
+ rewards = env.reward(state=states, action=actions, next_state=next_states)
+ terminals = env.terminal(state=next_states)
+ step_infos = env.step_info(state=states, action=actions, next_state=next_states)
+ truncated = step_infos["truncated"]
+
+ # Check if any environment is done
+ dones = jnp.logical_or(terminals, truncated)
+
+ # Add into step_infos the information about the final state even if the
+ # environments are not done.
+ # This is necessary for having a constant structure of the output pytree.
+ # The _final_observation|_final_info masks can be used to filter out the
+ # actual final data from the final_observation|final_info dictionaries.
+ #
+ # Note: the step_info dictionary of done environments that have been
+ # automatically reset shouldn't be consumed. It refers to the environment
+ # before being reset. Users in this case should read final_info.
+ step_infos |= (
+ dict(
+ final_observation=env.observation(state=next_states),
+ terminal_observation=env.observation(state=next_states), # sb3
+ final_info=copy.deepcopy(step_infos),
+ _final_observation=dones,
+ _final_info=dones,
+ is_done=dones,
+ )
+ # Backward compatibility (?) -> SB3 (TODO: done in TimeLimit)
+ # | {
+ # "TimeLimit.truncated": truncated,
+ # "terminal_observation": env.observation(state=next_states),
+ # }
+ )
+
+ # Compute the new state and new state_infos for all environments.
+ # We return this data only for those that are done.
+ # new_states = env.initial(rng=key2)
+ # new_state_infos = env.state_info(state=new_states)
+
+ # Compute the new states for all environments.
+ # We return this data only for those that are done.
+ new_states = env.initial(rng=key2)
+
+ # Merge the environment states
+ new_env_states = JaxVectorEnv.binary_mask_pytree(
+ mask=dones,
+ # If done, return the new initial states
+ pytree_a=new_states,
+ # If not done, return the next states
+ pytree_b=next_states,
+ )
+
+ # Compute the new observations.
+ # This is a normal observation for environments that are not done.
+ # This is the observation of the initial state for environments that were done.
+ new_observations = env.observation(state=new_env_states)
+
+ return new_env_states, (
+ new_observations,
+ rewards,
+ terminals,
+ truncated,
+ step_infos,
+ )
+
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
+ """Steps through the environment using the action."""
+
+ # TODO: clip as method of Space.clip(x)? use here with hasattr(space, clip)?
+ assert isinstance(self.action_space, spaces.PyTree)
+ actions = self.action_space.clip(x=actions)
+
+ # TODO: move these inside autoreset as soon as jax.random.split
+ # supports jit compilation
+ keys_1 = self.subkey(num=self.num_envs)
+ keys_2 = self.subkey(num=self.num_envs)
+
+ self.states, out = JaxVectorEnv.step_autoreset_func(
+ env=self.func_env,
+ states=self.states,
+ actions=actions,
+ key1=keys_1,
+ key2=keys_2,
+ )
+
+ return out
diff --git a/src/jaxgym/vector/jax/wrappers/__init__.py b/src/jaxgym/vector/jax/wrappers/__init__.py
new file mode 100644
index 000000000..88d67ef80
--- /dev/null
+++ b/src/jaxgym/vector/jax/wrappers/__init__.py
@@ -0,0 +1,3 @@
+from .flatten_spaces import FlattenSpacesVecWrapper
+
+# from .tensordict import TensorDictVecWrapper
diff --git a/src/jaxgym/vector/jax/wrappers/flatten_spaces.py b/src/jaxgym/vector/jax/wrappers/flatten_spaces.py
new file mode 100644
index 000000000..797aec323
--- /dev/null
+++ b/src/jaxgym/vector/jax/wrappers/flatten_spaces.py
@@ -0,0 +1,81 @@
+import jax.numpy as jnp
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+from gymnasium.experimental.vector.vector_env import VectorWrapper
+
+from jaxgym.vector.jax import JaxVectorEnv
+
+WrapperStateType = StateType
+WrapperObsType = jnp.ndarray # TODO jax.typing
+WrapperActType = jnp.ndarray
+WrapperRewardType = RewardType
+
+
+# TODO: not dataclass when operating on VectorWrapper -> check other ones
+class FlattenSpacesVecWrapper(VectorWrapper):
+ """"""
+
+ # TODO: vec_env?
+ env: JaxVectorEnv
+
+ def __init__(self, env: JaxVectorEnv) -> None:
+ """"""
+
+ if not isinstance(env, JaxVectorEnv):
+ raise TypeError(type(env))
+
+ self.action_space = env.action_space.to_box()
+ self.observation_space = env.observation_space.to_box()
+
+ super().__init__(env=env)
+
+ def reset(
+ self,
+ **kwargs
+ # *,
+ # seed: int | list[int] | None = None,
+ # options: dict[str, Any] | None = None,
+ # ) -> tuple[ObsType, dict[str, Any]]:
+ ):
+ """"""
+
+ observations, state_infos = self.env.reset(**kwargs)
+ # return self.env.observation_space.flatten_pytree(pytree=observation), state_info
+ return (
+ self.env.observation_space.flatten_sample(pytree=observations),
+ state_infos,
+ )
+
+ def step(self, actions):
+ # ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
+ """"""
+
+ observations, rewards, terminals, truncated, step_infos = self.env.step(
+ actions=self.env.action_space.unflatten_sample(x=actions)
+ )
+
+ if "final_observation" in step_infos:
+ step_infos["final_observation"] = self.env.observation_space.flatten_sample(
+ pytree=step_infos["final_observation"]
+ )
+
+ if "terminal_observation" in step_infos:
+ step_infos[
+ "terminal_observation"
+ ] = self.env.observation_space.flatten_sample(
+ pytree=step_infos["terminal_observation"]
+ )
+
+ return (
+ self.env.observation_space.flatten_sample(pytree=observations),
+ rewards,
+ terminals,
+ truncated,
+ step_infos,
+ )
diff --git a/src/jaxgym/vector/jax/wrappers/tensordict.py b/src/jaxgym/vector/jax/wrappers/tensordict.py
new file mode 100644
index 000000000..10fd5e437
--- /dev/null
+++ b/src/jaxgym/vector/jax/wrappers/tensordict.py
@@ -0,0 +1,3 @@
+class TensorDictVecWrapper:
+ def __init__(self) -> None:
+ raise NotImplementedError
diff --git a/src/jaxgym/wrappers/__init__.py b/src/jaxgym/wrappers/__init__.py
new file mode 100644
index 000000000..9926dee6e
--- /dev/null
+++ b/src/jaxgym/wrappers/__init__.py
@@ -0,0 +1,2 @@
+from .state import StateWrapper
+from .transform import TransformWrapper
diff --git a/src/jaxgym/wrappers/clip_action.py b/src/jaxgym/wrappers/clip_action.py
new file mode 100644
index 000000000..04e5d8896
--- /dev/null
+++ b/src/jaxgym/wrappers/clip_action.py
@@ -0,0 +1,25 @@
+# from typing import Generic
+# from jaxgym.functional import ActionFuncWrapper
+# from jaxgym.functional.func_wrapper import WrapperActType
+# from gymnasium.experimental.functional import ActType
+# import gymnasium as gym
+# import numpy as np
+#
+#
+# class ClipActionWrapper(
+# ActionFuncWrapper[WrapperActType],
+# Generic[WrapperActType],
+# ):
+# """"""
+#
+# def action(self, action: WrapperActType) -> ActType:
+# """"""
+#
+# if self.action_space.contains(x=action):
+# return action
+#
+# assert isinstance(self.action_space, gym.spaces.Box)
+#
+# return np.clip(
+# action, a_min=self.action_space.low, a_max=self.action_space.high
+# )
diff --git a/src/jaxgym/wrappers/jax/__init__.py b/src/jaxgym/wrappers/jax/__init__.py
new file mode 100644
index 000000000..66e767ad2
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/__init__.py
@@ -0,0 +1,10 @@
+from .action_noise import ActionNoiseWrapper
+from .clip_action import ClipActionWrapper
+from .flatten_spaces import FlattenSpacesWrapper
+from .nan_handler import NaNHandlerWrapper
+from .squash_action import SquashActionWrapper
+from .time_limit import TimeLimit
+from .to_numpy import ToNumPyWrapper
+from .transform import JaxTransformWrapper
+
+# from .time_limit_sb import TimeLimitStableBaselines
diff --git a/src/jaxgym/wrappers/jax/action_noise.py b/src/jaxgym/wrappers/jax/action_noise.py
new file mode 100644
index 000000000..5fe4f214f
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/action_noise.py
@@ -0,0 +1,77 @@
+from typing import Any, Callable, Generic
+
+import jax.flatten_util
+import jax.numpy as jnp
+import jax.tree_util
+import jax_dataclasses
+import numpy.typing as npt
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.jax import JaxDataclassWrapper
+from jaxsim import logging
+
+WrapperStateType = StateType
+WrapperObsType = ObsType
+WrapperActType = ActType
+WrapperRewardType = RewardType
+
+
+@jax_dataclasses.pytree_dataclass
+class ActionNoiseWrapper(
+ JaxDataclassWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ],
+ Generic[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ ],
+):
+ """"""
+
+ noise_fn: Callable[
+ [npt.NDArray, jax.random.PRNGKeyArray], npt.NDArray
+ ] = jax_dataclasses.static_field(
+ default=lambda action, rng: action
+ + 0.05 * jax.random.normal(key=rng, shape=action.shape)
+ )
+
+ def __post_init__(self) -> None:
+ """"""
+
+ super().__post_init__()
+
+ msg = f"[{self.__class__.__name__}] enabled"
+ logging.debug(msg=msg)
+
+ def transition(
+ self, state: WrapperStateType, action: WrapperActType, rng: Any = None
+ ) -> WrapperStateType:
+ """"""
+
+ rng, subkey = jax.random.split(rng, num=2)
+
+ action_flat, restore_fn = jax.flatten_util.ravel_pytree(pytree=action)
+ action_noisy_flat = self.noise_fn(action_flat, subkey)
+ action_noisy = restore_fn(action_noisy_flat)
+
+ return self.env.transition(state=state, action=action_noisy, rng=rng)
diff --git a/src/jaxgym/wrappers/jax/clip_action.py b/src/jaxgym/wrappers/jax/clip_action.py
new file mode 100644
index 000000000..3d2381bca
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/clip_action.py
@@ -0,0 +1,50 @@
+from typing import Generic
+
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.functional.func_wrapper import WrapperActType
+from jaxgym.jax import JaxDataclassActionWrapper
+from jaxsim import logging
+
+
+@jax_dataclasses.pytree_dataclass
+class ClipActionWrapper(
+ JaxDataclassActionWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ ],
+ Generic[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ ],
+):
+ """"""
+
+ def __post_init__(self) -> None:
+ """"""
+
+ super().__post_init__()
+
+ msg = f"[{self.__class__.__name__}] enabled"
+ logging.debug(msg=msg)
+
+ def action(self, action: WrapperActType) -> ActType:
+ """"""
+
+ return self.action_space.clip(x=action)
diff --git a/src/jaxgym/wrappers/jax/flatten_spaces.py b/src/jaxgym/wrappers/jax/flatten_spaces.py
new file mode 100644
index 000000000..bb31de3e2
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/flatten_spaces.py
@@ -0,0 +1,75 @@
+from typing import Any
+
+import gymnasium as gym
+import jax.numpy as jnp
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.jax import JaxDataclassWrapper
+from jaxsim import logging
+
+WrapperStateType = StateType
+WrapperObsType = jnp.ndarray
+WrapperActType = jnp.ndarray
+WrapperRewardType = RewardType
+
+
+# TODO: maybe better over JaxEnv to be consistent with JaxVectorEnv?
+@jax_dataclasses.pytree_dataclass
+class FlattenSpacesWrapper(
+ JaxDataclassWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ]
+):
+ """"""
+
+ # Propagate to other Jax wrappers
+ def __post_init__(self):
+ """"""
+
+ super().__post_init__()
+
+ msg = f"[{self.__class__.__name__}] enabled"
+ logging.debug(msg=msg)
+
+ @property
+ def action_space(self) -> gym.Space:
+ """"""
+
+ return self.env.action_space.to_box()
+
+ @property
+ def observation_space(self) -> gym.Space:
+ """"""
+
+ return self.env.observation_space.to_box()
+
+ def transition(
+ self, state: WrapperStateType, action: WrapperActType, rng: Any = None
+ ) -> WrapperStateType:
+ """"""
+
+ action_pytree = self.env.action_space.unflatten_sample(x=action)
+ return self.env.transition(state=state, action=action_pytree, rng=rng)
+
+ def observation(self, state: WrapperStateType) -> WrapperObsType:
+ """"""
+
+ observation_pytree = self.env.observation(state=state)
+ return self.env.observation_space.flatten_pytree(pytree=observation_pytree)
diff --git a/src/jaxgym/wrappers/jax/nan_handler.py b/src/jaxgym/wrappers/jax/nan_handler.py
new file mode 100644
index 000000000..73638f5c5
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/nan_handler.py
@@ -0,0 +1,138 @@
+from typing import Any, ClassVar, Generic
+
+import jax.flatten_util
+import jax.numpy as jnp
+import jax.tree_util
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+import jaxsim.typing as jtp
+from jaxgym.functional import FuncEnv
+from jaxgym.wrappers import StateWrapper
+from jaxsim import logging
+
+WrapperStateType = StateType
+WrapperObsType = ObsType
+WrapperActType = ActType
+WrapperRewardType = RewardType
+
+
+@jax_dataclasses.pytree_dataclass
+class NaNHandlerWrapper(
+ StateWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ WrapperStateType,
+ ],
+ Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType],
+):
+ """Reset the environment when a NaN is encountered."""
+
+ env: FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType]
+
+ HasNanKey: ClassVar[str] = "has_nan"
+ EnvironmentStateKey: ClassVar[str] = "env"
+
+ def __post_init__(self) -> None:
+ """"""
+
+ msg = f"[{self.__class__.__name__}] enabled"
+ logging.debug(msg=msg)
+
+ def wrapper_state_to_environment_state(
+ self, wrapper_state: WrapperStateType
+ ) -> StateType:
+ """"""
+
+ return wrapper_state[NaNHandlerWrapper.EnvironmentStateKey]
+
+ def initial(self, rng: Any = None) -> WrapperStateType:
+ """"""
+
+ return {
+ NaNHandlerWrapper.HasNanKey: jnp.array(False),
+ NaNHandlerWrapper.EnvironmentStateKey: self.env.initial(rng=rng),
+ }
+
+ def transition(
+ self, state: WrapperStateType, action: WrapperActType, rng: Any = None
+ ) -> WrapperStateType:
+ """"""
+
+ # Copy the current state
+ old_state = jax.tree_util.tree_map(
+ lambda x: x, state[NaNHandlerWrapper.EnvironmentStateKey]
+ )
+
+ # Step the environment
+ new_state = self.env.transition(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state),
+ action=action,
+ rng=rng,
+ )
+
+ new_state_without_nans = jax.tree_util.tree_map(
+ lambda leaf_new, leaf_old: jax.lax.select(
+ pred=self.pytree_has_nan_values(pytree=leaf_new),
+ on_true=leaf_old,
+ on_false=leaf_new,
+ ),
+ new_state,
+ old_state,
+ )
+
+ return {
+ NaNHandlerWrapper.HasNanKey: self.pytree_has_nan_values(pytree=new_state),
+ NaNHandlerWrapper.EnvironmentStateKey: new_state_without_nans,
+ }
+
+ def step_info(
+ self,
+ state: WrapperStateType,
+ action: WrapperActType,
+ next_state: WrapperStateType,
+ ) -> dict[str, Any]:
+ """"""
+
+ # Get the step info from the environment
+ info = self.env.step_info(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state),
+ action=action,
+ next_state=self.wrapper_state_to_environment_state(
+ wrapper_state=next_state
+ ),
+ )
+
+ # Activate the truncation flag if the episode is over
+ truncated = jnp.array(next_state[NaNHandlerWrapper.HasNanKey], dtype=bool)
+
+ # Check if any other wrapper already truncated the environment
+ truncated = jnp.logical_or(truncated, info.get("truncated", False))
+
+ # Handle the case in which the environment has been truncated and is done
+ truncated = jax.lax.select(
+ pred=self.terminal(state=next_state),
+ on_true=False,
+ on_false=truncated,
+ )
+
+ # Return the extended step info
+ return info | dict(truncated=truncated) | {"TimeLimit.truncated": truncated}
+
+ @staticmethod
+ def pytree_has_nan_values(pytree: jtp.PyTree) -> jtp.Bool:
+ """"""
+
+ pytree_flat, _ = jax.flatten_util.ravel_pytree(pytree=pytree)
+ return jnp.isnan(pytree_flat).any()
diff --git a/src/jaxgym/wrappers/jax/observation_noise.py b/src/jaxgym/wrappers/jax/observation_noise.py
new file mode 100644
index 000000000..782b0ffd6
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/observation_noise.py
@@ -0,0 +1,128 @@
+from typing import Any, Callable, Generic
+
+import jax.numpy as jnp
+import jax.tree_util
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.jax import JaxDataclassWrapper
+from jaxsim import logging
+
+WrapperStateType = StateType
+WrapperObsType = ObsType
+WrapperActType = ActType
+WrapperRewardType = RewardType
+
+
+# TODO: cannot do it here because only transition() has rng and not observation()
+@jax_dataclasses.pytree_dataclass
+class ObservationNoiseWrapper(
+ JaxDataclassWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ],
+ Generic[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ ],
+):
+ """"""
+
+ noise_fn: Callable[[ObsType], ObsType] = jax_dataclasses.static_field()
+
+ def __post_init__(self) -> None:
+ """"""
+
+ super().__post_init__()
+
+ msg = f"[{self.__class__.__name__}] enabled"
+ logging.debug(msg=msg)
+
+ # assert isinstance(self.env.action_space.sample(), np.ndarray)
+ # assert isinstance(self.env.observation_space.sample(), np.ndarray)
+
+ def transition(
+ self, state: WrapperStateType, action: WrapperActType, rng: Any = None
+ ) -> WrapperStateType:
+ """"""
+
+ return self.env.transition(state=state, action=action, rng=rng)
+
+ # def observation(self, state: WrapperStateType) -> WrapperObsType:
+ # """"""
+ #
+ # observation = ToNumPyWrapper.pytree_to_numpy(self.env.observation(state=state))
+ # return np.array(observation, dtype=self.env.observation_space.dtype)
+
+ # def reward(
+ # self,
+ # state: WrapperStateType,
+ # action: WrapperActType,
+ # next_state: WrapperStateType,
+ # ) -> WrapperRewardType:
+ # """"""
+ #
+ # return float(
+ # ToNumPyWrapper.pytree_to_numpy(
+ # self.env.reward(state=state, action=action, next_state=next_state)
+ # )
+ # )
+
+ # def terminal(self, state: WrapperStateType) -> TerminalType:
+ # """"""
+ #
+ # return ToNumPyWrapper.pytree_to_numpy(self.env.terminal(state=state))
+ #
+ # def state_info(self, state: WrapperStateType) -> dict[str, Any]:
+ # """"""
+ #
+ # return ToNumPyWrapper.pytree_to_numpy(self.env.state_info(state=state))
+ #
+ # def step_info(
+ # self,
+ # state: WrapperStateType,
+ # action: WrapperActType,
+ # next_state: WrapperStateType,
+ # ) -> dict[str, Any]:
+ # """"""
+ #
+ # return ToNumPyWrapper.pytree_to_numpy(
+ # self.env.step_info(state=state, action=action, next_state=next_state)
+ # )
+ #
+ # @staticmethod
+ # def pytree_to_numpy(pytree: Any) -> Any:
+ # """"""
+ #
+ # def convert_leaf(leaf: Any) -> Any:
+ # """"""
+ #
+ # if (
+ # isinstance(leaf, (np.ndarray, jnp.ndarray))
+ # and leaf.size == 1
+ # and leaf.dtype == "bool"
+ # ):
+ # return bool(leaf)
+ #
+ # return np.array(leaf)
+ #
+ # return jax.tree_util.tree_map(lambda l: convert_leaf(l), pytree)
diff --git a/src/jaxgym/wrappers/jax/squash_action.py b/src/jaxgym/wrappers/jax/squash_action.py
new file mode 100644
index 000000000..0b3e53b65
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/squash_action.py
@@ -0,0 +1,105 @@
+from typing import Generic
+
+import jax.numpy as jnp
+import jax.tree_util
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.functional.func_wrapper import WrapperActType
+from jaxgym.jax import JaxDataclassActionWrapper
+from jaxsim import logging
+from jaxsim import typing as jtp
+from jaxsim.utils import Mutability
+
+
+@jax_dataclasses.pytree_dataclass
+class SquashActionWrapper(
+ JaxDataclassActionWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ ],
+ Generic[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ ],
+):
+ """"""
+
+ def __post_init__(self) -> None:
+ """"""
+
+ super().__post_init__()
+
+ msg = f"[{self.__class__.__name__}] enabled"
+ logging.debug(msg=msg)
+
+ # Replace the action space with the squashed action space.
+ # Note: we assume the entire action space is bounded and there are no +-inf.
+ # Note: we assume the entire action space is composed by floats (no bools, etc).
+
+ import copy
+
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ # First make a copy of the PyTree
+ self.action_space = copy.deepcopy(self.env.action_space)
+ # self.action_space = jax.tree_util.tree_map(
+ # lambda l: l, self.env.action_space
+ # )
+
+ # The update both the low and high bounds
+ self.action_space.low = jax.tree_util.tree_map(
+ lambda l: -1.0 * jnp.ones_like(l), self.env.action_space.low
+ )
+ self.action_space.high = jax.tree_util.tree_map(
+ lambda l: 1.0 * jnp.ones_like(l), self.env.action_space.high
+ )
+
+ def action(self, action: WrapperActType) -> ActType:
+ """"""
+
+ return self.unsquash(
+ pytree=action,
+ low=self.env.action_space.low,
+ high=self.env.action_space.high,
+ )
+
+ @staticmethod
+ def squash(pytree: jtp.PyTree, low: jtp.PyTree, high: jtp.PyTree) -> jtp.PyTree:
+ """"""
+
+ pytree_squashed = jax.tree_util.tree_map(
+ lambda leaf, l, h: 2 * (leaf - l) / (h - l) - 1,
+ pytree,
+ low,
+ high,
+ )
+
+ return pytree_squashed
+
+ @staticmethod
+ def unsquash(pytree: jtp.PyTree, low: jtp.PyTree, high: jtp.PyTree) -> jtp.PyTree:
+ """"""
+
+ pytree_unsquashed = jax.tree_util.tree_map(
+ lambda leaf, l, h: (leaf + 1) * (h - l) / 2 + l,
+ pytree,
+ low,
+ high,
+ )
+
+ return pytree_unsquashed
diff --git a/src/jaxgym/wrappers/jax/time_limit.py b/src/jaxgym/wrappers/jax/time_limit.py
new file mode 100644
index 000000000..8c69eb74b
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/time_limit.py
@@ -0,0 +1,159 @@
+from typing import Any, ClassVar, Generic
+
+import jax.lax
+import jax.numpy as jnp
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.functional import FuncEnv
+from jaxgym.wrappers import StateWrapper
+from jaxsim import logging
+
+WrapperStateType = dict[str, Any]
+WrapperObsType = ObsType
+WrapperActType = ActType
+WrapperRewardType = RewardType
+
+
+@jax_dataclasses.pytree_dataclass
+class TimeLimit(
+ StateWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ WrapperStateType,
+ ],
+ Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType],
+):
+ """"""
+
+ env: FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType]
+ max_episode_steps: int = jax_dataclasses.static_field()
+
+ ElapsedStepsKey: ClassVar[str] = "elapsed_steps"
+ EnvironmentStateKey: ClassVar[str] = "env"
+
+ def __post_init__(self) -> None:
+ """"""
+
+ # TODO assert >=1?
+ msg = f"[{self.__class__.__name__}] max_episode_steps={self.max_episode_steps}"
+ logging.debug(msg=msg)
+
+ def wrapper_state_to_environment_state(
+ self, wrapper_state: WrapperStateType
+ ) -> StateType:
+ """"""
+
+ return wrapper_state[TimeLimit.EnvironmentStateKey]
+
+ def initial(self, rng: Any = None) -> WrapperStateType:
+ """"""
+
+ environment_state = self.env.initial(rng=rng)
+
+ return {
+ TimeLimit.EnvironmentStateKey: environment_state,
+ TimeLimit.ElapsedStepsKey: 0,
+ }
+
+ def transition(
+ self, state: WrapperStateType, action: WrapperActType, rng: Any = None
+ ) -> WrapperStateType:
+ """"""
+
+ elapsed_steps = state[TimeLimit.ElapsedStepsKey]
+ elapsed_steps += 1
+
+ # print("+++")
+ # print(state)
+ # print(self.wrapper_state_to_environment_state(wrapper_state=state))
+
+ environment_state = self.env.transition(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state),
+ action=action,
+ rng=rng,
+ )
+
+ return {
+ TimeLimit.EnvironmentStateKey: environment_state,
+ TimeLimit.ElapsedStepsKey: elapsed_steps,
+ }
+
+ def step_info(
+ self, state: WrapperStateType, action: ActType, next_state: WrapperStateType
+ ) -> dict[str, Any]:
+ """"""
+
+ # Get the step info from the environment
+ info = self.env.step_info(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state),
+ action=action,
+ next_state=self.wrapper_state_to_environment_state(
+ wrapper_state=next_state
+ ),
+ )
+
+ # assert "truncated" not in info # gymnasium
+
+ # TODO: make a specific wrapper for stable baselines?
+ # 1. add TimeLimit.truncated
+ # 2. add terminal_observation
+ # 3. step_dict -> list[step_dict]
+ # 4. all to numpy
+ # assert "TimeLimit.truncated" not in info # stable-baselines3
+ # TODO: in stable-baselines -> truncated and terminated are mutually exclusive
+
+ # Activate the truncation flag if the episode is over
+ truncated = jnp.array(
+ next_state[TimeLimit.ElapsedStepsKey] >= self.max_episode_steps, dtype=bool
+ )
+
+ # If max_episode_steps=0, this wrapper is a no-op
+ truncated = truncated if self.max_episode_steps != 0 else False
+
+ # Check if any other wrapper already truncated the environment
+ truncated = jnp.logical_or(truncated, info.get("truncated", False))
+ # truncated = jax.lax.select(
+ # pred="truncated" in info,
+ # on_true=jnp.logical_or(truncated, info["truncated"]),
+ # on_false=truncated,
+ # )
+
+ # Handle the case in which the environment has been truncated and is done
+ truncated = jax.lax.select(
+ pred=self.terminal(state=next_state),
+ on_true=False,
+ on_false=truncated,
+ )
+
+ # Return the extended step info
+ # return info | dict(truncated=truncated)
+ return (
+ info
+ | dict(truncated=truncated)
+ | {"TimeLimit.truncated": truncated}
+ # the following has to be done in the vecwrapper because other wrappers
+ # might have changed the observation structure:
+ # | dict(terminal_observation=self.observation(state=next_state))
+ )
+
+
+# @jax.jit
+# def has_field(d) -> bool:
+# import jax.lax
+# return jax.lax.select(
+# pred="f" in d,
+# on_true=True,
+# on_false=False,
+# )
diff --git a/src/jaxgym/wrappers/jax/time_limit_sb.py b/src/jaxgym/wrappers/jax/time_limit_sb.py
new file mode 100644
index 000000000..285c8006d
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/time_limit_sb.py
@@ -0,0 +1,109 @@
+from typing import Any, Generic
+
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.functional import FuncEnv
+from jaxgym.jax import JaxDataclassWrapper
+from jaxsim import logging
+
+WrapperStateType = StateType
+WrapperObsType = ObsType
+WrapperActType = ActType
+WrapperRewardType = RewardType
+
+
+# NON USARE -> DIRETTAMENTE FATTO IN TimeLimit e bon
+
+
+@jax_dataclasses.pytree_dataclass
+class TimeLimitStableBaselines(
+ JaxDataclassWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ],
+ Generic[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ ],
+):
+ """"""
+
+ env: FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType]
+
+ def __post_init__(self) -> None:
+ """"""
+
+ msg = f"[{self.__class__.__name__}] enabled"
+ logging.debug(msg=msg)
+
+ def step_info(
+ self, state: WrapperStateType, action: ActType, next_state: WrapperStateType
+ ) -> dict[str, Any]:
+ """"""
+
+ # Get the step info from the environment
+ info = self.env.step_info(state=state, action=action, next_state=next_state)
+
+ return info | {"TimeLimit.truncated": info.get("truncated", False)}
+
+ # has_truncated_key = jax.lax.select(
+ # pred="TimeLimit.truncated" in info, on_true=True, on_false=False
+ # )
+ #
+ # return jax.lax.select(
+ # # pred=jnp.array([]).all(),
+ # pred=info.get("truncated", False),
+ # on_true=info | {"TimeLimit.truncated": True},
+ # on_false=info | {"TimeLimit.truncated": False},
+ # )
+
+ # assert "truncated" not in info # gymnasium
+ #
+ # # TODO: make a specific wrapper for stable baselines?
+ # # 1. add TimeLimit.truncated
+ # # 2. add terminal_observation
+ # # 3. step_dict -> list[step_dict]
+ # # 4. all to numpy
+ # assert "TimeLimit.truncated" not in info # stable-baselines3
+ # # TODO: in stable-baselines -> truncated and terminated are mutually exclusive
+ #
+ # # Activate the truncation flag if the episode is over
+ # truncated = jnp.array(
+ # next_state[TimeLimit.ElapsedStepsKey] >= self.max_episode_steps, dtype=bool
+ # )
+ #
+ # # If max_episode_steps=0, this wrapper is a no-op
+ # truncated = truncated if self.max_episode_steps != 0 else False
+ #
+ # # Return the extended step info
+ # return info | dict(truncated=truncated) | {"TimeLimit.truncated": truncated}
+
+
+# @jax.jit
+# def has_field(d) -> bool:
+# import jax.lax
+# return jax.lax.select(
+# pred="f" in d,
+# on_true=True,
+# on_false=False,
+# )
diff --git a/src/jaxgym/wrappers/jax/to_numpy.py b/src/jaxgym/wrappers/jax/to_numpy.py
new file mode 100644
index 000000000..610d018bb
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/to_numpy.py
@@ -0,0 +1,157 @@
+from typing import Any
+
+import jax.numpy as jnp
+import jax.tree_util
+import jax_dataclasses
+import numpy as np
+import numpy.typing as npt
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.jax import JaxDataclassWrapper
+from jaxsim import logging
+
+WrapperStateType = StateType
+WrapperObsType = npt.NDArray
+WrapperActType = npt.NDArray
+WrapperRewardType = float
+
+
+@jax_dataclasses.pytree_dataclass
+class ToNumPyWrapper(
+ JaxDataclassWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ]
+):
+ # class ToNumPyWrapper(
+ # FuncWrapper[
+ # #
+ # StateType,
+ # ObsType,
+ # ActType,
+ # RewardType,
+ # TerminalType,
+ # RenderStateType,
+ # #
+ # WrapperStateType,
+ # WrapperObsType,
+ # WrapperActType,
+ # WrapperRewardType,
+ # ],
+ # Generic[
+ # StateType,
+ # ObsType,
+ # ActType,
+ # RewardType,
+ # TerminalType,
+ # RenderStateType,
+ # ],
+ # ):
+ """"""
+
+ # def __init__(
+ # self,
+ # env: FuncEnv[
+ # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ # ],
+ # ):
+ # """"""
+ #
+ # self.env = env
+ # assert isinstance(self.env.action_space.sample(), np.ndarray)
+ # assert isinstance(self.env.observation_space.sample(), np.ndarray)
+
+ def __post_init__(self) -> None:
+ """"""
+
+ super().__post_init__()
+
+ msg = f"[{self.__class__.__name__}] enabled"
+ logging.debug(msg=msg)
+
+ assert isinstance(self.env.action_space.sample(), np.ndarray)
+ assert isinstance(self.env.observation_space.sample(), np.ndarray)
+
+ # def transition(
+ # self, state: WrapperStateType, action: WrapperActType, rng: Any = None
+ # ) -> WrapperStateType:
+ # """"""
+ #
+ # return ToNumPyWrapper.pytree_to_numpy(
+ # self.env.transition(state=state, action=action, rng=rng)
+ # )
+
+ def observation(self, state: WrapperStateType) -> WrapperObsType:
+ """"""
+
+ observation = ToNumPyWrapper.pytree_to_numpy(self.env.observation(state=state))
+ return np.array(observation, dtype=self.env.observation_space.dtype)
+
+ def reward(
+ self,
+ state: WrapperStateType,
+ action: WrapperActType,
+ next_state: WrapperStateType,
+ ) -> WrapperRewardType:
+ """"""
+
+ return float(
+ ToNumPyWrapper.pytree_to_numpy(
+ self.env.reward(state=state, action=action, next_state=next_state)
+ )
+ )
+
+ def terminal(self, state: WrapperStateType) -> TerminalType:
+ """"""
+
+ return ToNumPyWrapper.pytree_to_numpy(self.env.terminal(state=state))
+
+ def state_info(self, state: WrapperStateType) -> dict[str, Any]:
+ """"""
+
+ return ToNumPyWrapper.pytree_to_numpy(self.env.state_info(state=state))
+
+ def step_info(
+ self,
+ state: WrapperStateType,
+ action: WrapperActType,
+ next_state: WrapperStateType,
+ ) -> dict[str, Any]:
+ """"""
+
+ return ToNumPyWrapper.pytree_to_numpy(
+ self.env.step_info(state=state, action=action, next_state=next_state)
+ )
+
+ @staticmethod
+ def pytree_to_numpy(pytree: Any) -> Any:
+ """"""
+
+ def convert_leaf(leaf: Any) -> Any:
+ """"""
+
+ if (
+ isinstance(leaf, (np.ndarray, jnp.ndarray))
+ and leaf.size == 1
+ and leaf.dtype == "bool"
+ ):
+ return bool(leaf)
+
+ return np.array(leaf)
+
+ return jax.tree_util.tree_map(lambda l: convert_leaf(l), pytree)
diff --git a/src/jaxgym/wrappers/jax/transform.py b/src/jaxgym/wrappers/jax/transform.py
new file mode 100644
index 000000000..3eae512a0
--- /dev/null
+++ b/src/jaxgym/wrappers/jax/transform.py
@@ -0,0 +1,69 @@
+import dataclasses
+from typing import Callable
+
+import jax_dataclasses
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper
+from jaxgym.wrappers import TransformWrapper
+from jaxsim import logging
+from jaxsim.utils import JaxsimDataclass, Mutability
+
+
+# TODO: Make Jit and Vmap explicit wrappers so that we can check that env is JaxDataClass?
+@jax_dataclasses.pytree_dataclass
+# class JaxTransformWrapper(TransformWrapper, JaxsimDataclass):
+class JaxTransformWrapper(TransformWrapper, JaxDataclassWrapper):
+ """"""
+
+ # env: JaxDataclassEnv[
+ # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ # ]
+
+ function: Callable[[Callable], Callable] = jax_dataclasses.static_field()
+
+ transform_initial: dataclasses.InitVar[bool] = dataclasses.field(default=True)
+ transform_transition: dataclasses.InitVar[bool] = dataclasses.field(default=True)
+ transform_observation: dataclasses.InitVar[bool] = dataclasses.field(default=True)
+ transform_reward: dataclasses.InitVar[bool] = dataclasses.field(default=True)
+ transform_terminal: dataclasses.InitVar[bool] = dataclasses.field(default=True)
+ transform_state_info: dataclasses.InitVar[bool] = dataclasses.field(default=True)
+ transform_step_info: dataclasses.InitVar[bool] = dataclasses.field(default=True)
+
+ def __post_init__( # noqa
+ self,
+ transform_initial: bool,
+ transform_transition: bool,
+ transform_observation: bool,
+ transform_reward: bool,
+ transform_terminal: bool,
+ transform_state_info: bool,
+ transform_step_info: bool,
+ ) -> None:
+ """"""
+
+ JaxDataclassWrapper.__post_init__(self)
+
+ msg = f"[{self.__class__.__name__}] function={self.function}"
+ logging.debug(msg=msg)
+
+ with self.mutable_context(mutability=Mutability.MUTABLE):
+ # super().__post_init__(
+ super().__init__(
+ env=self.env,
+ function=self.function,
+ transform_initial=transform_initial,
+ transform_transition=transform_transition,
+ transform_observation=transform_observation,
+ transform_reward=transform_reward,
+ transform_terminal=transform_terminal,
+ transform_state_info=transform_state_info,
+ transform_step_info=transform_step_info,
+ )
diff --git a/src/jaxgym/wrappers/state.py b/src/jaxgym/wrappers/state.py
new file mode 100644
index 000000000..223986126
--- /dev/null
+++ b/src/jaxgym/wrappers/state.py
@@ -0,0 +1,132 @@
+import abc
+from typing import Any, Generic, TypeVar
+
+import numpy.typing as npt
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.functional import FuncWrapper
+
+WrapperStateType = TypeVar("WrapperStateType")
+WrapperObsType = ObsType
+WrapperActType = ActType
+WrapperRewardType = RewardType
+
+
+class StateWrapper(
+ FuncWrapper[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ],
+ Generic[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ WrapperStateType,
+ ],
+ abc.ABC,
+):
+ """"""
+
+ @abc.abstractmethod
+ def wrapper_state_to_environment_state(
+ self, wrapper_state: WrapperStateType
+ ) -> StateType:
+ """"""
+
+ pass
+
+ @abc.abstractmethod
+ def initial(self, rng: Any = None) -> WrapperStateType:
+ """"""
+
+ pass
+
+ @abc.abstractmethod
+ def transition(
+ self, state: WrapperStateType, action: WrapperActType, rng: Any = None
+ ) -> WrapperStateType:
+ """"""
+
+ pass
+
+ #
+ #
+ #
+
+ def observation(self, state: WrapperStateType) -> WrapperObsType:
+ """"""
+
+ return self.env.observation(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state)
+ )
+
+ def reward(
+ self,
+ state: WrapperStateType,
+ action: WrapperActType,
+ next_state: WrapperStateType,
+ ) -> WrapperRewardType:
+ """"""
+
+ return self.env.reward(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state),
+ action=action,
+ next_state=self.wrapper_state_to_environment_state(
+ wrapper_state=next_state
+ ),
+ )
+
+ def terminal(self, state: WrapperStateType) -> TerminalType:
+ """"""
+
+ return self.env.terminal(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state)
+ )
+
+ def state_info(self, state: WrapperStateType) -> dict[str, Any]:
+ """Info dict about a single state."""
+
+ return self.env.state_info(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state)
+ )
+
+ def step_info(
+ self, state: WrapperStateType, action: ActType, next_state: WrapperStateType
+ ) -> dict[str, Any]:
+ """Info dict about a full transition."""
+
+ return self.env.step_info(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state),
+ action=action,
+ next_state=self.wrapper_state_to_environment_state(
+ wrapper_state=next_state
+ ),
+ )
+
+ def render_image(
+ self, state: StateType, render_state: RenderStateType
+ ) -> tuple[RenderStateType, npt.NDArray]:
+ """Render the state."""
+
+ return self.env.render_image(
+ state=self.wrapper_state_to_environment_state(wrapper_state=state),
+ render_state=render_state,
+ )
diff --git a/src/jaxgym/wrappers/transform.py b/src/jaxgym/wrappers/transform.py
new file mode 100644
index 000000000..d568204a2
--- /dev/null
+++ b/src/jaxgym/wrappers/transform.py
@@ -0,0 +1,125 @@
+from typing import Callable, Generic
+
+from gymnasium.experimental.functional import (
+ ActType,
+ ObsType,
+ RenderStateType,
+ RewardType,
+ StateType,
+ TerminalType,
+)
+
+from jaxgym.functional import FuncEnv, FuncWrapper
+
+WrapperStateType = StateType
+WrapperObsType = ObsType
+WrapperActType = ActType
+WrapperRewardType = RewardType
+
+
+class TransformWrapper(
+ FuncWrapper[
+ #
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ #
+ WrapperStateType,
+ WrapperObsType,
+ WrapperActType,
+ WrapperRewardType,
+ ],
+ Generic[
+ StateType,
+ ObsType,
+ ActType,
+ RewardType,
+ TerminalType,
+ RenderStateType,
+ ],
+):
+ """"""
+
+ def __init__(
+ self,
+ env: FuncEnv[
+ StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ ],
+ function: Callable[[Callable], Callable] = lambda f: f,
+ transform_initial: bool = True,
+ transform_transition: bool = True,
+ transform_observation: bool = True,
+ transform_reward: bool = True,
+ transform_terminal: bool = True,
+ transform_state_info: bool = True,
+ transform_step_info: bool = True,
+ ):
+ """"""
+
+ self.env = env
+
+ # Here to show up in repr
+ self.function = function
+
+ if transform_initial:
+ self.initial = function(self.initial)
+
+ if transform_transition:
+ self.transition = function(self.transition)
+
+ if transform_observation:
+ self.observation = function(self.observation)
+
+ if transform_reward:
+ self.reward = function(self.reward)
+
+ if transform_terminal:
+ self.terminal = function(self.terminal)
+
+ if transform_state_info:
+ self.state_info = function(self.state_info)
+
+ if transform_step_info:
+ self.step_info = function(self.step_info)
+
+ # @staticmethod
+ # def transform_env(
+ # env: FuncEnv[
+ # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ # ],
+ # function: Callable[[Callable], Callable],
+ # transform_initial: bool = True,
+ # transform_transition: bool = True,
+ # transform_observation: bool = True,
+ # transform_reward: bool = True,
+ # transform_terminal: bool = True,
+ # transform_state_info: bool = True,
+ # transform_step_info: bool = True,
+ # ) -> FuncWrapper[
+ # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ # ]:
+ # """"""
+ #
+ # if transform_initial:
+ # self.initial = function(self.initial)
+ #
+ # if transform_transition:
+ # self.transition = function(self.transition)
+ #
+ # if transform_observation:
+ # self.observation = function(self.observation)
+ #
+ # if transform_reward:
+ # self.reward = function(self.reward)
+ #
+ # if transform_terminal:
+ # self.terminal = function(self.terminal)
+ #
+ # if transform_state_info:
+ # self.state_info = function(self.state_info)
+ #
+ # if transform_step_info:
+ # self.step_info = function(self.step_info)
diff --git a/src/jaxsim/simulation/integrators.py b/src/jaxsim/simulation/integrators.py
index 71ab42c48..450c7bedc 100644
--- a/src/jaxsim/simulation/integrators.py
+++ b/src/jaxsim/simulation/integrators.py
@@ -338,7 +338,7 @@ def odeint_euler(
t: TimeHorizon,
*args,
num_sub_steps: int = 1,
- return_aux: bool = False
+ return_aux: bool = False,
) -> Union[State, Tuple[State, Dict[str, Any]]]:
"""
Integrate a system of ODEs using the Euler method.
@@ -378,7 +378,7 @@ def odeint_euler_semi_implicit(
t: TimeHorizon,
*args,
num_sub_steps: int = 1,
- return_aux: bool = False
+ return_aux: bool = False,
) -> Union[State, Tuple[State, Dict[str, Any]]]:
"""
Integrate a system of ODEs using the Semi-Implicit Euler method.
@@ -418,7 +418,7 @@ def odeint_rk4(
t: TimeHorizon,
*args,
num_sub_steps: int = 1,
- return_aux: bool = False
+ return_aux: bool = False,
) -> Union[State, Tuple[State, Dict[str, Any]]]:
"""
Integrate a system of ODEs using the Runge-Kutta 4 method.
diff --git a/src/jaxsim/training/__init__.py b/src/jaxsim/training/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/jaxsim/training/agent.py b/src/jaxsim/training/agent.py
new file mode 100644
index 000000000..3fc9edb40
--- /dev/null
+++ b/src/jaxsim/training/agent.py
@@ -0,0 +1,791 @@
+import copy
+import dataclasses
+import datetime
+import functools
+import pathlib
+from typing import Any, Callable, Dict, Tuple, Union
+
+import flax.training.checkpoints
+import flax.training.train_state
+import gym.spaces
+import jax
+import jax.experimental.loops
+import jax.numpy as jnp
+import jax_dataclasses
+import optax
+from flax.core.frozen_dict import FrozenDict
+from optax._src.alias import ScalarOrSchedule
+
+import jaxsim.typing as jtp
+from jaxsim import logging
+from jaxsim.utils import JaxsimDataclass
+
+from .memory import DataLoader, Memory
+from .networks import ActorCriticNetworks, ActorNetwork, CriticNetwork
+
+
+@jax_dataclasses.pytree_dataclass
+class PPOParams:
+ # Gradient descent
+ alpha: float = jax_dataclasses.field(default=0.0003)
+ optimizer: Callable[
+ [ScalarOrSchedule], optax.GradientTransformation
+ ] = jax_dataclasses.static_field(default=optax.adam)
+
+ # RL Algorithm
+ gamma: float = jax_dataclasses.field(default=0.99)
+ lambda_gae: float = jax_dataclasses.field(default=0.95)
+
+ # PPO
+ # beta_kl: float = jax_dataclasses.static_field(default=0.0)
+ beta_kl: float = jax_dataclasses.field(default=0.0)
+ target_kl: float = jax_dataclasses.field(default=0.010)
+ epsilon_clip: float = jax_dataclasses.field(default=0.2)
+
+ # Other params
+ entropy_loss_weight: float = jax_dataclasses.field(default=0.0)
+
+
+class PPOTrainState(flax.training.train_state.TrainState):
+ beta_kl: float
+
+
+@dataclasses.dataclass(frozen=True)
+class CheckpointManager:
+ checkpoint_path: pathlib.Path
+
+ def save_best(
+ self,
+ train_state: PPOTrainState,
+ measure: str,
+ metric: Union[int, float],
+ keep: int = 1,
+ ) -> None:
+ checkpoint_path = self.checkpoint_path / f"best_{measure}"
+
+ path_to_checkpoint = flax.training.checkpoints.save_checkpoint(
+ ckpt_dir=checkpoint_path,
+ target=train_state,
+ step=metric,
+ prefix=f"checkpoint_",
+ keep=keep,
+ overwrite=True,
+ )
+
+ logging.info(msg=f"Saved checkpoint: {path_to_checkpoint}")
+
+ def save_latest(
+ self, train_state: PPOTrainState, keep_every_n_steps: int = None
+ ) -> None:
+ path_to_checkpoint = flax.training.checkpoints.save_checkpoint(
+ ckpt_dir=self.checkpoint_path,
+ target=train_state,
+ step=train_state.step,
+ prefix=f"checkpoint_",
+ keep=1,
+ keep_every_n_steps=keep_every_n_steps,
+ )
+
+ logging.info(msg=f"Saved checkpoint: {path_to_checkpoint}")
+
+ def load_best(
+ self, dummy_train_state: PPOTrainState, measure: str
+ ) -> PPOTrainState:
+ checkpoint_path = self.checkpoint_path / f"best_{measure}"
+
+ train_state = flax.training.checkpoints.restore_checkpoint(
+ ckpt_dir=checkpoint_path,
+ target=dummy_train_state,
+ prefix=f"checkpoint_",
+ )
+
+ return train_state
+
+ def load_latest(self, dummy_train_state: PPOTrainState) -> PPOTrainState:
+ train_state = flax.training.checkpoints.restore_checkpoint(
+ ckpt_dir=self.checkpoint_path,
+ target=dummy_train_state,
+ prefix=f"checkpoint_",
+ )
+
+ return train_state
+
+
+@jax_dataclasses.pytree_dataclass
+class Agent(JaxsimDataclass):
+ key: jax.random.PRNGKey = jax_dataclasses.field(
+ default_factory=lambda: jax.random.PRNGKey(seed=0), repr=False
+ )
+
+ params: PPOParams = jax_dataclasses.field(default_factory=PPOParams)
+ train_state: PPOTrainState = jax_dataclasses.field(default=None, repr=False)
+
+ action_space: gym.spaces.Box = jax_dataclasses.static_field(default=None)
+ observation_space: gym.spaces.Box = jax_dataclasses.static_field(default=None)
+
+ checkpoint_manager: CheckpointManager = jax_dataclasses.static_field(default=None)
+
+ _num_timesteps: int = jax_dataclasses.field(default=0)
+ _num_iterations: int = jax_dataclasses.field(default=0)
+
+ _min_reward: float = jnp.finfo(jnp.float32).max
+ _max_reward: float = jnp.finfo(jnp.float32).min
+
+ @staticmethod
+ def build(
+ actor: ActorNetwork = None,
+ critic: CriticNetwork = None,
+ action_space: gym.spaces.Space = None,
+ observation_space: gym.spaces.Space = None,
+ key: jax.random.PRNGKey = jax.random.PRNGKey(seed=0),
+ params: PPOParams = PPOParams(),
+ train_state: PPOTrainState = None, # TODO remove?
+ checkpoint_label: str = "training",
+ load_checkpoint_from_path: pathlib.Path = None,
+ ) -> "Agent":
+ date_str = datetime.datetime.now().strftime("%y%m%d_%H%M")
+ checkpoint_folder = f"{checkpoint_label}_{date_str}"
+
+ checkpoint_path = (
+ pathlib.Path("~/jaxsim_results").expanduser() / checkpoint_folder
+ )
+
+ agent = Agent(
+ key=key,
+ params=params,
+ train_state=train_state,
+ action_space=action_space,
+ observation_space=observation_space,
+ checkpoint_manager=CheckpointManager(checkpoint_path),
+ )
+
+ if agent.train_state is None:
+ # Generate a new network RNG key
+ key = agent.advance_key(num_sub_keys=1)
+
+ # Get a dummy observation
+ observation_dummy = jnp.zeros(shape=observation_space.shape)
+
+ # Create the actor/critic network
+ actor_critic = ActorCriticNetworks(actor=actor, critic=critic)
+
+ # Initialize the actor/critic network
+ out, params_critic = actor_critic.init_with_output(key, observation_dummy)
+ distribution_dummy, value_dummy = out
+
+ # Check value shape
+ if value_dummy.shape != (1,):
+ raise ValueError(value_dummy.shape, (1,))
+
+ # Check action shape
+ dummy_action = distribution_dummy.sample(seed=key)
+ if dummy_action.shape != action_space.shape:
+ raise ValueError(dummy_action.shape, action_space.shape)
+
+ # Initialize the actor/critic train state
+ train_state = PPOTrainState.create(
+ apply_fn=actor_critic.apply,
+ params=params_critic,
+ tx=agent.params.optimizer(agent.params.alpha),
+ beta_kl=agent.params.beta_kl,
+ )
+
+ with agent.editable(validate=False) as agent:
+ agent.train_state = train_state
+
+ # Replace the actor/critic train state
+ # agent = jax_dataclasses.replace(agent, train_state=train_state) # noqa
+
+ if load_checkpoint_from_path is not None:
+ load_checkpoint_from_path = (
+ load_checkpoint_from_path.expanduser().absolute()
+ )
+
+ if not load_checkpoint_from_path.exists():
+ raise FileExistsError(load_checkpoint_from_path)
+
+ with agent.editable(validate=False) as agent:
+ agent.train_state = flax.training.checkpoints.restore_checkpoint(
+ ckpt_dir=load_checkpoint_from_path,
+ target=agent.train_state,
+ prefix=f"checkpoint_",
+ )
+
+ return agent
+
+ # /tmp/jaxsim/cartpole_20220614_1756/checkpoint_#i
+ # /tmp/jaxsim/cartpole_20220614_1756/reward/checkpoint_REWARD
+ # /tmp/jaxsim/cartpole_20220614_1756/episode_steps/checkpoint_REWARD
+
+ def save_checkpoint(
+ self,
+ prefix: str = "checkpoint_",
+ checkpoint_path: pathlib.Path = pathlib.Path.cwd() / "checkpoints",
+ ) -> None:
+ path_to_checkpoint = flax.training.checkpoints.save_checkpoint(
+ target=self.train_state,
+ prefix=prefix,
+ ckpt_dir=checkpoint_path,
+ step=self.train_state.step,
+ )
+
+ logging.info(msg=f"Save checkpoint: {path_to_checkpoint}")
+
+ def load_checkpoint(
+ self, checkpoint_path: pathlib.Path = pathlib.Path.cwd() / "checkpoints"
+ ) -> "Agent":
+ train_state = flax.training.checkpoints.restore_checkpoint(
+ ckpt_dir=checkpoint_path, target=self.train_state
+ )
+
+ with self.editable(validate=True) as agent:
+ agent.train_state = train_state
+
+ agent._set_mutability(mutability=self._mutability())
+ return agent
+
+ # return jax_dataclasses.replace(self, train_state=train_state)
+
+ def advance_key(self, num_sub_keys: int = 1) -> jax.random.PRNGKey:
+ keys = jax.random.split(key=self.key, num=num_sub_keys + 1)
+
+ object.__setattr__(self, "key", keys[0])
+
+ return keys[1:].squeeze()
+
+ def advance_key2(self, num_sub_keys: int = 1) -> Tuple[jax.random.PRNGKey, "Agent"]:
+ keys = jax.random.split(key=self.key, num=num_sub_keys + 1)
+ return keys[1:].squeeze(), jax_dataclasses.replace(self, key=keys[0]) # noqa
+
+ def choose_action(
+ self,
+ observation: jtp.Vector,
+ explore: bool = True,
+ key: jax.random.PRNGKey = None,
+ ) -> Tuple[jtp.VectorJax, jtp.VectorJax, jtp.VectorJax]:
+ # Get a new key if not passed
+ key = key if key is not None else self.advance_key()
+
+ # Infer π_θ(⋅|sₜ) and V_ϕ(sₜ)
+ distribution, value = self.train_state.apply_fn(
+ self.train_state.params, data=observation
+ )
+
+ # Sample an action from π_θ(⋅|sₜ)
+ action = jax.lax.select(
+ pred=explore,
+ on_true=distribution.sample(seed=key),
+ on_false=distribution.mode(),
+ )
+
+ # Compute log-likelihood of action: log[π_θ(aₜ|sₜ)]
+ log_prob_action = distribution.log_prob(value=action)
+
+ return action, log_prob_action, value
+ # return (
+ # jnp.array(action).squeeze(),
+ # jnp.array(log_prob_action).squeeze(),
+ # jnp.array(value).squeeze(),
+ # )
+
+ @staticmethod
+ @functools.partial(jax.jit)
+ def estimate_advantage_gae_jit(
+ train_state: PPOTrainState, memory: Memory, gamma: float, lambda_gae: float
+ ) -> jtp.VectorJax:
+ # Closure that computes V(o). Used to boostrap the return when necessary.
+ value_of = lambda observation: train_state.apply_fn(
+ train_state.params, data=observation
+ )[1].squeeze()
+
+ return Agent.estimate_advantage_gae(
+ memory=memory, gamma=gamma, lambda_gae=lambda_gae, V=value_of
+ )
+
+ @staticmethod
+ def estimate_advantage_gae(
+ memory: Memory,
+ gamma: float,
+ lambda_gae: float,
+ V: Callable[[jtp.ArrayJax], jtp.ArrayJax] = lambda _: 0.0,
+ ) -> jtp.VectorJax:
+ r = memory.rewards
+ mask = 1 - memory.dones.astype(dtype=int)
+
+ # Extract additional data from the info dictionary
+ is_terminal = memory.infos["is_terminal"]
+ next_observations = memory.infos["terminal_observation"]
+
+ # The last trajectory in the memory is likely truncated, i.e. is_done[-1] = 0.
+ # In order to estimate Â, we get the next observation stored in the info dict
+ # and boostrap the return of the last observation of the trajectory with TD(0).
+ Vs = jnp.hstack([memory.values.squeeze(), V(next_observations[-1])])
+
+ with jax.experimental.loops.Scope() as s:
+ # Allocate the Âₜ array (adding a trailing zero, as Vs).
+ # We cannot know how the trajectory will continue from the truncated last
+ # trajectory of the memory. In this case, since we have computed the value
+ # of the next observation, we estimate A of the last sample using TD(0)
+ # returns, i.e. Âₕ = δₕ = rₕ + γ⋅V(sₕ₊₁) - V(sₕ).
+ # This can be done either setting Âₕ₊₁ = 0 or, equivalently, λₕ = 0.
+ # We proceed with the former case.
+ s.A = jnp.zeros_like(Vs)
+
+ # Iteration same as: reversed(mask.size)
+ for t in s.range(mask.size - 1, -1, -1):
+ # Classic TD(0) boostrap: δₜ = rₜ + γ⋅V(sₜ₊₁) - V(sₜ)
+ delta_t_not_done = r[t] + gamma * Vs[t + 1] - Vs[t]
+
+ # Use one step of Monte Carlo if terminal, TD(0) otherwise.
+ # This allows to consider two different types of termination:
+ # - MC: Reached a terminal state s_T. All the rewards following s_T
+ # are considered to be 0.
+ # - TD(0): The trajectory reached the maximum length, and it was
+ # truncated early (common in continuous control). We cannot
+ # assume that the reward would be 0 after truncation, therefore
+ # we boostrap the return with TD(0) (similarly to what we do
+ # for the last truncated trajectory of the memory).
+ delta_t_done = jax.lax.select(
+ pred=is_terminal.astype(dtype=bool)[t],
+ on_true=r[t] - Vs[t],
+ on_false=r[t] + gamma * V(next_observations[t]) - Vs[t],
+ )
+
+ # Select the right δₜ
+ delta_t = jax.lax.select(
+ pred=memory.dones.astype(dtype=bool)[t],
+ on_true=delta_t_done,
+ on_false=delta_t_not_done,
+ )
+
+ # Compute the advantage estimate Âₜ
+ A_t = delta_t + gamma * lambda_gae * s.A[t + 1] * mask[t]
+ s.A = s.A.at[t].set(A_t.squeeze())
+
+ # Remove the trailing value we added to handle the last entry of the memory
+ A = s.A[:-1]
+
+ return jax.lax.stop_gradient(jnp.vstack(A))
+
+ @staticmethod
+ def compute_reward_to_go(
+ memory: Memory,
+ gamma: float = 1.0,
+ V: Callable[[jtp.ArrayJax], jtp.ArrayJax] = lambda _: 0.0,
+ ) -> jtp.VectorJax:
+ assert memory.flat is True
+
+ r = memory.rewards.squeeze()
+ r_to_go = jnp.zeros_like(r)
+
+ dones = memory.dones.astype(dtype=bool).squeeze()
+ is_terminal = memory.infos["is_terminal"].astype(dtype=bool).squeeze()
+ next_observations = memory.infos["terminal_observation"].squeeze()
+
+ with jax.experimental.loops.Scope() as s:
+ # Approximate the return of the state following the last one with its value.
+ # Note: we store next_observation in the info dict for this reason.
+ s.r_to_go = jnp.hstack([r_to_go, V(next_observations[-1])])
+
+ # Iteration same as: reversed(dones.size)
+ for t in s.range(dones.size - 1, -1, -1):
+ # Monte Carlo discounted accumulation.
+ # Note: considering r_to_go[-1], this is TD(0) for the last memory entry.
+ r_to_go_not_done = r[t] + gamma * s.r_to_go[t + 1]
+
+ # If done, use one step of Monte Carlo if terminal, TD(0) otherwise
+ r_to_go_done = jax.lax.select(
+ pred=is_terminal[t],
+ on_true=r[t],
+ on_false=r[t] + gamma * V(next_observations[t]),
+ )
+
+ # Select the right Rₜ
+ r_to_go_t = jax.lax.select(
+ pred=dones[t],
+ on_true=r_to_go_done,
+ on_false=r_to_go_not_done,
+ )
+
+ # Store Rₜ in the buffer
+ s.r_to_go = s.r_to_go.at[t].set(r_to_go_t.squeeze())
+
+ # Remove the trailing value we added to handle the last entry of the memory
+ r_to_go = s.r_to_go[:-1]
+
+ return jax.lax.stop_gradient(jnp.vstack(r_to_go))
+
+ @staticmethod
+ @functools.partial(jax.jit)
+ def compute_reward_to_go_jit(
+ train_state: PPOTrainState, memory: Memory, gamma: float
+ ) -> jtp.VectorJax:
+ # Closure that computes V(o). Used to boostrap the return when necessary.
+ value_of = lambda observation: train_state.apply_fn(
+ train_state.params, data=observation
+ )[1].squeeze()
+
+ return Agent.compute_reward_to_go(memory=memory, gamma=gamma, V=value_of)
+
+ @staticmethod
+ def explained_variance(y_hat: jtp.Array, y: jtp.Array) -> jtp.Array:
+ assert y_hat.ndim == y.ndim == 1
+
+ var_y = jnp.var(y)
+
+ return jax.lax.select(
+ pred=(var_y == 0.0),
+ on_true=jnp.nan,
+ on_false=(1 - jnp.var(y - y_hat) / var_y),
+ )
+
+ @staticmethod
+ @functools.partial(jax.jit)
+ def train_actor_critic(
+ train_state_target: PPOTrainState,
+ train_state_behavior: PPOTrainState,
+ memory: Memory,
+ returns_target: jtp.VectorJax,
+ policy_gradient_loss_weight: jtp.VectorJax,
+ ppo_params: PPOParams = PPOParams(),
+ ) -> Tuple[PPOTrainState, Dict]:
+ # Adjust 1D arrays
+ returns_target = returns_target.squeeze()
+ policy_gradient_loss_weight = policy_gradient_loss_weight.squeeze()
+
+ # Assume memory has vertical 1D arrays
+ mem_values = jnp.vstack(memory.values.squeeze())
+ mem_actions = jnp.vstack(memory.actions.squeeze())
+ mem_observations = jnp.vstack(memory.states.squeeze())
+ mem_log_prob_actions = jnp.vstack(memory.log_prob_actions.squeeze())
+
+ # Loss function for both the actor and critic networks.
+ # Note: we do not support sharing layers, therefore weighting differently
+ # the two losses should not be relevant.
+ def loss_fn(params: flax.core.FrozenDict[str, Any]) -> Tuple[float, Dict]:
+ # Infer π_θₙ(⋅|sₜ) and V_ϕₙ(sₜ) with the new parameters (θₙ, ϕₙ)
+ new_distributions, new_values = train_state_target.apply_fn(
+ params,
+ data=mem_observations,
+ )
+
+ # ======
+ # Critic
+ # ======
+
+ # The loss uses the returns as targets. Returns could be computed with
+ # a (possibly discounted) reward-to-go or from the estimated advantage.
+
+ # Compute the MSE loss
+ new_values = new_values.squeeze()
+ critic_loss = jnp.linalg.norm(new_values - returns_target, ord=2)
+ # TODO: should this be norm squared?
+
+ # Compute the explained variance. It should start as a very negative number
+ # and progressively converge towards 1.0 when the value function learned to
+ # approximate correctly the sampled return.
+ returns = policy_gradient_loss_weight.flatten() + mem_values.flatten()
+ explained_variance = Agent.explained_variance(
+ y=returns, y_hat=mem_values.flatten()
+ )
+
+ # =====
+ # Actor
+ # =====
+
+ # Refer to https://arxiv.org/abs/1707.06347 for the surrogate functions
+ # used for both the CLIP and the KLPEN versions of PPO.
+
+ # Infer π_θₒ(⋅|sₜ) with the old parameters θₒ. Used only for KLPEN.
+ old_distributions, old_values = train_state_behavior.apply_fn(
+ train_state_behavior.params,
+ data=mem_observations,
+ )
+
+ # Compute new log-likelihood of actions: log[π_θₙ(aₜ|sₜ)]
+ new_log_prob_actions = new_distributions.log_prob(value=mem_actions)
+
+ # Rename the old log-likelihood of actions: log[π_θₒ(aₜ|sₜ)]
+ old_log_prob_actions = mem_log_prob_actions.squeeze()
+
+ # Compute the ratio rₜ(θₙ) of the likelihoods
+ prob_action_ratio = jnp.exp(new_log_prob_actions - old_log_prob_actions)
+
+ # Compute the CPI surrogate objective
+ L = prob_action_ratio * policy_gradient_loss_weight
+
+ # Compute the clipped version of the ratio of the likelihoods
+ prob_action_ratio_clipped = jnp.clip(
+ a=prob_action_ratio,
+ a_min=(1.0 - ppo_params.epsilon_clip),
+ a_max=(1.0 + ppo_params.epsilon_clip),
+ )
+
+ # Compute the clip ratio
+ clipped_elements = jnp.where(
+ prob_action_ratio != prob_action_ratio_clipped, 1, 0
+ )
+ clip_ratio = clipped_elements.sum() / clipped_elements.size
+
+ # Apply the CLIP surrogate objective.
+ # Note: 'epsilon_clip' is zero if CLIP is not enabled.
+ L = jax.lax.select(
+ pred=(ppo_params.epsilon_clip == 0),
+ on_true=L,
+ on_false=jnp.minimum(
+ L, prob_action_ratio_clipped * policy_gradient_loss_weight
+ ),
+ ).mean()
+
+ # Compute the additional KLPEN surrogate objective term.
+ # Note: 'beta_kl' is zero if KLPEN is not enabled.
+ distr_kl = old_distributions.kl_divergence(other_dist=new_distributions)
+ ppo_klpen_term = train_state_behavior.beta_kl * distr_kl.mean()
+
+ # Apply the KLPEN surrogate objective term
+ L -= ppo_klpen_term
+
+ # Compute the loss to minimize from the surrogate objective
+ actor_loss = -L
+
+ # Optional entropy reward
+ entropy_mean = new_distributions.entropy().mean()
+ entropy_reward = ppo_params.entropy_loss_weight * entropy_mean
+
+ return (
+ actor_loss - entropy_reward + 0.100 * critic_loss,
+ dict(
+ actor_loss=actor_loss,
+ critic_loss=critic_loss,
+ entropy=entropy_mean,
+ entropy_reward=entropy_reward,
+ kl=distr_kl.mean(),
+ beta_kl=train_state_behavior.beta_kl,
+ clip_ratio=clip_ratio,
+ explained_variance=explained_variance,
+ ),
+ )
+
+ # Commented-out code to check gradients wrt finite differences
+ # from jax.test_util import check_grads
+ # check_grads(loss_fn, (train_state_target.params,), order=1, eps=1e-4)
+
+ # Compute the loss and its gradient wrt the NN parameters
+ (total_loss, loss_fn_data), grads = jax.value_and_grad(
+ fun=loss_fn, has_aux=True
+ )(train_state_target.params)
+
+ # Pass the gradient to the optimizer and get a new state
+ new_train_state = train_state_target.apply_gradients(grads=grads)
+
+ return new_train_state, dict(total_loss=total_loss, **loss_fn_data)
+
+ @staticmethod
+ @functools.partial(jax.jit)
+ def adaptive_update_beta_ppo_kl_pen(
+ train_state_behavior: PPOTrainState,
+ train_state_target: PPOTrainState,
+ ppo_params: PPOParams,
+ memory: Memory,
+ ) -> PPOTrainState:
+ # Refer to https://arxiv.org/abs/1707.06347 (Sec. 4) for the update rule of β_KL.
+ # We use the default heuristic 1.5 and 2 parameters as reported in the paper.
+
+ # Infer π_θₙ(⋅|sₜ) with the new parameters
+ new_distributions, _ = train_state_target.apply_fn(
+ train_state_target.params, data=memory.states
+ )
+
+ # Infer π_θₒ(⋅|sₜ) with the old parameters
+ old_distributions, _ = train_state_behavior.apply_fn(
+ train_state_behavior.params, data=memory.states
+ )
+
+ # Compute the KL divergence of the old policy from the new policy
+ kl = old_distributions.kl_divergence(other_dist=new_distributions).mean()
+
+ # Get the old β_KL used as weight of PPO-KLPEN surrogate objective
+ beta_kl = train_state_behavior.beta_kl
+
+ # Increase β_KL
+ train_state_target = jax.lax.cond(
+ pred=(kl > ppo_params.target_kl * 1.5),
+ true_fun=lambda _: train_state_target.replace(beta_kl=beta_kl * 2.0),
+ false_fun=lambda _: train_state_target,
+ operand=(),
+ )
+
+ # Decrease β_KL
+ train_state_target = jax.lax.cond(
+ pred=(kl < ppo_params.target_kl / 1.5),
+ true_fun=lambda _: train_state_target.replace(beta_kl=beta_kl * 0.5),
+ false_fun=lambda _: train_state_target,
+ operand=(),
+ )
+
+ return train_state_target
+
+ def train(
+ self,
+ memory: Memory,
+ num_epochs: int = 1,
+ batch_size: int = 512,
+ print_report: bool = False,
+ ) -> "Agent":
+ # Truncate the last trajectory by modifying the last is_done. In this way we can
+ # bootstrap correctly its return considering the last sample as non-terminal.
+ # Also flatten the memory if it was sampled from parallel environments.
+ memory_flat = memory.truncate_last_trajectory().flatten()
+
+ mean_reward = memory_flat.rewards.mean()
+ self.checkpoint_manager.save_best(
+ train_state=self.train_state, measure="reward", metric=mean_reward, keep=5
+ )
+
+ self.checkpoint_manager.save_latest(
+ train_state=self.train_state, keep_every_n_steps=25
+ )
+
+ # Update the training metadata
+ with self.editable(validate=True) as agent:
+ agent._num_iterations += 1
+ agent._num_timesteps += len(memory)
+
+ # Update the training metadata
+ # agent = jax_dataclasses.replace(
+ # self, # noqa
+ # _num_iterations=(self._num_iterations + 1),
+ # _num_timesteps=(self._num_timesteps + len(memory)),
+ # )
+
+ # Log the min/max reward ever seen
+ min_reward = jnp.array([agent._min_reward, memory.rewards.min()]).min()
+ max_reward = jnp.array([agent._max_reward, memory.rewards.max()]).max()
+ agent = jax_dataclasses.replace(
+ agent, _min_reward=min_reward, _max_reward=max_reward # noqa
+ )
+
+ # Estimate the advantages Â_π_θₒ(sₜ, aₜ) with GAE used to train the actor.
+ advantages = Agent.estimate_advantage_gae_jit(
+ train_state=agent.train_state,
+ memory=memory_flat,
+ gamma=agent.params.gamma,
+ lambda_gae=agent.params.lambda_gae,
+ )
+
+ # Compute the rewards-to-go R̂ₜ used to train the critic.
+ # Note: same of estimating Â_GAE with λ=1.0 and γ=1.0.
+ rewards_to_go = Agent.compute_reward_to_go_jit(
+ train_state=agent.train_state,
+ memory=memory_flat,
+ gamma=agent.params.gamma,
+ )
+
+ # Select the returns to use. We can use any of the following:
+ # - Rewards-to-go: R = R̂ₜ
+ # - GAE advantages plus values: R = Â_GAE + V
+ # returns = rewards_to_go
+ returns = advantages + memory_flat.values
+
+ # Store the behavior train state (old parameters θₒ and ϕₒ).
+ # Note: for safety, we make sure that params do not get overridden by the
+ # next optimization by taking their deep copy.
+ train_state_behaviour = agent.train_state.replace(
+ params=copy.deepcopy(agent.train_state.params)
+ )
+
+ for epoch_idx in range(num_epochs):
+ for batch_slice in DataLoader(memory=memory_flat).batch_slices_generator(
+ batch_size=batch_size,
+ shuffle=True,
+ seed=epoch_idx,
+ allow_partial_batch=False,
+ ):
+ # Perform one step of gradient descent
+ train_state, extra_data = Agent.train_actor_critic(
+ train_state_behavior=train_state_behaviour,
+ train_state_target=agent.train_state,
+ memory=memory_flat[batch_slice],
+ returns_target=returns[batch_slice],
+ policy_gradient_loss_weight=advantages[batch_slice],
+ ppo_params=agent.params,
+ )
+
+ # Update the agent with the new train state
+ agent = jax_dataclasses.replace(agent, train_state=train_state)
+
+ # TODO
+ # Create the log data of the training step
+ log_data = FrozenDict(
+ extra_data, reward_range=(agent._min_reward, agent._max_reward)
+ )
+
+ # Print to output
+ # TODO: logging?
+ print(log_data)
+
+ # ======================================
+ # Update the KL parameters for PPO-KLPEN
+ # ======================================
+
+ # @jax.jit
+ # def update_beta_kl(agent: Agent) -> Agent:
+ #
+ # # Update the train state with the adjusted β_KL
+ # train_state_kl = Agent.adaptive_update_beta_ppo_kl_pen(
+ # train_state_behavior=train_state_behaviour,
+ # train_state_target=agent.train_state,
+ # ppo_params=agent.params,
+ # memory=memory_flat,
+ # )
+ #
+ # # Update the agent with the new train state
+ # return jax_dataclasses.replace(agent, train_state=train_state_kl)
+
+ # agent = jax.lax.cond(
+ # pred=agent.params.beta_kl != 0.0,
+ # true_fun=update_beta_kl,
+ # false_fun=lambda agent: agent,
+ # operand=agent,
+ # )
+
+ if agent.params.beta_kl != 0.0:
+ # Update the train state with the adjusted β_KL
+ train_state_kl = Agent.adaptive_update_beta_ppo_kl_pen(
+ train_state_behavior=train_state_behaviour,
+ train_state_target=agent.train_state,
+ ppo_params=agent.params,
+ memory=memory_flat,
+ )
+
+ # Update the agent with the new train state
+ agent = jax_dataclasses.replace(agent, train_state=train_state_kl)
+
+ if not print_report:
+ return agent
+
+ # ===================
+ # Print training data
+ # ===================
+
+ avg_length_of_trajectories = []
+
+ for trajectory in memory.trajectories():
+ avg_length_of_trajectories.append(trajectory.rewards.size)
+
+ from rich.console import Console
+ from rich.table import Table
+
+ table = Table(title=f"Iteration #{agent._num_iterations}")
+ table.add_column("Name", justify="left", style="cyan", no_wrap=True)
+ table.add_column("Value", justify="right", style="cyan", no_wrap=True)
+
+ table.add_row("Epochs", f"{num_epochs}")
+ table.add_row("Timesteps", f"{len(memory)}")
+ table.add_row("Total timesteps", f"{agent._num_timesteps}")
+ table.add_row("Avg reward", f"{float(memory_flat.rewards.mean()):10.4f}")
+ console = Console()
+ print()
+ console.print(table)
+
+ return agent
diff --git a/src/jaxsim/training/distributions.py b/src/jaxsim/training/distributions.py
new file mode 100644
index 000000000..ae60d6dbe
--- /dev/null
+++ b/src/jaxsim/training/distributions.py
@@ -0,0 +1,260 @@
+from typing import Optional, Sequence, Tuple, Union
+
+import distrax
+import jax.numpy as jnp
+import jax.scipy.special
+from distrax import MultivariateNormalDiag
+from distrax._src.distributions import distribution
+
+
+class SquashedMultivariateNormalDiagBase(distribution.Distribution):
+ def __init__(
+ self,
+ low: distribution.Array,
+ high: distribution.Array,
+ loc: Optional[distribution.Array] = None,
+ scale_diag: Optional[distribution.Array] = None,
+ ):
+ self.normal = MultivariateNormalDiag(loc=loc, scale_diag=scale_diag)
+ self.low, _ = jnp.broadcast_arrays(low, self.normal.loc)
+ self.high, _ = jnp.broadcast_arrays(high, self.normal.loc)
+
+ def event_shape(self) -> Tuple[int, ...]:
+ return self.normal.event_shape
+
+ def _sample_n(self, key: jax.random.PRNGKey, n: int) -> distribution.Array:
+ samples_normal = self.normal._sample_n(key=key, n=n)
+ return self._squash(unsquashed_sample=samples_normal)
+
+ def _squash(self, unsquashed_sample: distribution.Array) -> distribution.Array:
+ raise NotImplementedError
+
+ def _unsquash(self, squashed_sample: distribution.Array) -> distribution.Array:
+ raise NotImplementedError
+
+ def _log_abs_det_grad_squash(
+ self, unsquashed_sample: distribution.Array
+ ) -> distribution.Array:
+ raise NotImplementedError
+
+ def entropy(self) -> distribution.Array:
+ raise NotImplementedError
+
+ def mean(self) -> distribution.Array:
+ mean_normal = self.normal.mean()
+ return self._squash(unsquashed_sample=mean_normal)
+
+ def median(self) -> distribution.Array:
+ median_normal = self.normal.median()
+ return self._squash(unsquashed_sample=median_normal)
+
+ def variance(self) -> distribution.Array:
+ raise NotImplementedError
+
+ def stddev(self) -> distribution.Array:
+ raise NotImplementedError
+
+ def mode(self) -> distribution.Array:
+ mode_normal = self.normal.mode()
+ return self._squash(unsquashed_sample=mode_normal)
+
+ def cdf(self, value: distribution.Array) -> distribution.Array:
+ raise NotImplementedError
+
+ def log_cdf(self, value: distribution.Array) -> distribution.Array:
+ raise NotImplementedError
+
+ def sample(
+ self,
+ *,
+ seed: Union[distribution.IntLike, distribution.PRNGKey],
+ sample_shape: Union[distribution.IntLike, Sequence[distribution.IntLike]] = (),
+ ) -> distribution.Array:
+ # Sample from the MultivariateNormal distribution
+ unsquashed_sample = self.normal.sample(seed=seed, sample_shape=sample_shape)
+
+ # Squash the sample into [low, high]
+ squashed_sample = self._squash(unsquashed_sample=unsquashed_sample)
+ assert squashed_sample.shape == unsquashed_sample.shape
+
+ return squashed_sample
+
+ def log_prob(self, value: distribution.Array) -> distribution.Array:
+ # Unsquash from [low, high] to ]-∞, ∞[
+ value_unsquashed = self._unsquash(squashed_sample=value)
+
+ # Compute the log-prob of the underlying normal distribution
+ log_prob_norm = self.normal.log_prob(value=value_unsquashed)
+
+ # Compute the correction term due to squashing
+ log_abs_det_grad_squash = self._log_abs_det_grad_squash(
+ unsquashed_sample=value_unsquashed
+ )
+
+ # Adjust the log-prob with the gradient of the squashing function
+ return log_prob_norm - log_abs_det_grad_squash
+
+ def kl_divergence(
+ self, other_dist: "SquashedMultivariateNormalDiagBase", **kwargs
+ ) -> distribution.Array:
+ if not isinstance(other_dist, type(self)):
+ raise (TypeError(other_dist), type(self))
+
+ # The squashing function does not influence the KL divergence
+ return self.normal.kl_divergence(other_dist=other_dist.normal)
+
+
+class GaussianSquashedMultivariateNormalDiag(SquashedMultivariateNormalDiagBase):
+ Epsilon: float = 1e-5
+ ScaleOfSquashingGaussian = 0.5 * 1.8137
+
+ def __init__(
+ self,
+ low: distribution.Array,
+ high: distribution.Array,
+ loc: Optional[distribution.Array] = None,
+ scale_diag: Optional[distribution.Array] = None,
+ ):
+ # Clip the mean of the underlying gaussian in a valid range
+ loc = jnp.clip(a=loc, a_min=-4, a_max=4)
+
+ # Initialize base class
+ super().__init__(low=low, high=high, loc=loc, scale_diag=scale_diag)
+
+ # Create the gaussian from which we take the CDF as squashing function.
+ # Note: the default scale approximates the CDF to tanh.
+ self.squash_dist = distrax.MultivariateNormalDiag(
+ loc=jnp.array([0.0]), scale_diag=jnp.array([self.ScaleOfSquashingGaussian])
+ )
+
+ def _squash_dist_cdfi(self, value: distribution.Array) -> distribution.Array:
+ def unstandardize(
+ dist: MultivariateNormalDiag, value_std: distribution.Array
+ ) -> distribution.Array:
+ return value_std * dist.scale_diag + dist.loc
+
+ return unstandardize(
+ dist=self.squash_dist, value_std=jax.scipy.special.ndtri(p=value)
+ )
+
+ def _squash(self, unsquashed_sample: distribution.Array) -> distribution.Array:
+ # Squash the input sample into the [0, 1] interval
+ squashed_sample = jnp.vectorize(self.squash_dist.cdf)(unsquashed_sample)
+
+ # Adjust boundaries in order to prevent getting infinite log-prob
+ clipped_squashed_sample = jnp.clip(
+ squashed_sample, a_min=self.Epsilon, a_max=(1 - self.Epsilon)
+ )
+
+ # Project the squashed sample into the output space [low, high]
+ return clipped_squashed_sample * (self.high - self.low) + self.low
+
+ def _unsquash(self, squashed_sample: distribution.Array) -> distribution.Array:
+ import jaxsim
+
+ # if (
+ # not jaxsim.utils.tracing(squashed_sample)
+ # and jnp.where(squashed_sample > self.high, True, False).any()
+ # ):
+ # raise ValueError(squashed_sample, self.high)
+ #
+ # if (
+ # not jaxsim.utils.tracing(squashed_sample)
+ # and jnp.where(squashed_sample < self.low, True, False).any()
+ # ):
+ # raise ValueError(squashed_sample, self.low)
+ # Project the squashed sample into the normalized space [0, 1]
+ normalized_squashed_example = (squashed_sample - self.low) / (
+ self.high - self.low
+ )
+
+ # Unsquash the sample
+ return self._squash_dist_cdfi(value=normalized_squashed_example)
+
+ def _log_abs_det_grad_squash(
+ self, unsquashed_sample: distribution.Array
+ ) -> distribution.Array:
+ # Compute the log-grad of the squashing function
+ log_grad = jnp.vectorize(self.squash_dist.log_prob)(unsquashed_sample)
+ log_grad += jnp.log(self.high - self.low) # TODO: add this again
+
+ # Adjust size
+ log_grad = log_grad if log_grad.ndim > 1 else jnp.array([log_grad])
+
+ # Sum over sample dimension (i.e. return a single value for each sample)
+ return log_grad.sum(axis=1)
+
+ def entropy(self) -> distribution.Array:
+ return (
+ -self.normal.kl_divergence(other_dist=self.squash_dist)
+ + jnp.log(self.high - self.low).sum()
+ )
+
+
+class TanhSquashedMultivariateNormalDiag(SquashedMultivariateNormalDiagBase):
+ Epsilon: float = 1e-6
+
+ def __init__(
+ self,
+ low: distribution.Array,
+ high: distribution.Array,
+ loc: Optional[distribution.Array] = None,
+ scale_diag: Optional[distribution.Array] = None,
+ ):
+ # Clip the mean of the underlying gaussian in a valid range
+ loc = jnp.clip(a=loc, a_min=-4, a_max=4)
+
+ # Initialize base class
+ super().__init__(low=low, high=high, loc=loc, scale_diag=scale_diag)
+
+ def _squash(self, unsquashed_sample: distribution.Array) -> distribution.Array:
+ # Squash the input sample into the [0, 1] interval
+ squashed_sample = (jnp.tanh(unsquashed_sample) + 1) / 2
+
+ # Project the squashed sample into the output space [low, high]
+ return squashed_sample * (self.high - self.low) + self.low
+
+ def _unsquash(self, squashed_sample: distribution.Array) -> distribution.Array:
+ import jaxsim
+
+ # if (
+ # not jaxsim.utils.tracing(squashed_sample)
+ # and jnp.where(squashed_sample > self.high, True, False).any()
+ # ):
+ # raise ValueError(squashed_sample, self.high)
+ # if (
+ # not jaxsim.utils.tracing(squashed_sample)
+ # and jnp.where(squashed_sample < self.low, True, False).any()
+ # ):
+ # raise ValueError(squashed_sample, self.low)
+ # Project the squashed sample into the normalized space [0, 1]
+ normalized_squashed_sample = (squashed_sample - self.low) / (
+ self.high - self.low
+ )
+
+ # Project into [-1, 1]
+ normalized_squashed_sample = normalized_squashed_sample * 2 - 1.0
+
+ # Clip so that tanh output is stabilized
+ clipped_squashed_sample = jnp.clip(
+ normalized_squashed_sample, -1.0 + self.Epsilon, 1.0 - self.Epsilon
+ )
+
+ # Unsquash the sample
+ return jnp.arctanh(clipped_squashed_sample)
+
+ def _log_abs_det_grad_squash(
+ self, unsquashed_sample: distribution.Array
+ ) -> distribution.Array:
+ # Compute the log-grad of the squashing function
+ log_grad = jnp.log(1 - jnp.tanh(unsquashed_sample) ** 2 + self.Epsilon)
+ log_grad += jnp.log(self.high - self.low) - jnp.log(2)
+
+ # Adjust size
+ log_grad = log_grad if log_grad.ndim > 1 else jnp.array([log_grad])
+
+ # Sum over sample dimension (i.e. return a single value for each sample)
+ return log_grad.sum(axis=1)
+
+ def entropy(self) -> distribution.Array:
+ return jnp.zeros_like(self.low).sum(axis=1)
diff --git a/src/jaxsim/training/memory.py b/src/jaxsim/training/memory.py
new file mode 100644
index 000000000..1c6339e11
--- /dev/null
+++ b/src/jaxsim/training/memory.py
@@ -0,0 +1,299 @@
+import operator
+from typing import Generator, List, NamedTuple, Sequence, Tuple, Union
+
+import jax
+import jax.flatten_util
+import jax.numpy as jnp
+import jax_dataclasses
+import numpy as np
+import numpy.typing as npt
+from flax.core.frozen_dict import FrozenDict
+
+import jaxsim.typing as jtp
+
+
+@jax_dataclasses.pytree_dataclass
+class Memory(Sequence):
+ states: jtp.MatrixJax
+ actions: jtp.MatrixJax
+ rewards: jtp.VectorJax
+ dones: jtp.VectorJax
+
+ values: jtp.VectorJax
+ log_prob_actions: jtp.VectorJax
+
+ infos: FrozenDict[str, jtp.PyTree] = jax_dataclasses.field(default_factory=dict)
+
+ @property
+ def flat(self) -> bool:
+ # return self.dones.ndim == 2
+ # return self.dones.squeeze().ndim == 1
+ # return self.dones.squeeze().ndim <= 1
+ return self.dones.ndim < 3
+
+ @property
+ def number_of_levels(self) -> int:
+ if self.flat:
+ return 0
+
+ self.check_valid()
+ return self.dones.shape[0]
+
+ def truncate_last_trajectory(self) -> "Memory":
+ if self.flat:
+ dones = self.dones.at[-1].set(True)
+ memory = jax_dataclasses.replace(self, dones=dones) # noqa
+
+ else:
+ dones = self.dones.at[:, -1].set(True)
+ memory = jax_dataclasses.replace(self, dones=dones) # noqa
+
+ return memory
+
+ def flatten(self) -> "Memory":
+ if self.flat:
+ return self
+
+ self.check_valid()
+ return jax.tree_map(lambda x: jnp.vstack(x), self)
+
+ def unflatten(self) -> "Memory":
+ if not self.flat:
+ return self
+
+ # return jax.tree_map(lambda x: jnp.stack([jnp.vstack(x)]), self)
+
+ self.check_valid()
+
+ memory = jax.tree_map(lambda x: x[jnp.newaxis, :], self)
+ memory.check_valid()
+ return memory
+
+ @staticmethod
+ def build(
+ states: jtp.MatrixJax,
+ actions: jtp.MatrixJax,
+ rewards: jtp.VectorJax,
+ dones: jtp.VectorJax,
+ values: jtp.VectorJax,
+ log_prob_actions: jtp.VectorJax,
+ infos: FrozenDict[str, jtp.ArrayJax] = FrozenDict(),
+ ) -> "Memory":
+ memory = Memory(
+ states=states,
+ actions=actions,
+ rewards=rewards,
+ dones=dones,
+ values=values,
+ log_prob_actions=log_prob_actions,
+ infos=infos,
+ )
+
+ # We work on (N, 1) 1D arrays
+ # Transform scalars to 1D arrays
+ memory = jax.tree_map(lambda x: x.squeeze(), memory)
+ # memory = jax.tree_map(lambda x: jnp.vstack(x.squeeze()), memory)
+ memory = jax.tree_map(lambda x: jnp.array([x]) if x.ndim == 0 else x, memory)
+ # memory = jax.tree_map(lambda x: jnp.vstack(x), memory) # TODO
+ # memory = jax.tree_map(lambda x: x[jnp.newaxis,:] if x.ndim == 1 else x, memory)
+
+ # D, (S, D), (L, S, D)
+
+ # If there's a single sample (S=1), add a new trivial S axis: (D,) -> (1, D)
+ # if memory.dones.ndim == 1:
+ if memory.dones.size == 1:
+ memory = jax.tree_map(lambda x: x[jnp.newaxis, :], memory)
+
+ # memory = jax.tree_map(lambda x: jnp.vstack(x) if x.ndim == 1 else x, memory)
+ # TODO: vstack necessary??
+
+ memory.check_valid()
+ return memory
+
+ def __len__(self) -> int:
+ self.check_valid()
+ return self.dones.flatten().size
+
+ def __getitem__(self, item) -> "Memory":
+ try:
+ item = operator.index(item)
+ except TypeError:
+ pass
+
+ def get_slice(s: Union[slice, np.ndarray, jnp.ndarray, Tuple]) -> Memory:
+ return jax.tree_map(lambda x: x[s], self)
+ # squeezed = jax.tree_map(lambda x: jnp.squeeze(x[s]), self)
+ # return jax.tree_map(lambda x: jnp.vstack(x[s]), squeezed)
+
+ if isinstance(item, int):
+ if item > len(self):
+ raise IndexError(
+ f"index {item} is out of bounds for axis 0 with size {len(self)}"
+ )
+
+ return get_slice(jnp.s_[item])
+
+ if isinstance(item, slice):
+ return get_slice(item)
+
+ if isinstance(item, (np.ndarray, jnp.ndarray)):
+ return get_slice(item)
+
+ if isinstance(item, tuple):
+ return get_slice(item)
+
+ raise TypeError(item)
+
+ def flat_level(self, level: int) -> "Memory":
+ if self.flat:
+ return self
+
+ if level > self.number_of_levels - 1:
+ raise ValueError(level, self.number_of_levels)
+
+ return jax.tree_map(lambda x: x[level], self)
+
+ def has_nans(self) -> jtp.Bool:
+ has_nan = jax.tree_map(lambda l: jnp.isnan(l).any(), self)
+ return jax.flatten_util.ravel_pytree(has_nan)[0].any()
+
+ def check_valid(self) -> None:
+ # if self.has_nans():
+ # raise ValueError("Found NaN values")
+
+ # L, S, D = (None, None, None)
+
+ if self.dones.ndim < 2 or self.dones.ndim > 3:
+ raise ValueError(self.dones.shape)
+
+ # if self.dones.ndim == 2:
+ # L, S, D = None, self.dones.shape[0], self.dones.shape[1]
+ #
+ # if self.dones.ndim == 3:
+ # L, S, D = self.dones[0].shape, self.dones[1].shape, self.dones.shape[2]
+ #
+ # if D != 1:
+ # raise ValueError(D, self.dones.shape)
+
+ # return True
+ shape_of_leaves = jax.tree_map(
+ lambda x: x.shape, jax.tree_util.tree_leaves(self)
+ )
+
+ # (L, S, D)
+ # (S, D)
+
+ # Check (S, ⋅) TODO: same as check below with L and S? can be removed?
+ if self.flat and len(set([s[0:1] for s in shape_of_leaves])) != 1:
+ raise ValueError(shape_of_leaves)
+
+ # Check (L, S, ⋅) TODO: same as check below with L and S? can be removed?
+ if not self.flat and len(set([s[0:2] for s in shape_of_leaves])) != 1:
+ raise ValueError(shape_of_leaves)
+
+ # Check all leaves have same shape (L,S,⋅)
+ # if len(set([s[:-1] for s in shape_of_leaves])) != 1:
+ # raise ValueError(shape_of_leaves)
+
+ # Get the Level and Samples dimensions
+ L = self.dones.shape[0] if not self.flat else None
+ S = self.dones.shape[1] if not self.flat else self.dones.shape[0]
+
+ # If flat, check that all leaves have S samples
+ if self.flat and set(s[0] for s in shape_of_leaves) != {S}:
+ raise ValueError(shape_of_leaves)
+
+ # If not flat, check that all leaves have L levels and S samples
+ if not self.flat and set(s[0:2] for s in shape_of_leaves) != {(L, S)}:
+ raise ValueError(shape_of_leaves)
+
+ def trajectories(self) -> Generator["Memory", None, None]:
+ # In this method we operate on non-flat memory
+ memory = self if not self.flat else self.unflatten()
+
+ def trajectory_slices_of_level(
+ memory: Memory, level: int = 0
+ ) -> List[Tuple[npt.NDArray, npt.NDArray, slice]]:
+ idx_of_ones, _ = np.where(memory.dones[level] == 1)
+
+ if idx_of_ones.size < 2:
+ return []
+
+ start = idx_of_ones[:-1] + 1
+ stop = idx_of_ones[1:] + 1
+
+ return [
+ np.s_[level, idx_start:idx_stop]
+ for idx_start, idx_stop in zip(start, stop)
+ ]
+
+ for level in range(memory.number_of_levels):
+ for s in trajectory_slices_of_level(memory=memory, level=level):
+ yield memory[s]
+
+
+class DataLoader:
+ def __init__(self, memory: Memory):
+ self.memory = memory if memory.flat else memory.flatten()
+
+ def batch_slices_generator(
+ self,
+ batch_size: int,
+ shuffle: bool = False,
+ seed: int = None,
+ # key: jax.random.PRNGKeyArray = None,
+ allow_partial_batch: bool = False,
+ ) -> Generator[jtp.ArrayJax, None, None]:
+ # Create the index mask
+ # mask_indices = jnp.arange(0, len(self.memory), dtype=int)
+ mask_indices = np.arange(0, len(self.memory), dtype=int)
+
+ seed = seed if seed is not None else 0
+ # key = key if key is not None else jax.random.PRNGKey(seed=seed)
+
+ # When this function is JIT compiled with shuffle=True, the shuffled indices
+ # are always the same, according to the seed
+ if shuffle:
+ rng = np.random.default_rng(seed)
+ mask_indices = rng.permutation(mask_indices)
+
+ # mask_indices = (
+ # mask_indices
+ # if shuffle is False
+ # else jax.random.permutation(key=key, x=mask_indices)
+ # )
+
+ def boolean_mask_generator(
+ a: npt.NDArray, size: int
+ ) -> Generator[jtp.ArrayJax, None, None]:
+ if a.ndim != 1:
+ raise ValueError(a.ndim)
+
+ if size > a.size:
+ raise ValueError(size, a.size)
+
+ idx = 0
+ mask = jnp.zeros(a.shape, dtype=bool)
+
+ while idx + size <= a.size:
+ batch_slice = np.s_[idx : idx + size]
+ # yield mask.at[batch_slice].set(True)
+
+ mask = np.zeros(a.shape, dtype=bool)
+ mask[batch_slice] = True
+ yield mask[np.array(mask_indices)]
+
+ # low, high = idx, idx + size
+ # indices = jnp.arange(start=0, stop=mask.size)
+ # batch_slice_low = jnp.where(indices >= low, True, False)
+ # batch_slice_high = jnp.where(indices < high, True, False)
+ # yield batch_slice_low * batch_slice_high
+
+ idx += size
+
+ if allow_partial_batch:
+ # batch_slice = np.s_[idx:]
+ batch_slice = jnp.s_[idx:]
+ yield mask.at[batch_slice].set(True)
+
+ yield from boolean_mask_generator(a=mask_indices, size=batch_size)
diff --git a/src/jaxsim/training/networks.py b/src/jaxsim/training/networks.py
new file mode 100644
index 000000000..335ae89a6
--- /dev/null
+++ b/src/jaxsim/training/networks.py
@@ -0,0 +1,173 @@
+from typing import Any, Callable, Sequence, Tuple
+
+import distrax
+import flax.linen as nn
+import gym.spaces
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+import jaxsim.typing as jtp
+from jaxsim import logging
+from jaxsim.training import distributions
+
+
+class CriticNetwork(nn.Module):
+ layer_sizes: Sequence[int]
+
+ activation: Callable[[jtp.Array], jtp.Array] = nn.relu
+ kernel_init: Callable[..., Any] = jax.nn.initializers.normal()
+
+ bias: bool = True
+ activate_final: bool = False
+
+ def setup(self):
+ logging.debug("Configuring critic network...")
+
+ # Automatically add a final layer so that the NN outputs a scalar value V(s)
+ all_layer_sizes = np.hstack([self.layer_sizes, 1])
+
+ self.layers = [
+ nn.Dense(
+ features=size,
+ kernel_init=self.kernel_init,
+ use_bias=self.bias,
+ )
+ for size in all_layer_sizes
+ ]
+
+ logging.debug(f" Layers: {all_layer_sizes.tolist()}")
+ activated = [
+ self.activate_layer(layer_idx=idx) for idx, _ in enumerate(all_layer_sizes)
+ ]
+ logging.debug(f" Activated: {activated}")
+ logging.debug(f" Activation function: {self.activation.__name__}")
+ logging.debug(f" Kernel initializer: {self.kernel_init}")
+
+ def activate_layer(self, layer_idx: int) -> bool:
+ return (layer_idx != len(self.layers) - 1) or self.activate_final
+
+ def __call__(self, data: jtp.Array) -> jtp.Array:
+ x = data
+
+ for idx, layer in enumerate(self.layers):
+ x = layer(x)
+ x = self.activation(x) if self.activate_layer(layer_idx=idx) else x
+
+ return x
+
+
+class ActorNetwork(nn.Module):
+ layer_sizes: Sequence[int]
+ action_space: gym.spaces.Box
+
+ activation: Callable[[jtp.Array], jtp.Array] = nn.relu
+ kernel_init: Callable[..., Any] = jax.nn.initializers.normal()
+
+ bias: bool = True
+ activate_final: bool = False
+
+ log_std_min: float = None
+ log_std_max: float = jnp.log(1.0)
+ log_std_init_val: float = jnp.log(0.50)
+
+ def setup(self) -> None:
+ logging.debug("Configuring actor network...")
+
+ # Automatically add a final layer so that the NN outputs an array containing
+ # the means μ of the action distribution
+ all_layer_sizes = np.hstack([self.layer_sizes, self.action_space.sample().size])
+
+ # Get the index of the last layer
+ last_layer_idx = all_layer_sizes.size - 1
+
+ def get_kernel_init(layer_idx: int) -> Callable[..., Any]:
+ # For the last layer we use a xavier_normal initializer with a variance
+ # 1/100 smaller than the default (which is 1.0)
+ last_layer_kernel_init = jax.nn.initializers.normal(stddev=1e-2 / 100)
+
+ return (
+ self.kernel_init
+ if layer_idx != last_layer_idx
+ else last_layer_kernel_init
+ )
+
+ # Fully-connected network for the mean μ
+ self.layers = [
+ nn.Dense(
+ features=size,
+ kernel_init=get_kernel_init(layer_idx=layer_idx),
+ use_bias=self.bias,
+ )
+ for layer_idx, size in enumerate(all_layer_sizes)
+ ]
+
+ # State-independent bias variables for −the log of- the variance σ
+ self.log_std = self.param(
+ "log_std",
+ lambda rng, shape: self.log_std_init_val * jnp.ones(shape),
+ self.action_space.sample().size,
+ )
+
+ logging.debug(f" Layers: {all_layer_sizes.tolist()}")
+ activated = [
+ self.activate_layer(layer_idx=idx) for idx, _ in enumerate(all_layer_sizes)
+ ]
+ logging.debug(f" Activated: {activated}")
+ logging.debug(f" Activation function: {self.activation.__name__}")
+ logging.debug(f" Kernel initializer: {self.kernel_init}")
+ logging.debug(
+ " log(σ): {:.4f} (min={}, max={})".format(
+ self.log_std_init_val, self.log_std_min, self.log_std_max
+ )
+ )
+
+ def activate_layer(self, layer_idx: int) -> bool:
+ return (layer_idx != len(self.layers) - 1) or self.activate_final
+
+ def __call__(self, data: jtp.Array) -> distrax.Distribution:
+ x = data
+
+ for idx, layer in enumerate(self.layers):
+ x = layer(x)
+ x = self.activation(x) if self.activate_layer(layer_idx=idx) else x
+
+ # The mean μ is the NN output
+ mu = x
+
+ # ==========================
+ # Distributional exploration
+ # ==========================
+
+ # Helper to clip the standard deviation
+ def clip_log_std(value: jtp.Vector) -> jtp.Vector:
+ return jnp.clip(a=value, a_min=self.log_std_min, a_max=self.log_std_max)
+
+ # Clip log(σ) if limits are defined
+ log_std_limits = np.array([self.log_std_min, self.log_std_max])
+ log_std = clip_log_std(self.log_std) if log_std_limits.any() else self.log_std
+
+ # Compute σ taking the exponential
+ std = jnp.exp(log_std)
+
+ # Return the actor distribution
+ return distributions.TanhSquashedMultivariateNormalDiag(
+ # return distributions.GaussianSquashedMultivariateNormalDiag(
+ loc=mu,
+ scale_diag=std,
+ low=self.action_space.low,
+ high=self.action_space.high,
+ )
+
+
+class ActorCriticNetworks(nn.Module):
+ actor: ActorNetwork
+ critic: CriticNetwork
+
+ def __call__(self, data: jtp.Array) -> Tuple[distrax.Distribution, jtp.Array]:
+ observation = data
+
+ value = self.critic(data=observation)
+ distribution = self.actor(data=observation)
+
+ return distribution, value
diff --git a/src/jaxsim/training/sampler.py b/src/jaxsim/training/sampler.py
new file mode 100644
index 000000000..feef77d47
--- /dev/null
+++ b/src/jaxsim/training/sampler.py
@@ -0,0 +1,279 @@
+import dataclasses
+import functools
+import time
+from typing import Any, Callable, Dict, Tuple, Union
+
+import jax
+import jax.flatten_util
+import jax.numpy as jnp
+import jax_dataclasses
+import numpy.typing as npt
+from flax.core.frozen_dict import FrozenDict
+from rich.console import Console
+
+import jaxsim.typing as jtp
+from jaxsim.simulation import systems as sys
+from jaxsim.training.agent import Agent
+from jaxsim.training.memory import Memory
+
+
+@jax_dataclasses.pytree_dataclass
+class EnvironmentSamplerBuffer:
+ key: jtp.ArrayJax
+ state: FrozenDict
+
+
+@dataclasses.dataclass
+class EnvironmentSampler:
+ parallel: bool
+ environment: sys.EnvironmentSystem
+
+ duration: float
+ horizon: jtp.VectorJax
+
+ _buffer: EnvironmentSamplerBuffer
+ _sample_function: Callable = jax_dataclasses.field(default=None)
+
+ def get_sample_function(
+ self,
+ ) -> Callable[
+ [FrozenDict, jtp.ArrayJax, Dict[str, Any]], Tuple[Memory, sys.EnvironmentSystem]
+ ]:
+ """
+ Build the sampling function for either single or parallel environment.
+
+ Returns:
+ The JIT-compiled sampling function.
+ """
+
+ if self._sample_function is not None:
+ return self._sample_function
+
+ # We split the arguments batched and non-batched.
+ # The non-batched arguments will be stored in a kwargs dictionary.
+ func = lambda x0, key, kwargs: EnvironmentSampler.step_environment_over_horizon(
+ x0=x0, key=key, **kwargs
+ )
+
+ if not self.parallel:
+ self._sample_function = func
+ return self._sample_function
+
+ # Thanks to the definition of the lambda, specifying the batched axes
+ # is much more simple in vmap
+ self._sample_function = jax.jit(jax.vmap(fun=func, in_axes=(0, 0, None)))
+ return self._sample_function
+
+ def _sample(
+ self, x0: FrozenDict, key: jtp.Array, extra_args: Dict[str, Any] = None
+ ) -> Tuple[Memory, sys.EnvironmentSystem]:
+ """
+ Call the JIT-compiled sampling functions.
+
+ The first call this method logs the time spent to JIT-compile the function.
+
+ Args:
+ x0: The state of the (possibly parallelized) system.
+ key: The key of the (possibly parallelized) system.
+ extra_args: The non-batch arguments of the system
+
+ Returns:
+ The sampled memory and the system with the final state.
+ """
+
+ extra_args = extra_args if extra_args is not None else dict()
+
+ if self._sample_function is not None:
+ return self.get_sample_function()(x0, key, extra_args)
+
+ console = Console()
+
+ with console.status("[bold green]JIT compiling sampling function...") as _:
+ start = time.time()
+ out = self.get_sample_function()(x0, key, extra_args)
+ elapsed = time.time() - start
+ console.log(f"JIT compiled sampler in {elapsed:.3f} seconds.")
+
+ return out
+
+ @staticmethod
+ def build(
+ environment: sys.EnvironmentSystem,
+ t0: float = 0.0,
+ tf: float = 5.0,
+ dt: float = 0.001,
+ seed: int = 0,
+ parallel_environments: int = 1,
+ ) -> "EnvironmentSampler":
+ def build_state_and_key_parallel() -> EnvironmentSamplerBuffer:
+ env_list = [
+ environment.reset_environment(environment.seed(seed=i + 1))
+ for i in range(parallel_environments)
+ ]
+
+ def tree_transpose(list_of_trees):
+ # return jax.tree_multimap(lambda *xs: jnp.stack(xs), *list_of_trees)
+ return jax.tree_map(lambda *xs: jnp.stack(xs), *list_of_trees)
+
+ keys_list = [env.key for env in env_list]
+ states_list = [env.state_subsystem for env in env_list]
+
+ key = tree_transpose(keys_list)
+ state = tree_transpose(states_list)
+
+ return EnvironmentSamplerBuffer(state=state, key=key) # noqa
+
+ def build_state_and_key_not_parallel() -> EnvironmentSamplerBuffer:
+ env = environment.reset_environment(environment.seed(seed=seed))
+
+ key = env.key
+ state = env.state_subsystem
+
+ return EnvironmentSamplerBuffer(state=state, key=key) # noqa
+
+ buffer = (
+ build_state_and_key_parallel()
+ if parallel_environments > 1
+ else build_state_and_key_not_parallel()
+ )
+
+ # Build the sampler
+ sampler = EnvironmentSampler(
+ environment=environment,
+ parallel=parallel_environments > 1,
+ horizon=jnp.arange(start=t0, stop=tf, step=dt),
+ duration=tf - t0,
+ _buffer=buffer,
+ )
+
+ return sampler
+
+ def sample(self, agent: Agent, explore: bool = True) -> Memory:
+ # Update the horizon
+ horizon = self.horizon + self.duration
+
+ # Build the dict of non-batched arguments
+ inputs_dict = dict(
+ system=self.environment, agent=agent, t=horizon, explore=explore
+ )
+
+ # Sample from the environment
+ memory, environment = self._sample(
+ self._buffer.state, self._buffer.key, inputs_dict
+ )
+
+ # Create the new buffer making sure that shapes didn't change
+ with jax_dataclasses.copy_and_mutate(self._buffer) as buffer:
+ buffer.key = environment.key
+ buffer.state = environment.state_subsystem
+
+ # Double check that the state dict didn't change
+ assert jax.tree_structure(self._buffer.state) == jax.tree_structure(
+ buffer.state
+ )
+
+ # Update the sampler state
+ self.horizon = horizon
+ self._buffer = buffer
+
+ # Return the sampled memory
+ return memory
+
+ @staticmethod
+ @functools.partial(jax.jit)
+ def run_actor(
+ agent: Agent, environment: sys.EnvironmentSystem, explore: bool = True
+ ) -> Tuple[jtp.VectorJax, jtp.VectorJax, jtp.VectorJax, Agent]:
+ key, agent = agent.advance_key2()
+
+ observation = environment.get_observation(system=environment)
+
+ action, log_prob_action, value = agent.choose_action(
+ observation=observation,
+ explore=jnp.array(explore).any(),
+ key=key,
+ )
+
+ return action, log_prob_action, value, agent
+
+ @staticmethod
+ @functools.partial(jax.jit)
+ def step_environment_over_horizon(
+ system: sys.EnvironmentSystem,
+ agent: Agent,
+ t: npt.NDArray,
+ x0: FrozenDict = None,
+ key: jax.random.PRNGKey = None,
+ explore: Union[bool, jtp.VectorJax] = jnp.array(True),
+ ) -> Tuple[Memory, sys.EnvironmentSystem]:
+ # Handle state and key
+ system = system if x0 is None else system.update_subsystem_state(new_state=x0)
+ system = system if key is None else system.update_key(key=key)
+
+ # Compute a dummy output of the system, used to initialize the buffer.
+ # We flatten the output (dict) in order to allocate a dense jax array
+ # used inside a JIT-compiled for loop.
+ out, system = system(t0=t[0], tf=t[0], u0=None)
+ out_flattened, restore_output_fn = jax.flatten_util.ravel_pytree(out)
+
+ # Create the buffers storing environment data
+ values = jnp.zeros(shape=(t.size, 1))
+ log_prob_actions = jnp.zeros(shape=(t.size, 1))
+ system_output = jnp.zeros(shape=(t.size, out_flattened.size))
+
+ # Generate a new key from the environment used for sampling actions
+ agent_key, system = system.generate_subkey()
+ agent = jax_dataclasses.replace(agent, key=agent_key) # noqa
+
+ # Initialize the loop carry
+ carry_init = (system_output, log_prob_actions, values, system, agent)
+
+ def body_fun(i: int, carry: Tuple) -> Tuple:
+ # Unpack the loop carry
+ system_output, log_prob_actions, values, system, agent = carry
+
+ # Execute the actor
+ (action, log_prob_action, value, agent) = EnvironmentSampler.run_actor(
+ agent=agent, environment=system, explore=explore
+ )
+
+ # Update values and log_prob
+ values = values.at[i].set(value)
+ log_prob_actions = log_prob_actions.at[i].set(log_prob_action)
+
+ # Advance the environment and get its output
+ out, system = system(t0=t[i], tf=t[i + 1], u0=FrozenDict(action=action))
+
+ # Store the environment output in the buffer
+ out_flattened, _ = jax.flatten_util.ravel_pytree(out)
+ system_output = system_output.at[i, :].set(out_flattened)
+
+ # Return the loop carry
+ return system_output, log_prob_actions, values, system, agent
+
+ # Execute the rollout. The environment automatically resets when it's done.
+ system_output, log_prob_actions, values, system, agent = jax.lax.fori_loop(
+ lower=0,
+ upper=t.size,
+ body_fun=body_fun,
+ init_val=carry_init,
+ )
+
+ # Unflatten the output
+ output_horizon: FrozenDict = jax.vmap(lambda b: restore_output_fn(b))(
+ system_output
+ )
+
+ # Create the memory object
+ memory = Memory.build(
+ states=output_horizon["observation"],
+ actions=output_horizon["action"],
+ rewards=output_horizon["reward"],
+ dones=output_horizon["done"],
+ values=values,
+ log_prob_actions=log_prob_actions,
+ infos=output_horizon["info"],
+ )
+
+ # Return objects after stepping over the horizon
+ return memory, system
diff --git a/src/jaxsim/training/trajectory_sampler.py b/src/jaxsim/training/trajectory_sampler.py
new file mode 100644
index 000000000..7ecb6d1a8
--- /dev/null
+++ b/src/jaxsim/training/trajectory_sampler.py
@@ -0,0 +1,242 @@
+from typing import Callable, Tuple
+
+import jax
+import jax.numpy as jnp
+import jax_dataclasses
+
+from jaxsim.gym import Env, EnvironmentState
+from jaxsim.gym.typing import *
+from jaxsim.training.agent import Agent
+from jaxsim.training.memory import Memory
+from jaxsim.utils import JaxsimDataclass
+
+
+@jax_dataclasses.pytree_dataclass
+class TrajectorySampler(JaxsimDataclass):
+ env: Env
+ agent: Agent
+ state: EnvironmentState
+
+ _sampling_fn: Callable = jax_dataclasses.static_field(default=None)
+
+ @staticmethod
+ def build(
+ env: Env, agent: Agent, state: EnvironmentState, number_of_environments: int = 1
+ ) -> "TrajectorySampler":
+ if number_of_environments > 1:
+ env, state = TrajectorySampler.make_parallel_environment(
+ number=number_of_environments, env=env, state=state
+ )
+
+ with TrajectorySampler(env=env, agent=agent, state=state).editable(
+ validate=False
+ ) as sampler:
+ # Handle parallel environments
+ sampling_fn = (
+ TrajectorySampler.Sample_trajectory
+ if sampler.number_of_environments() == 1
+ else jax.vmap(
+ fun=TrajectorySampler.Sample_trajectory,
+ in_axes=(0, 0, None, None, None),
+ )
+ )
+
+ # JIT compile sampling function
+ sampler._sampling_fn = jax.jit(
+ sampling_fn, static_argnames=["horizon_length"]
+ )
+
+ return sampler
+
+ def seed(self, seed: int = 0) -> None:
+ if self.number_of_environments() == 1:
+ state = self.env.seed(state=self.state, seed=seed)
+
+ else:
+ with self.state.editable(validate=True) as state:
+ state.key = jax.random.split(
+ key=jax.random.PRNGKey(seed=seed), num=self.number_of_environments()
+ )
+
+ self.set_mutability(mutable=True, validate=False)
+ self.state = state
+ self.set_mutability(mutable=False)
+
+ _ = self.reset()
+
+ @staticmethod
+ def reset_fn(
+ env: Env, state: EnvironmentState
+ ) -> Tuple[EnvironmentState, Tuple[Observation, Reward, IsDone, Info]]:
+ return env.reset(state=state)
+
+ def reset(self) -> Observation:
+ reset_fn_jit = (
+ jax.jit(jax.vmap(self.reset_fn))
+ if self.number_of_environments() > 1
+ else jax.jit(self.reset_fn)
+ )
+
+ state, observation = reset_fn_jit(env=self.env, state=self.state)
+
+ self.set_mutability(mutable=True, validate=False)
+ self.state = state
+ self.set_mutability(mutable=False)
+
+ return observation
+
+ def number_of_environments(self) -> int:
+ shape_of_key = self.state.key.shape
+ return 1 if len(shape_of_key) == 1 else shape_of_key[0]
+
+ def sample_trajectory(
+ self, horizon_length: int = 1, explore: bool = True
+ ) -> Memory:
+ # We cannot use named arguments: https://github.com/google/jax/issues/7465
+ memory, state = self._sampling_fn(
+ self.env, self.state, self.agent, int(horizon_length), bool(explore)
+ )
+
+ self.set_mutability(mutable=True, validate=True)
+ self.state = state
+ self.set_mutability(mutable=False)
+
+ return memory
+
+ # ===============
+ # Private methods
+ # ===============
+
+ @staticmethod
+ def make_parallel_environment(
+ number: int, env: Env, state: EnvironmentState
+ ) -> Tuple[Env, EnvironmentState]:
+ envs = jax.tree_map(lambda *l: jnp.stack(l), *[env] * number)
+ states = jax.tree_map(lambda *l: jnp.stack(l), *[state] * number)
+
+ return envs, states
+
+ @staticmethod
+ def Sample_trajectory(
+ env: Env,
+ state: EnvironmentState,
+ agent: Agent,
+ horizon_length: int = 1,
+ explore: bool = True,
+ ) -> Tuple[Memory, EnvironmentState]:
+ # Create the memory object of the entire batch
+ memory = jax.tree_map(
+ lambda *x0: jnp.stack(x0),
+ *[TrajectorySampler.Zero_memory_sample(env=env, state=state, agent=agent)]
+ * horizon_length,
+ )
+
+ carry_init = (state, memory)
+
+ def body_fun(idx: int, carry: Tuple) -> Tuple:
+ # Unpack the carry
+ state, memory = carry
+
+ # Get the current observation
+ observation = env.get_observation(state=state).flatten()
+
+ # Sample a new action
+ subkey, state = state.generate_key()
+
+ action, log_prob_action, value = agent.choose_action(
+ observation=observation, explore=explore, key=subkey
+ )
+ # distribution, value = agent.train_state.apply_fn(
+ # agent.train_state.params, data=observation
+ # )
+ # action = jax.lax.select(
+ # pred=explore,
+ # on_true=distribution.sample(seed=subkey),
+ # on_false=distribution.mode(),
+ # )
+ # log_prob_action = distribution.log_prob(value=action)
+
+ # Step the environment with automatic reset
+ state, (_, reward, is_done, info) = TrajectorySampler.Step_environment(
+ env=env, state=state, action=action
+ )
+
+ # Build a single-entry memory object with the sample
+ sample = Memory.build(
+ states=observation,
+ actions=action,
+ rewards=reward,
+ dones=is_done,
+ values=value,
+ log_prob_actions=log_prob_action,
+ infos=info,
+ )
+
+ # Store the new sample
+ memory = jax.tree_map(
+ lambda stacked, leaf: stacked.at[idx].set(leaf),
+ memory,
+ TrajectorySampler.Memory_to_memory_1D(sample=sample),
+ )
+
+ return state, memory
+
+ state, memory = jax.lax.fori_loop(
+ lower=0,
+ upper=horizon_length,
+ body_fun=body_fun,
+ init_val=carry_init,
+ )
+
+ return memory, state
+
+ @staticmethod
+ def Step_environment(
+ env: Env, state: EnvironmentState, action: Action
+ ) -> Tuple[EnvironmentState, Tuple[Observation, Reward, IsDone, Info]]:
+ # Step the environment
+ state, (observation, reward, is_done, info) = env.step(
+ action=action, state=state
+ )
+
+ # Automatically reset the environment if done
+ state = jax.lax.cond(
+ pred=is_done,
+ true_fun=lambda: env.reset(state=state)[0],
+ false_fun=lambda: state,
+ )
+
+ return state, (observation, reward, is_done, info)
+
+ @staticmethod
+ def Memory_to_memory_1D(sample: Memory) -> Memory:
+ # Fix Memory to handle just one sample (S=1)
+ def memory_to_sample(leaf):
+ l = leaf.squeeze()
+ return jnp.array([l]) if l.ndim == 0 else l
+
+ return jax.tree_map(memory_to_sample, sample)
+
+ @staticmethod
+ def Zero_memory_sample(env: Env, state: EnvironmentState, agent: Agent) -> Memory:
+ _, obs = env.reset(state=state)
+
+ action, log_prob_action, value = agent.choose_action(
+ observation=obs.flatten(), explore=False, key=state.key
+ )
+
+ _, (obs, reward, is_done, info) = env.step(action=action, state=state)
+
+ sample = Memory.build(
+ states=obs.flatten(),
+ actions=action,
+ rewards=reward,
+ dones=is_done,
+ values=value,
+ log_prob_actions=log_prob_action,
+ infos=info,
+ )
+
+ zero_sample = jax.tree_map(lambda l: jnp.zeros_like(l), sample)
+
+ return TrajectorySampler.Memory_to_memory_1D(sample=zero_sample)
diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py
new file mode 100644
index 000000000..b79fd990f
--- /dev/null
+++ b/src/jaxsim/utils/__init__.py
@@ -0,0 +1,8 @@
+from jax_dataclasses._copy_and_mutate import _Mutability as Mutability
+
+from .jaxsim_dataclass import JaxsimDataclass
+from .tracing import not_tracing, tracing
+from .vmappable import Vmappable
+
+# Leave this below the others to prevent circular imports
+from .oop import jax_tf # isort: skip
diff --git a/src/jaxsim/utils.py b/src/jaxsim/utils/jaxsim_dataclass.py
similarity index 56%
rename from src/jaxsim/utils.py
rename to src/jaxsim/utils/jaxsim_dataclass.py
index 07a6bef5e..c995b4ff6 100644
--- a/src/jaxsim/utils.py
+++ b/src/jaxsim/utils/jaxsim_dataclass.py
@@ -1,44 +1,32 @@
import abc
import contextlib
import copy
-from typing import Any, ContextManager, TypeVar
+import dataclasses
+from typing import Generator
import jax.abstract_arrays
import jax.flatten_util
import jax.interpreters.partial_eval
import jax_dataclasses
-from jax_dataclasses._copy_and_mutate import _Mutability as Mutability
import jaxsim.typing as jtp
-T = TypeVar("T")
+from . import Mutability
-
-def tracing(var: Any) -> bool:
- """Returns True if the variable is being traced by JAX, False otherwise."""
-
- return jax.numpy.array(
- [
- isinstance(var, t)
- for t in (
- jax.abstract_arrays.ShapedArray,
- jax.interpreters.partial_eval.DynamicJaxprTracer,
- )
- ]
- ).any()
-
-
-def not_tracing(var: Any) -> bool:
- """Returns True if the variable is not being traced by JAX, False otherwise."""
-
- return True if tracing(var) is False else False
+try:
+ from typing import Self
+except ImportError:
+ from typing_extensions import Self
class JaxsimDataclass(abc.ABC):
""""""
+ # This attribute is set by jax_dataclasses
+ __mutability__ = None
+
@contextlib.contextmanager
- def editable(self: T, validate: bool = True) -> ContextManager[T]:
+ def editable(self: Self, validate: bool = True) -> Generator[Self, None, None]:
""""""
mutability = (
@@ -48,23 +36,34 @@ def editable(self: T, validate: bool = True) -> ContextManager[T]:
with JaxsimDataclass.mutable_context(self.copy(), mutability=mutability) as obj:
yield obj
- # with jax_dataclasses.copy_and_mutate(self, validate=validate) as self_rw:
- # yield self_rw
- #
- # self_rw._set_mutability(self._mutability())
-
@contextlib.contextmanager
- def mutable_context(self: T, mutability: Mutability) -> ContextManager[T]:
+ def mutable_context(
+ self: Self, mutability: Mutability, restore_after_exception: bool = True
+ ) -> Generator[Self, None, None]:
""""""
original_mutability = self._mutability()
- self._set_mutability(mutability)
- yield self
-
- self._set_mutability(original_mutability)
-
- def is_mutable(self: T, validate: bool = False) -> bool:
+ if restore_after_exception:
+ self_copy = copy.copy(self)
+
+ def restore_self():
+ self._set_mutability(mutability=Mutability.MUTABLE)
+ for f in dataclasses.fields(self_copy):
+ setattr(self, f.name, getattr(self_copy, f.name))
+
+ try:
+ self._set_mutability(mutability)
+ yield self
+ except Exception as e:
+ if restore_after_exception:
+ restore_self()
+ self._set_mutability(original_mutability)
+ raise e
+ finally:
+ self._set_mutability(original_mutability)
+
+ def is_mutable(self, validate: bool = False) -> bool:
""""""
return (
@@ -91,21 +90,21 @@ def _set_mutability(self, mutability: Mutability) -> None:
self, mutable=mutability, visited=set()
)
- def mutable(self: T, mutable: bool = True, validate: bool = False) -> T:
+ def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self:
self.set_mutability(mutable=mutable, validate=validate)
return self
- def copy(self: T) -> T:
- obj = copy.deepcopy(self)
+ def copy(self: Self) -> Self:
+ obj = jax.tree_util.tree_map(lambda leaf: leaf, self)
obj._set_mutability(mutability=self._mutability())
return obj
- def replace(self: T, validate: bool = True, **kwargs) -> T:
+ def replace(self: Self, validate: bool = True, **kwargs) -> Self:
with self.editable(validate=validate) as obj:
_ = [obj.__setattr__(k, copy.copy(v)) for k, v in kwargs.items()]
obj._set_mutability(mutability=self._mutability())
return obj
- def flatten(self: T) -> jtp.VectorJax:
+ def flatten(self) -> jtp.VectorJax:
return jax.flatten_util.ravel_pytree(self)[0]
diff --git a/src/jaxsim/utils/oop.py b/src/jaxsim/utils/oop.py
new file mode 100644
index 000000000..cc4099446
--- /dev/null
+++ b/src/jaxsim/utils/oop.py
@@ -0,0 +1,497 @@
+import contextlib
+import dataclasses
+import functools
+import inspect
+import os
+from typing import Any, Callable, Generator
+
+import jax
+import jax.flatten_util
+
+from jaxsim import logging
+from jaxsim.utils import tracing
+
+from . import Mutability, Vmappable
+
+
+class jax_tf:
+ """
+ Class containing decorators applicable to methods of Vmappable objects.
+ """
+
+ # Environment variables that can be used to disable the transformations
+ EnvVarOOP: str = "JAXSIM_OOP_DECORATORS"
+ EnvVarJitOOP: str = "JAXSIM_OOP_DECORATORS_JIT"
+ EnvVarVmapOOP: str = "JAXSIM_OOP_DECORATORS_VMAP"
+
+ @staticmethod
+ def method_ro(
+ fn: Callable,
+ jit: bool = True,
+ static_argnames: tuple[str, ...] | list[str] = (),
+ vmap: bool | None = None,
+ vmap_in_axes: tuple[int, ...] | int | None = None,
+ vmap_out_axes: tuple[int, ...] | int | None = None,
+ ):
+ """
+ Decorator for r/o methods of classes inheriting from Vmappable.
+ """
+
+ return jax_tf.method(
+ fn=fn,
+ read_only=True,
+ validate=True,
+ jit_enabled=jit,
+ static_argnames=static_argnames,
+ vmap_enabled=vmap,
+ vmap_in_axes=vmap_in_axes,
+ vmap_out_axes=vmap_out_axes,
+ )
+
+ @staticmethod
+ def method_rw(
+ fn: Callable,
+ validate: bool = True,
+ jit: bool = True,
+ static_argnames: tuple[str, ...] | list[str] = (),
+ vmap: bool | None = None,
+ vmap_in_axes: tuple[int, ...] | int | None = None,
+ vmap_out_axes: tuple[int, ...] | int | None = None,
+ ):
+ """
+ Decorator for r/w methods of classes inheriting from Vmappable.
+ """
+
+ return jax_tf.method(
+ fn=fn,
+ read_only=False,
+ validate=validate,
+ jit_enabled=jit,
+ static_argnames=static_argnames,
+ vmap_enabled=vmap,
+ vmap_in_axes=vmap_in_axes,
+ vmap_out_axes=vmap_out_axes,
+ )
+
+ @staticmethod
+ def method(
+ fn: Callable,
+ read_only: bool = True,
+ validate: bool = True,
+ jit_enabled: bool = True,
+ static_argnames: tuple[str, ...] | list[str] = (),
+ vmap_enabled: bool | None = None,
+ vmap_in_axes: tuple[int, ...] | int | None = None,
+ vmap_out_axes: tuple[int, ...] | int | None = None,
+ ):
+ """
+ Decorator for methods of classes inheriting from Vmappable.
+
+ This decorator enables executing the methods on an object characterized by a
+ desired mutability, that is selected considering the r/o and validation flags.
+ It also allows to transform the method with the jit/vmap transformations.
+ If the Vmappable object is vectorized, the method is automatically vmapped, and
+ the in_axes are properly post-processed to simplify the combination with jit.
+
+ Args:
+ fn: The method to decorate.
+ read_only: Whether the method operates on a read-only object.
+ validate: Whether r/w methods should preserve the pytree structure.
+ jit_enabled: Whether to apply the jit transformation.
+ static_argnames: The names of the arguments that should be static.
+ vmap_enabled: Whether to apply the vmap transformation.
+ vmap_in_axes: The in_axes to use for the vmap transformation.
+ vmap_out_axes: The out_axes to use for the vmap transformation.
+
+ Returns:
+ The decorated method.
+ """
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ """The wrapper function that is returned by this decorator."""
+
+ # Methods of classes inheriting from Vmappable decorated by this wrapper
+ # automatically support jit/vmap/mutability features when called standalone.
+ # However, when objects are arguments of plain functions transformed with
+ # jit/vmap, and decorated methods are called inside those functions, we need
+ # to disable this decorator to avoid double wrapping and execution errors.
+ # We do so by iterating over the arguments, and checking whether they are
+ # being traced by JAX.
+ for argument in list(args) + list(kwargs.values()):
+ try:
+ argument_flat, _ = jax.flatten_util.ravel_pytree(argument)
+
+ if tracing(argument_flat):
+ return fn(*args, **kwargs)
+ except:
+ continue
+
+ # ===============================================================
+ # Wrap fn so that jit/vmap/mutability transformations are applied
+ # ===============================================================
+
+ # Initialize the mutability of the instance over which the method is running.
+ # * In r/o methods, this approach prevents any type of mutation.
+ # * In r/w methods, this approach allows to catch early JIT recompilations
+ # caused by unwanted changes in the pytree structure.
+ if read_only:
+ mutability = Mutability.FROZEN
+ else:
+ mutability = (
+ Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
+ )
+
+ # Extract the class instance over which fn is called
+ instance: Vmappable = args[0]
+ assert isinstance(instance, Vmappable)
+
+ # Inspect the environment to detect whether to enforce disabling jit/vmap
+ deco_on = jax_tf.env_var_on(jax_tf.EnvVarOOP)
+ jit_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarJitOOP) and deco_on
+ vmap_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarVmapOOP) and deco_on
+
+ # Get the transformed function (possibly cached by functools.cache).
+ # Note that all the arguments of the following methods, when hashed, should
+ # uniquely identify the returned function so that a new function is built
+ # when arguments change and either jit or vmap have to be called again.
+ fn_db = jax_tf.wrap_fn(
+ fn=fn, # noqa
+ mutability=mutability,
+ jit=jit_enabled_env and jit_enabled,
+ static_argnames=tuple(static_argnames),
+ vmap=vmap_enabled_env
+ and (
+ vmap_enabled is True
+ or (vmap_enabled is None and instance.vectorized)
+ ),
+ in_axes=vmap_in_axes,
+ out_axes=vmap_out_axes,
+ )
+
+ # Call the transformed (mutable/jit/vmap) method
+ out, obj = fn_db(*args, **kwargs)
+
+ if read_only:
+ return out
+
+ # =================================================================
+ # From here we assume that the wrapper is operating on a r/w method
+ # =================================================================
+
+ from jax_dataclasses._dataclasses import JDC_STATIC_MARKER
+
+ # Select the right runtime mutability. The only difference here is when a r/w
+ # method is called on a frozen object. In this case, we enable updating the
+ # pytree data and preserve its structure only if validation is enabled.
+ mutability_dict = {
+ Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION,
+ Mutability.MUTABLE: Mutability.MUTABLE,
+ Mutability.FROZEN: Mutability.MUTABLE
+ if validate
+ else Mutability.MUTABLE_NO_VALIDATION,
+ }
+
+ # We need to replace all the dynamic leafs of the original instance with those
+ # computed by the functional transformation.
+ # We do so by iterating over the fields of the jax_dataclasses and ignoring
+ # all the fields that are marked as static.
+ with instance.mutable_context(
+ mutability=mutability_dict[instance._mutability()]
+ ):
+ for f in dataclasses.fields(instance): # noqa
+ if (
+ hasattr(f, "type")
+ and hasattr(f.type, "__metadata__")
+ and JDC_STATIC_MARKER in f.type.__metadata__
+ ):
+ continue
+
+ try:
+ setattr(instance, f.name, getattr(obj, f.name))
+ except AssertionError:
+ raise RuntimeError(
+ "Failed to update field '{}' (old={}|new={})".format(
+ f.name, getattr(instance, f.name), getattr(obj, f.name)
+ )
+ )
+
+ return out
+
+ return wrapper
+
+ @staticmethod
+ @functools.cache
+ def wrap_fn(
+ fn: Callable,
+ mutability: Mutability,
+ jit: bool,
+ static_argnames: tuple[str, ...] | list[str],
+ vmap: bool,
+ in_axes: tuple[int, ...] | int | None,
+ out_axes: tuple[int, ...] | int | None,
+ ) -> Callable:
+ """
+ Transform a method with jit/vmap and execute it on an object characterized
+ by the desired mutability.
+
+ Note:
+ The method should take the object (self) as first argument.
+
+ Note:
+ This returned transformed method is cached by considering the hash of all
+ the arguments. It will re-apply jit/vmap transformations only if needed.
+
+ Args:
+ fn: The method to consider.
+ mutability: The mutability of the object on which the method is called.
+ jit: Whether to apply jit transformations.
+ static_argnames: The names of the arguments that should be considered static.
+ vmap: Whether to apply vmap transformations.
+ in_axes: The axes along which to vmap input arguments.
+ out_axes: The axes along which to vmap output arguments.
+
+ Note:
+ In order to simplify the application of vmap, we close the method arguments
+ over all the non-mapped input arguments. Furthermore, for improving the
+ compatibility with jit, we also close the vmap application over the static
+ arguments.
+
+ Returns:
+ The transformed method operating on an object with the desired mutability.
+ We maintain the same signature of the original method.
+ """
+
+ # Extract the signature of the function
+ sig = inspect.signature(fn)
+
+ # All static arguments must be actual arguments of fn
+ for name in static_argnames:
+ if name not in sig.parameters:
+ raise ValueError(f"Static argument '{name}' not found in {fn}")
+
+ # If in_axes is a tuple, its dimension should match the number of arguments
+ if isinstance(in_axes, tuple) and len(in_axes) != len(sig.parameters):
+ msg = "The length of 'in_axes' must match the number of arguments ({})"
+ raise ValueError(msg.format(len(sig.parameters)))
+
+ # Check that static arguments are not mapped with vmap.
+ # This case would not work since static arguments are not traces and vmap need
+ # to trace arguments in order to map them.
+ if isinstance(in_axes, tuple):
+ for mapped_axis, arg_name in zip(in_axes, sig.parameters.keys()):
+ if mapped_axis is not None and arg_name in static_argnames:
+ raise ValueError(
+ f"Static argument '{arg_name}' cannot be mapped with vmap"
+ )
+
+ def fn_tf_vmap(function_to_vmap: Callable, *args, **kwargs):
+ """Wrapper applying the vmap transformation"""
+
+ # Canonicalize the arguments so that all of them are kwargs
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+
+ # Build a dictionary mapping all arguments to a mapped axis, even when
+ # the None is passed (defaults to in_axes=0) or and int is passed (defaults
+ # to in_axes=).
+ match in_axes:
+ case None:
+ argname_to_mapped_axis = {name: 0 for name in bound.arguments}
+ case tuple():
+ argname_to_mapped_axis = {
+ name: in_axes[i] for i, name in enumerate(bound.arguments)
+ }
+ case int():
+ argname_to_mapped_axis = {name: in_axes for name in bound.arguments}
+ case _:
+ raise ValueError(in_axes)
+
+ # Build a dictionary (argument_name -> argument) for all mapped arguments.
+ # Note that a mapped argument is an argument whose axis is not None and
+ # is not a static jit argument.
+ vmap_mapped_args = {
+ arg: value
+ for arg, value in bound.arguments.items()
+ if argname_to_mapped_axis[arg] is not None
+ and arg not in static_argnames
+ }
+
+ # Build a dictionary (argument_name -> argument) for all unmapped arguments
+ vmap_unmapped_args = {
+ arg: value
+ for arg, value in bound.arguments.items()
+ if arg not in vmap_mapped_args
+ }
+
+ # Close the function over the unmapped arguments of vmap
+ fn_closed = functools.partial(function_to_vmap, **vmap_unmapped_args)
+
+ # Create the in_axes tuple of only the mapped arguments
+ in_axes_mapped = tuple(
+ argname_to_mapped_axis[name] for name in vmap_mapped_args
+ )
+
+ # If all in_axes are the same, simplify in_axes tuple to be just an integer
+ if len(set(in_axes_mapped)) == 1:
+ in_axes_mapped = list(set(in_axes_mapped))[0]
+
+ # If, instead, in_axes has different elements, we need to replace the mapped
+ # axis of "self" with a pytree having as leafs the mapped axis.
+ # This is because the vmap in_axes specification must be a tree prefix of
+ # the corresponding value.
+ if isinstance(in_axes_mapped, tuple) and "self" in vmap_mapped_args:
+ argname_to_mapped_axis["self"] = jax.tree_util.tree_map(
+ lambda _: argname_to_mapped_axis["self"], vmap_mapped_args["self"]
+ )
+ in_axes_mapped = tuple(
+ argname_to_mapped_axis[name] for name in vmap_mapped_args
+ )
+
+ # Apply the vmap transformation and call the function passing only the
+ # mapped arguments. The unmapped arguments have been closed over.
+ # Note: that we altered the "in_axes" tuple so that it does not have any
+ # None elements.
+ # Note: if in_axes_mapped is a tuple, the following fails if we pass kwargs,
+ # we need to pass the unpacked args tuple instead.
+ return jax.vmap(
+ fn_closed,
+ in_axes=in_axes_mapped,
+ **dict(out_axes=out_axes) if out_axes is not None else {},
+ )(*list(vmap_mapped_args.values()))
+
+ def fn_tf_jit(function_to_jit: Callable, *args, **kwargs):
+ """Wrapper applying the jit transformation"""
+
+ # Canonicalize the arguments so that all of them are kwargs
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+
+ # Apply the jit transformation and call the function passing all arguments
+ # as keyword arguments
+ return jax.jit(function_to_jit, static_argnames=static_argnames)(
+ **bound.arguments
+ )
+
+ # First applied wrapper that executes fn in a mutable context
+ fn_mutable = functools.partial(
+ jax_tf.call_class_method_in_mutable_context, fn, jit, mutability
+ )
+
+ # Second applied wrapper that transforms fn with vmap
+ fn_vmap = fn_mutable if not vmap else functools.partial(fn_tf_vmap, fn_mutable)
+
+ # Third applied wrapper that transforms fn with jit
+ fn_jit_vmap = fn_vmap if not jit else functools.partial(fn_tf_jit, fn_vmap)
+
+ return fn_jit_vmap
+
+ @staticmethod
+ def call_class_method_in_mutable_context(
+ fn: Callable, jit: bool, mutability: Mutability, *args, **kwargs
+ ) -> tuple[Any, Vmappable]:
+ """
+ Wrapper to call a method on an object with the desired mutable context.
+
+ Args:
+ fn: The method to call.
+ jit: Whether the method is being jit compiled or not.
+ mutability: The desired mutability context.
+ *args: The positional arguments to pass to the method (including self).
+ **kwargs: The keyword arguments to pass to the method.
+
+ Returns:
+ A tuple containing the return value of the method and the object
+ possibly updated by the method if it is in read-write.
+
+ Note:
+ This approach enables to jit-compile methods of a stateful object without
+ leaking traces, therefore obtaining a jax-compatible OOP pattern.
+ """
+
+ # Log here whether the method is being jit compiled or not.
+ # This log message does not get printed from compiled code, so here is the
+ # most appropriate place to be sure that we log it correctly.
+ if jit:
+ logging.debug(msg=f"JIT compiling {fn}")
+
+ # Canonicalize the arguments so that all of them are kwargs
+ sig = inspect.signature(fn)
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+
+ # Extract the class instance over which fn is called
+ instance: Vmappable = bound.arguments["self"]
+
+ # Select the right mutability. If the instance is mutable with validation
+ # disabled, we override the input mutability so that we do not fail in case
+ # of mismatched tree structure.
+ mut = (
+ Mutability.MUTABLE_NO_VALIDATION
+ if instance._mutability() is Mutability.MUTABLE_NO_VALIDATION
+ else mutability
+ )
+
+ # Call fn in a mutable context
+ with instance.mutable_context(mutability=mut):
+ # Methods could call other decorated methods. When it happens, the decorator
+ # of the called method is invoked, that applies jit and vmap transformations.
+ # This is not desired as it calls vmap inside an already vmapped method.
+ # We work around this occurrence by disabling the jit/vmap decorators of all
+ # methods called inside fn through a context manager.
+ # Note that we already work around this in the beginning of the wrapper
+ # function by detecting traced arguments, but the decorator works also
+ # when jit=False and vmap=False, therefore only enforcing the mutability.
+ with jax_tf.disabled_oop_decorators():
+ out = fn(**bound.arguments)
+
+ return out, instance
+
+ @staticmethod
+ def env_var_on(var_name: str, default_value: str = "1") -> bool:
+ """
+ Check whether an environment variable is set to a value that is considered on.
+
+ Args:
+ var_name: The name of the environment variable.
+ default_value: The default variable value to consider if the variable has not
+ been exported.
+
+ Returns:
+ True if the environment variable contains an on value, False otherwise.
+ """
+
+ on_values = {"1", "true", "on", "yes"}
+ return os.environ.get(var_name, default_value).lower() in on_values
+
+ @staticmethod
+ @contextlib.contextmanager
+ def disabled_oop_decorators() -> Generator[None, None, None]:
+ """
+ Context manager to disable the application of jax transformations performed by
+ the decorators of this class.
+
+ Note: when the transformations are disabled, the only logic still applied is
+ the selection of the object mutability over which the method is running.
+ """
+
+ # Check whether the environment variable is part of the environment and
+ # save its value. We restore the original value before exiting the context.
+ env_cache = (
+ None if jax_tf.EnvVarOOP not in os.environ else os.environ[jax_tf.EnvVarOOP]
+ )
+
+ # Disable both jit and vmap transformations
+ os.environ[jax_tf.EnvVarOOP] = "0"
+
+ try:
+ # Execute the code in the context with disabled transformations
+ yield
+
+ finally:
+ # Restore the original value of the environment variable or remove it if
+ # it was not present before entering the context
+ if env_cache is not None:
+ os.environ[jax_tf.EnvVarOOP] = env_cache
+ else:
+ _ = os.environ.pop(jax_tf.EnvVarOOP)
diff --git a/src/jaxsim/utils/tracing.py b/src/jaxsim/utils/tracing.py
new file mode 100644
index 000000000..e85deeb4e
--- /dev/null
+++ b/src/jaxsim/utils/tracing.py
@@ -0,0 +1,26 @@
+from typing import Any
+
+import jax._src.core
+import jax.abstract_arrays
+import jax.flatten_util
+import jax.interpreters.partial_eval
+
+
+def tracing(var: Any) -> bool | jax.Array:
+ """Returns True if the variable is being traced by JAX, False otherwise."""
+
+ return jax.numpy.array(
+ [
+ isinstance(var, t)
+ for t in (
+ jax._src.core.Tracer,
+ jax.interpreters.partial_eval.DynamicJaxprTracer,
+ )
+ ]
+ ).any()
+
+
+def not_tracing(var: Any) -> bool:
+ """Returns True if the variable is not being traced by JAX, False otherwise."""
+
+ return True if tracing(var) is False else False
diff --git a/src/jaxsim/utils/vmappable.py b/src/jaxsim/utils/vmappable.py
new file mode 100644
index 000000000..0e449f4b8
--- /dev/null
+++ b/src/jaxsim/utils/vmappable.py
@@ -0,0 +1,117 @@
+import dataclasses
+from typing import Type
+
+import jax
+import jax.numpy as jnp
+import jax_dataclasses
+
+from . import JaxsimDataclass, Mutability
+
+try:
+ from typing import Self
+except ImportError:
+ from typing_extensions import Self
+
+
+@jax_dataclasses.pytree_dataclass
+class Vmappable(JaxsimDataclass):
+ """Abstract class with utilities for vmappable pytrees."""
+
+ batch_size: jax_dataclasses.Static[int] = dataclasses.field(
+ default=int(0), repr=False, compare=False, hash=False, kw_only=True
+ )
+
+ @property
+ def vectorized(self) -> bool:
+ """Marks this pytree as vectorized."""
+
+ return self.batch_size > 0
+
+ @classmethod
+ def build_from_list(cls: Type[Self], list_of_obj: list[Self]) -> Self:
+ """
+ Build a vectorized pytree from a list of pytree of the same type.
+
+ Args:
+ list_of_obj: The list of pytrees to vectorize.
+
+ Returns:
+ The vectorized pytree having as leaves the stacked leaves of the input list.
+ """
+
+ if set(type(el) for el in list_of_obj) != {cls}:
+ msg = "The input list must contain only objects of type '{}'"
+ raise ValueError(msg.format(cls.__name__))
+
+ # Create a pytree by stacking all the leafs of the input list
+ data_vec: Vmappable = jax.tree_map(
+ lambda *leafs: jnp.array(leafs), *list_of_obj
+ )
+
+ # Store the batch dimension
+ with data_vec.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ data_vec.batch_size = len(list_of_obj)
+
+ # Detect the most common mutability in the input list
+ mutabilities = [e._mutability() for e in list_of_obj]
+ mutability = max(set(mutabilities), key=mutabilities.count)
+
+ # Update the mutability of the vectorized pytree
+ data_vec._set_mutability(mutability)
+
+ return data_vec
+
+ def vectorize(self: Self, batch_size: int) -> Self:
+ """
+ Return a vectorized version of this pytree.
+
+ Args:
+ batch_size: The batch size.
+
+ Returns:
+ A vectorized version of this pytree obtained by stacking the leaves of the
+ original pytree along a new batch dimension (the first one).
+ """
+
+ if self.vectorized:
+ raise RuntimeError("Cannot vectorize an already vectorized object")
+
+ if batch_size == 0:
+ return self.copy()
+
+ # TODO validate if mutability is maintained
+
+ return self.__class__.build_from_list(list_of_obj=[self] * batch_size)
+
+ def extract_element(self: Self, index: int) -> Self:
+ """
+ Extract the i-th element from a vectorized pytree.
+
+ Args:
+ index: The index of the element to extract.
+
+ Returns:
+ A non vectorized pytree obtained by extracting the i-th element from the
+ vectorized pytree.
+ """
+
+ if index < 0:
+ raise ValueError("The index of the desired element cannot be negative")
+
+ if index == 0 and self.batch_size == 0:
+ return self.copy()
+
+ if not self.vectorized:
+ raise RuntimeError("Cannot extract elements from a non-vectorized object")
+
+ if index >= self.batch_size:
+ raise ValueError("The index must be smaller than the batch size")
+
+ # Get the i-th pytree by extracting the i-th element from the vectorized pytree
+ data = jax.tree_map(lambda leaf: leaf[index], self)
+
+ # Update the batch size of the extracted scalar pytree
+ with data.mutable_context(mutability=Mutability.MUTABLE):
+ data.batch_size = 0
+
+ return data
diff --git a/tests/test_jax_oop.py b/tests/test_jax_oop.py
new file mode 100644
index 000000000..f887ecfae
--- /dev/null
+++ b/tests/test_jax_oop.py
@@ -0,0 +1,422 @@
+import dataclasses
+import io
+from contextlib import redirect_stdout
+from typing import Any, Type
+
+import jax
+import jax.numpy as jnp
+import jax_dataclasses
+import numpy as np
+import pytest
+
+from jaxsim.utils import Mutability, Vmappable, oop
+
+try:
+ from typing import Self
+except ImportError:
+ from typing_extensions import Self
+
+
+@jax_dataclasses.pytree_dataclass
+class AlgoData(Vmappable):
+ """Class storing vmappable data of a given algorithm."""
+
+ counter: jax.Array = dataclasses.field(
+ default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
+ )
+
+ @classmethod
+ def build(cls: Type[Self], counter: jax.typing.ArrayLike) -> Self:
+ """Builder method. Helpful for enforcing type and shape of fields."""
+
+ # Counter can be int / scalar numpy array / scalar jax array / etc.
+ if jnp.array(counter).squeeze().size != 1:
+ raise ValueError("The counter must be a scalar")
+
+ # Create the object enforcing `counter` to be a scalar jax array
+ data = AlgoData(
+ counter=jnp.array(counter, dtype=jnp.uint64).squeeze(),
+ )
+
+ return data
+
+
+def test_data():
+ """Test AlgoData class."""
+
+ data1 = AlgoData.build(counter=0)
+ data2 = AlgoData.build(counter=np.array(10))
+ data3 = AlgoData.build(counter=jnp.array(50))
+
+ assert isinstance(data1.counter, jax.Array) and data1.counter.dtype == jnp.uint64
+ assert isinstance(data2.counter, jax.Array) and data2.counter.dtype == jnp.uint64
+ assert isinstance(data3.counter, jax.Array) and data3.counter.dtype == jnp.uint64
+
+ assert data1.batch_size == 0
+ assert data2.batch_size == 0
+ assert data3.batch_size == 0
+
+ # ==================
+ # Vectorizing PyTree
+ # ==================
+
+ for batch_size in (0, 10, 100):
+ data_vec = data1.vectorize(batch_size=batch_size)
+
+ assert data_vec.batch_size == batch_size
+
+ if batch_size > 0:
+ assert data_vec.counter.shape[0] == batch_size
+
+ # =========================================
+ # Extracting element from vectorized PyTree
+ # =========================================
+
+ data_vec = AlgoData.build_from_list(list_of_obj=[data1, data2, data3])
+ assert data_vec.batch_size == 3
+ assert data_vec.extract_element(index=0) == data1
+ assert data_vec.extract_element(index=1) == data2
+ assert data_vec.extract_element(index=2) == data3
+
+ with pytest.raises(ValueError):
+ _ = data_vec.extract_element(index=3)
+
+ out = data1.extract_element(index=0)
+ assert out == data1
+ assert id(out) != id(data1)
+
+ with pytest.raises(RuntimeError):
+ _ = data1.extract_element(index=1)
+
+ with pytest.raises(ValueError):
+ _ = AlgoData.build_from_list(list_of_obj=[data1, data2, data3, 42])
+
+
+@jax_dataclasses.pytree_dataclass
+class MyClassWithAlgorithms(Vmappable):
+ """
+ Class to demonstrate how to use `Vmappable`.
+ """
+
+ # Dynamic data of the algorithm
+ data: AlgoData = dataclasses.field(default=None)
+
+ # Static attribute of the pytree (triggers recompilation if changed)
+ double_input: jax_dataclasses.Static[bool] = dataclasses.field(default=None)
+
+ # Non-static attribute of the pytree that is not transparently vmap-able.
+ done: jax.typing.ArrayLike = dataclasses.field(
+ default_factory=lambda: jnp.array(False, dtype=bool)
+ )
+
+ # Additional leaves to test the behaviour of mutable and immutable python objects
+ my_tuple: tuple[int] = dataclasses.field(default=tuple(jnp.array([1, 2, 3])))
+ my_list: list[int] = dataclasses.field(
+ default_factory=lambda: [4, 5, 6], init=False
+ )
+ my_array: jax.Array = dataclasses.field(
+ default_factory=lambda: jnp.array([10, 20, 30])
+ )
+
+ @classmethod
+ def build(cls: Type[Self], double_input: bool = False) -> Self:
+ """"""
+
+ obj = MyClassWithAlgorithms()
+
+ with obj.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ obj.data = AlgoData.build(counter=0)
+ obj.double_input = jnp.array(double_input)
+
+ return obj
+
+ @oop.jax_tf.method_ro
+ def algo_ro(self, advance: int | jax.typing.ArrayLike) -> Any:
+ """This is a read-only algorithm. It does not alter any pytree leaf."""
+
+ # This should be printed only the first execution since it is disabled
+ # in the execution of the JIT-compiled function.
+ print("__algo_ro__")
+
+ # Use the dynamic condition that doubles the input value
+ mul = jax.lax.select(self.double_input, 2, 1)
+
+ # Increase the counter
+ counter_old = jnp.atleast_1d(self.data.counter)[0]
+ counter_new = counter_old + mul * advance
+
+ # Return the updated counter
+ return counter_new
+
+ @oop.jax_tf.method_rw
+ def algo_rw(self, advance: int | jax.typing.ArrayLike) -> Any:
+ """
+ This is a read-write algorithm. It may alter pytree leaves either belonging
+ to the vmappable data or generic non-static dataclass attributes.
+ """
+
+ print(self)
+
+ # This should be printed only the first execution since it is disabled
+ # in the execution of the JIT-compiled function.
+ print("__algo_rw__")
+
+ # Use the dynamic condition that doubles the input value
+ mul = jax.lax.select(self.double_input, 2, 1)
+
+ # Increase the internal counter
+ counter_old = jnp.atleast_1d(self.data.counter)[0]
+ self.data.counter = jnp.array(counter_old + mul * advance, dtype=jnp.uint64)
+
+ # Update the non-static and non-vmap-able attribute
+ self.done = jax.lax.cond(
+ pred=self.data.counter > 100,
+ true_fun=lambda _: jnp.array(True),
+ false_fun=lambda _: jnp.array(False),
+ operand=None,
+ )
+
+ print(self)
+
+ # Return the updated counter
+ return self.data.counter
+
+
+def test_mutability():
+ """Test MyClassWithAlgorithms class."""
+
+ # Build the object
+ obj_ro = MyClassWithAlgorithms.build(double_input=True)
+
+ # By default, pytrees built with jax_dataclasses are frozen (read-only)
+ assert obj_ro._mutability() == Mutability.FROZEN
+ with pytest.raises(dataclasses.FrozenInstanceError):
+ obj_ro.data.counter = 42
+
+ # Data can be changed through a context manager, in this case operating on a copy...
+ with obj_ro.editable(validate=True) as obj_ro_copy:
+ obj_ro_copy.data.counter = jnp.array(42, dtype=obj_ro.data.counter.dtype)
+ assert obj_ro_copy.data.counter == pytest.approx(42)
+ assert obj_ro.data.counter != pytest.approx(42)
+
+ # ... or a context manager that does not copy the pytree...
+ with obj_ro.mutable_context(mutability=Mutability.MUTABLE):
+ obj_ro.data.counter = jnp.array(42, dtype=obj_ro.data.counter.dtype)
+ assert obj_ro.data.counter == pytest.approx(42)
+
+ # ... that raises if the leafs change type
+ with pytest.raises(AssertionError):
+ with obj_ro.mutable_context(mutability=Mutability.MUTABLE):
+ obj_ro.data.counter = 42
+
+ # Pytrees can be copied...
+ obj_ro_copy = obj_ro.copy()
+ assert id(obj_ro) != id(obj_ro_copy)
+ # ... operation that does not copy the leaves
+ # TODO describe
+ assert id(obj_ro.done) == id(obj_ro_copy.done)
+ assert id(obj_ro.data.counter) == id(obj_ro_copy.data.counter)
+ assert id(obj_ro.my_array) == id(obj_ro_copy.my_array)
+ assert id(obj_ro.my_tuple) != id(obj_ro_copy.my_tuple)
+ assert id(obj_ro.my_list) != id(obj_ro_copy.my_list)
+
+ # They can be converted as mutable pytrees to update their values without
+ # using context managers (maybe useful for debugging or quick prototyping)
+ obj_rw = obj_ro.copy().mutable(validate=True)
+ assert obj_rw._mutability() == Mutability.MUTABLE
+ obj_rw.data.counter = jnp.array(42, dtype=obj_rw.data.counter.dtype)
+
+ # However, with validation enabled, this works only if the leaf does not
+ # change its type (shape, dtype, weakness, ...)
+ with pytest.raises(AssertionError):
+ obj_rw.data.counter = 100
+ with pytest.raises(AssertionError):
+ obj_rw.data.counter = jnp.array(100, dtype=float)
+ with pytest.raises(AssertionError):
+ obj_rw.data.counter = jnp.array([100, 200], dtype=obj_rw.data.counter.dtype)
+
+ # Instead, with validation disabled, the pytree structure can be altered
+ # (and this might cause JIT recompilations, so use it at your own risk)
+ obj_rw_noval = obj_ro.copy().mutable(validate=False)
+ assert obj_rw_noval._mutability() == Mutability.MUTABLE_NO_VALIDATION
+ obj_rw_noval.data.counter = jnp.array(42, dtype=obj_rw.data.counter.dtype)
+
+ # Now this should work without exceptions
+ obj_rw_noval.data.counter = 100
+ obj_rw_noval.data.counter = jnp.array(100, dtype=float)
+ obj_rw_noval.data.counter = jnp.array([100, 200], dtype=obj_rw.data.counter.dtype)
+
+ # Build another object and check mutability changes
+ obj_ro = MyClassWithAlgorithms.build(double_input=True)
+ assert obj_ro.is_mutable(validate=True) is False
+ assert obj_ro.is_mutable(validate=False) is False
+
+ obj_rw_val = obj_ro.mutable(validate=True)
+ assert id(obj_ro) == id(obj_rw_val)
+ assert obj_rw_val.is_mutable(validate=True) is True
+ assert obj_rw_val.is_mutable(validate=False) is False
+
+ obj_rw_noval = obj_rw_val.mutable(validate=False)
+ assert id(obj_rw_noval) == id(obj_rw_val)
+ assert obj_rw_noval.is_mutable(validate=True) is False
+ assert obj_rw_noval.is_mutable(validate=False) is True
+
+ # Checking mutable leaves behavior
+ obj_rw = MyClassWithAlgorithms.build(double_input=True).mutable(validate=True)
+ obj_rw_copy = obj_rw.copy()
+
+ # Memory of JAX arrays cannot be altered in place so this is safe
+ obj_rw.my_array = obj_rw.my_array.at[1].set(-20)
+ assert obj_rw_copy.my_array[1] != -20
+
+ # Tuples are immutable so this should be safe too
+ obj_rw.my_tuple = tuple(jnp.array([1, -2, 3]))
+ assert obj_rw_copy.my_array[1] != -2
+
+ # Lists are treated as tuples (they are not leaves) but since they are mutable,
+ # their id changes
+ obj_rw.my_list[1] = -5
+ assert obj_rw_copy.my_list[1] != -5
+
+ # Check that exceptions in mutable context do not alter the object
+ obj_ro = MyClassWithAlgorithms.build(double_input=True)
+ assert obj_ro.data.counter == 0
+ assert obj_ro.double_input == jnp.array(True)
+
+ with pytest.raises(RuntimeError):
+ with obj_ro.mutable_context(mutability=Mutability.MUTABLE):
+ obj_ro.double_input = jnp.array(False, dtype=obj_ro.double_input.dtype)
+ obj_ro.data.counter = jnp.array(33, dtype=obj_ro.data.counter.dtype)
+ raise RuntimeError
+ assert obj_ro.data.counter == 0
+ assert obj_ro.double_input == jnp.array(True)
+
+
+def test_decorators_jit_compilation():
+ """Test JIT features of MyClassWithAlgorithms class."""
+
+ obj = MyClassWithAlgorithms.build(double_input=False)
+ assert obj.data.counter == 0
+ assert obj.is_mutable(validate=True) is False
+ assert obj.is_mutable(validate=False) is False
+
+ # JIT compilation should happen only the first function call.
+ # We test this by checking that the first execution prints some output.
+ with io.StringIO() as buf, redirect_stdout(buf):
+ _ = obj.algo_ro(advance=1)
+ printed = buf.getvalue()
+ assert "__algo_ro__" in printed
+ with io.StringIO() as buf, redirect_stdout(buf):
+ _ = obj.algo_ro(advance=1)
+ printed = buf.getvalue()
+ assert "__algo_ro__" not in printed
+
+ # JIT compilation should happen only the first function call.
+ # We test this by checking that the first execution prints some output.
+ with io.StringIO() as buf, redirect_stdout(buf):
+ _ = obj.algo_rw(advance=1)
+ printed = buf.getvalue()
+ assert "__algo_rw__" in printed
+ with io.StringIO() as buf, redirect_stdout(buf):
+ _ = obj.algo_rw(advance=1)
+ printed = buf.getvalue()
+ assert "__algo_rw__" not in printed
+
+ # Create a new object
+ obj = MyClassWithAlgorithms.build(double_input=False)
+
+ # New objects should be able to re-use the JIT-compiled functions from other objects
+ with io.StringIO() as buf, redirect_stdout(buf):
+ _ = obj.algo_ro(advance=1)
+ _ = obj.algo_rw(advance=1)
+ printed = buf.getvalue()
+ assert "__algo_ro__" not in printed
+ assert "__algo_rw__" not in printed
+
+ # Create a new object
+ obj = MyClassWithAlgorithms.build(double_input=False)
+
+ # Read-only methods can be called on r/o objects
+ out = obj.algo_ro(advance=1)
+ assert out == obj.data.counter + 1
+ out = obj.algo_ro(advance=1)
+ assert out == obj.data.counter + 1
+
+ # Read-write methods can be called too on r/o objects since they are marked as r/w
+ out = obj.algo_rw(advance=1)
+ assert out == 1
+ out = obj.algo_rw(advance=1)
+ assert out == 2
+ out = obj.algo_rw(advance=2)
+ assert out == 4
+
+ # Create a new object with a different dynamic attribute
+ obj_dyn = MyClassWithAlgorithms.build(double_input=False).mutable(validate=True)
+ obj_dyn.done = jnp.array(not obj_dyn.done, dtype=bool)
+
+ # New objects with different dynamic attributes should be able to re-use the
+ # JIT-compiled functions from other objects
+ with io.StringIO() as buf, redirect_stdout(buf):
+ _ = obj.algo_ro(advance=1)
+ _ = obj.algo_rw(advance=1)
+ printed = buf.getvalue()
+ assert "__algo_ro__" not in printed
+ assert "__algo_rw__" not in printed
+
+ # Create a new object with a different static attribute
+ obj_stat = MyClassWithAlgorithms.build(double_input=True)
+
+ # New objects with different static attributes trigger the recompilation of the
+ # JIT-compiled functions...
+ with io.StringIO() as buf, redirect_stdout(buf):
+ _ = obj_stat.algo_ro(advance=1)
+ _ = obj_stat.algo_rw(advance=1)
+ printed = buf.getvalue()
+ assert "__algo_ro__" in printed
+ assert "__algo_rw__" in printed
+
+ # ... that are cached as well by jax
+ with io.StringIO() as buf, redirect_stdout(buf):
+ _ = obj_stat.algo_ro(advance=1)
+ _ = obj_stat.algo_rw(advance=1)
+ printed = buf.getvalue()
+ assert "__algo_ro__" not in printed
+ assert "__algo_rw__" not in printed
+
+
+def test_decorators_vmap():
+ """Test automatic vectorization features of MyClassWithAlgorithms class."""
+
+ # Create a new object with scalar data
+ obj = MyClassWithAlgorithms.build(double_input=False)
+
+ # Vectorize the entire object
+ obj_vec = obj.vectorize(batch_size=10)
+ assert obj_vec.vectorized is True
+ assert obj_vec.batch_size == 10
+ assert id(obj_vec) != id(obj)
+
+ # Calling methods of vectorized objects with scalar arguments should raise an error
+ with pytest.raises(ValueError):
+ _ = obj_vec.algo_ro(advance=1)
+ with pytest.raises(ValueError):
+ _ = obj_vec.algo_rw(advance=1)
+
+ # Check that the r/o method provides automatically vectorized output and accepts
+ # vectorized input
+ out_vec = obj_vec.algo_ro(advance=jnp.array([1] * obj_vec.batch_size))
+ assert out_vec.shape[0] == 10
+ assert set(out_vec.tolist()) == {1}
+
+ # Check that the r/w method provides automatically vectorized output and accepts
+ # vectorized input
+ out_vec = obj_vec.algo_rw(advance=jnp.array([1] * obj_vec.batch_size))
+ assert out_vec.shape[0] == 10
+ assert set(out_vec.tolist()) == {1}
+ out_vec = obj_vec.algo_rw(advance=jnp.array([1] * obj_vec.batch_size))
+ assert set(out_vec.tolist()) == {2}
+
+ # Extract a single object from the vectorized object
+ obj = obj_vec.extract_element(index=5)
+ assert obj.vectorized is False
+ assert obj.data.counter == obj_vec.data.counter[5]
diff --git a/train.py b/train.py
new file mode 100644
index 000000000..f5d8b8700
--- /dev/null
+++ b/train.py
@@ -0,0 +1,562 @@
+import dataclasses
+import multiprocessing
+import pathlib
+from typing import Any, ClassVar, Optional
+
+import jax.numpy as jnp
+import jax.random
+import jax_dataclasses
+import numpy as np
+import numpy.typing as npt
+import rod
+from meshcat_viz import MeshcatWorld
+from resolve_robotics_uri_py import resolve_robotics_uri
+
+import jaxgym.jax.pytree_space as spaces
+import jaxsim.typing as jtp
+from jaxgym.jax import JaxDataclassEnv, JaxEnv
+from jaxgym.vector.jax import JaxVectorEnv
+from jaxgym.wrappers.jax import (
+ FlattenSpacesWrapper,
+ JaxTransformWrapper,
+ TimeLimit,
+ ToNumPyWrapper,
+)
+from jaxsim import JaxSim
+from jaxsim.physics.algos.soft_contacts import SoftContactsParams
+from jaxsim.simulation import simulator_callbacks
+from jaxsim.simulation.ode_integration import IntegratorType
+from jaxsim.simulation.simulator import SimulatorData, VelRepr
+from jaxsim.utils import JaxsimDataclass, Mutability
+
+
+@jax_dataclasses.pytree_dataclass
+class ErgoCubObservation(JaxsimDataclass):
+ """Observation of the ErgoCub environment."""
+
+ base_height: jtp.Float
+ gravity_projection: jtp.Array
+
+ joint_positions: jtp.Array
+ joint_velocities: jtp.Array
+
+ base_linear_velocity: jtp.Array
+ base_angular_velocity: jtp.Array
+
+ contact_state: jtp.Array
+
+ @staticmethod
+ def build(
+ base_height: jtp.Float,
+ gravity_projection: jtp.Array,
+ joint_positions: jtp.Array,
+ joint_velocities: jtp.Array,
+ base_linear_velocity: jtp.Array,
+ base_angular_velocity: jtp.Array,
+ contact_state: jtp.Array,
+ ) -> "ErgoCubObservation":
+ """Build an ErgoCubObservation object."""
+
+ return ErgoCubObservation(
+ base_height=jnp.array(base_height, dtype=float),
+ gravity_projection=jnp.array(gravity_projection, dtype=float),
+ joint_positions=jnp.array(joint_positions, dtype=float),
+ joint_velocities=jnp.array(joint_velocities, dtype=float),
+ base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
+ base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
+ contact_state=jnp.array(contact_state, dtype=bool),
+ )
+
+
+@dataclasses.dataclass
+class MeshcatVizRenderState:
+ """Render state of a meshcat-viz visualizer."""
+
+ world: MeshcatWorld = dataclasses.dataclass(init=False)
+
+ _gui_process: Optional[multiprocessing.Process] = dataclasses.field(
+ default=None, init=False, repr=False, hash=False, compare=False
+ )
+
+ _jaxsim_to_meshcat_viz_name: dict[str, str] = dataclasses.field(
+ default_factory=dict, init=False, repr=False, hash=False, compare=False
+ )
+
+ def __post_init__(self) -> None:
+ """"""
+
+ self.world = MeshcatWorld()
+ self.world.open()
+
+ def close(self) -> None:
+ """"""
+
+ if self.world is not None:
+ self.world.close()
+
+ if self._gui_process is not None:
+ self._gui_process.terminate()
+ self._gui_process.close()
+
+ @staticmethod
+ def open_window(web_url: str) -> None:
+ """Open a new window with the given web url."""
+
+ import webview
+
+ print(web_url)
+ webview.create_window("meshcat", web_url)
+ webview.start(gui="qt")
+
+ def open_window_in_process(self) -> None:
+ """"""
+
+ if self._gui_process is not None:
+ self._gui_process.terminate()
+ self._gui_process.close()
+
+ self._gui_process = multiprocessing.Process(
+ target=MeshcatVizRenderState.open_window, args=(self.world.web_url,)
+ )
+ self._gui_process.start()
+
+
+StateType = dict[str, SimulatorData | jtp.Array]
+ActType = jnp.ndarray
+ObsType = ErgoCubObservation
+RewardType = float | jnp.ndarray
+TerminalType = bool | jnp.ndarray
+RenderStateType = MeshcatVizRenderState
+
+
+@jax_dataclasses.pytree_dataclass
+class ErgoCubWalkFuncEnvV0(
+ JaxDataclassEnv[
+ StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ ]
+):
+ """ErgoCub environment implementing a target reaching task."""
+
+ name: ClassVar = jax_dataclasses.static_field(default="ErgoCubWalkFuncEnvV0")
+
+ # Store an instance of the JaxSim simulator.
+ # It gets initialized with SimulatorData with a functional approach.
+ _simulator: JaxSim = jax_dataclasses.field(default=None)
+
+ def __post_init__(self) -> None:
+ """Environment initialization."""
+
+ # Dummy initialization (not needed here)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ _ = self.jaxsim
+ model = self.jaxsim.get_model(model_name="ErgoCub")
+
+ # Create the action space (static attribute)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ high = jnp.array([25.0] * model.dofs(), dtype=float)
+ self._action_space = spaces.PyTree(low=-high, high=high)
+
+ # Get joint limits
+ s_min, s_max = model.joint_limits()
+ s_range = s_max - s_min
+
+ low = ErgoCubObservation.build(
+ base_height=0.25,
+ gravity_projection=-jnp.ones(3),
+ joint_positions=s_min,
+ joint_velocities=-50.0 * jnp.ones_like(s_min),
+ base_linear_velocity=-5.0 * jnp.ones(3),
+ base_angular_velocity=-10.0 * jnp.ones(3),
+ contact_state=jnp.array([False] * 4),
+ )
+
+ high = ErgoCubObservation.build(
+ base_height=1.0,
+ gravity_projection=jnp.ones(3),
+ joint_positions=s_max,
+ joint_velocities=50.0 * jnp.ones_like(s_max),
+ base_linear_velocity=5.0 * jnp.ones(3),
+ base_angular_velocity=10.0 * jnp.ones(3),
+ contact_state=jnp.array([True] * 4),
+ )
+
+ # Create the observation space (static attribute)
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ self._observation_space = spaces.PyTree(low=low, high=high)
+
+ @property
+ def jaxsim(self) -> JaxSim:
+ """"""
+
+ if self._simulator is not None:
+ return self._simulator
+
+ # Create the jaxsim simulator.
+ simulator = JaxSim.build(
+ # Note: any change of either 'step_size' or 'steps_per_run' requires
+ # updating the number of integration steps in the 'transition' method.
+ step_size=0.000_250,
+ steps_per_run=1,
+ velocity_representation=VelRepr.Body,
+ integrator_type=IntegratorType.EulerSemiImplicit,
+ simulator_data=SimulatorData(
+ gravity=jnp.array([0, 0, -10.0]),
+ contact_parameters=SoftContactsParams.build(K=10_000, D=20),
+ ),
+ ).mutable(mutable=True, validate=False)
+
+ # Get the SDF path
+ model_sdf_path = resolve_robotics_uri(
+ "package://ergoCub/robots/ergoCubGazeboV1_minContacts/model.urdf"
+ )
+
+ # Insert the model
+ _ = simulator.insert_model_from_description(
+ model_description=model_sdf_path, model_name="ErgoCub"
+ )
+
+ simulator.data.models = {
+ model_name: jax.tree_util.tree_map(lambda leaf: jnp.array(leaf), model_data)
+ for model_name, model_data in simulator.data.models.items()
+ }
+
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
+ self._simulator = simulator.mutable(mutable=True, validate=True)
+
+ return self._simulator
+
+ def initial(self, rng: Any = None) -> StateType:
+ """"""
+
+ # Split the key
+ subkey1, subkey2 = jax.random.split(rng, num=2)
+
+ # Sample an initial observation
+ initial_observation: ErgoCubObservation = (
+ self.observation_space.sample_with_key(key=subkey1)
+ )
+
+ # Sample a goal position
+ goal_xy_position = jax.random.uniform(
+ key=subkey2, minval=-5.0, maxval=5.0, shape=(2,)
+ )
+
+ with self.jaxsim.editable(validate=False) as simulator:
+ # Reset the simulator and get the model
+ simulator.reset(remove_models=False)
+ model = simulator.get_model(model_name="ErgoCub")
+
+ # Reset the joint positions
+ model.reset_joint_positions(
+ positions=initial_observation.joint_positions,
+ joint_names=model.joint_names(),
+ )
+
+ # Reset the base position
+ model.reset_base_position(position=jnp.array([0, 0, 0.5]))
+
+ # Reset the base velocity
+ model.reset_base_velocity(
+ base_velocity=jnp.hstack(
+ [
+ 0.1 * initial_observation.base_linear_velocity,
+ 0.1 * initial_observation.base_angular_velocity,
+ ]
+ )
+ )
+
+ # Return the simulation state
+ return dict(
+ simulator_data=simulator.data,
+ goal=jnp.array(goal_xy_position, dtype=float),
+ )
+
+ def transition(
+ self, state: StateType, action: ActType, rng: Any = None
+ ) -> StateType:
+ """"""
+
+ # Get the JaxSim simulator
+ simulator = self.jaxsim
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ with simulator.editable(validate=True) as simulator:
+ simulator.data = state["simulator_data"]
+
+ @jax_dataclasses.pytree_dataclass
+ class SetTorquesOverHorizon(simulator_callbacks.PreStepCallback):
+ def pre_step(self, sim: JaxSim) -> JaxSim:
+ """"""
+
+ model = sim.get_model(model_name="ErgoCub")
+ model.zero_input()
+ model.set_joint_generalized_force_targets(
+ forces=jnp.atleast_1d(action), joint_names=model.joint_names()
+ )
+
+ return sim
+
+ number_of_integration_steps = 40 # 0.010 # TODO 20 for having 0.010
+
+ # Stepping logic
+ with simulator.editable(validate=True) as simulator:
+ simulator, _ = simulator.step_over_horizon(
+ horizon_steps=number_of_integration_steps,
+ clear_inputs=False,
+ callback_handler=SetTorquesOverHorizon(),
+ )
+
+ # Return the new environment state (updated SimulatorData)
+ return state | dict(simulator_data=simulator.data)
+
+ def observation(self, state: StateType) -> ObsType:
+ """"""
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ # and get the simulated model
+ with self.jaxsim.editable(validate=True) as simulator:
+ simulator.data = state["simulator_data"]
+ model = simulator.get_model("ErgoCub")
+
+ # Compute the normalized gravity projection in the body frame
+ W_R_B = model.base_orientation(dcm=True)
+ # W_gravity = state.simulator.gravity()
+ W_gravity = self.jaxsim.gravity()
+ B_gravity = W_R_B.T @ (W_gravity / jnp.linalg.norm(W_gravity))
+
+ W_p_B = model.base_position()
+ W_p_goal = jnp.hstack([state["goal"].squeeze(), 0])
+
+ # Compute the distance between the base and the goal in the body frame
+ B_p_distance = W_R_B.T @ (W_p_goal - W_p_B)
+
+ # Build the observation from the state
+ return ErgoCubObservation.build(
+ base_height=model.base_position()[2],
+ gravity_projection=B_gravity,
+ joint_positions=model.joint_positions(),
+ joint_velocities=model.joint_velocities(),
+ base_linear_velocity=model.base_velocity()[0:3],
+ base_angular_velocity=model.base_velocity()[3:6],
+ contact_state=model.in_contact(
+ link_names=[
+ name for name in model.link_names() if name.endswith("_ankle")
+ ]
+ ),
+ )
+
+ def reward(
+ self, state: StateType, action: ActType, next_state: StateType
+ ) -> RewardType:
+ """"""
+
+ with self.jaxsim.editable(validate=True) as simulator_next:
+ simulator_next.data = next_state["simulator_data"]
+ model_next = simulator_next.get_model("ErgoCub")
+
+ terminal = self.terminal(state=state)
+ obs_in_space = jax.lax.select(
+ pred=self.observation_space.contains(x=self.observation(state=state)),
+ on_true=1.0,
+ on_false=0.0,
+ )
+
+ # Position of the base
+ W_p_B = model_next.base_position()
+ W_p_xy_goal = state["goal"]
+
+ reward = 0.0
+ reward += 1.0 * (1.0 - jnp.array(terminal, dtype=float)) # alive
+ reward += 5.0 * obs_in_space #
+ # reward += 100.0 * v_WB[0] # forward velocity
+ reward -= jnp.linalg.norm(W_p_B[0:2] - W_p_xy_goal) # distance from goal
+ reward += 1.0 * model_next.in_contact(
+ link_names=[
+ name
+ for name in model_next.link_names()
+ if name.startswith("leg_") and name.endswith("_lower")
+ ]
+ ).any().astype(float)
+ reward -= 0.1 * jnp.linalg.norm(action) / action.size # control cost
+
+ return reward
+
+ def terminal(self, state: StateType) -> TerminalType:
+ # Get the current observation
+ observation = self.observation(state=state)
+
+ base_too_high = (
+ observation.base_height >= self.observation_space.high.base_height
+ )
+ return base_too_high
+
+ # =========
+ # Rendering
+ # =========
+
+ def render_image(
+ self, state: StateType, render_state: RenderStateType
+ ) -> tuple[RenderStateType, npt.NDArray]:
+ """Show the state."""
+
+ model_name = "ErgoCub"
+
+ # Initialize the simulator with the environment state (containing SimulatorData)
+ # and get the simulated model
+ with self.jaxsim.editable(validate=False) as simulator:
+ simulator.data = state["simulator_data"]
+ model = simulator.get_model(model_name=model_name)
+
+ # Insert the model lazily in the visualizer if it is not already there
+ if model_name not in render_state.world._meshcat_models.keys():
+ from rod.urdf.exporter import UrdfExporter
+
+ urdf_string = UrdfExporter.sdf_to_urdf_string(
+ sdf=rod.Sdf(
+ version="1.7",
+ model=model.physics_model.description.extra_info["sdf_model"],
+ ),
+ pretty=True,
+ gazebo_preserve_fixed_joints=False,
+ )
+
+ meshcat_viz_name = render_state.world.insert_model(
+ model_description=urdf_string, is_urdf=True, model_name=None
+ )
+
+ render_state._jaxsim_to_meshcat_viz_name[model_name] = meshcat_viz_name
+
+ # Check that the model is in the visualizer
+ if (
+ not render_state._jaxsim_to_meshcat_viz_name[model_name]
+ in render_state.world._meshcat_models.keys()
+ ):
+ raise ValueError(f"The '{model_name}' model is not in the meshcat world")
+
+ # Update the model in the visualizer
+ render_state.world.update_model(
+ model_name=render_state._jaxsim_to_meshcat_viz_name[model_name],
+ joint_names=model.joint_names(),
+ joint_positions=model.joint_positions(),
+ base_position=model.base_position(),
+ base_quaternion=model.base_orientation(dcm=False),
+ )
+
+ return render_state, np.empty(0)
+
+ def render_init(self, open_gui: bool = False, **kwargs) -> RenderStateType:
+ """Initialize the render state."""
+
+ # Initialize the render state
+ meshcat_viz_state = MeshcatVizRenderState()
+
+ if open_gui:
+ meshcat_viz_state.open_window_in_process()
+
+ return meshcat_viz_state
+
+ def render_close(self, render_state: RenderStateType) -> None:
+ """Close the render state."""
+
+ render_state.close()
+
+
+class ErgoCubWalkEnvV0(JaxEnv):
+ """"""
+
+ def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None:
+ """"""
+
+ from jaxgym.wrappers.jax import (
+ ClipActionWrapper,
+ FlattenSpacesWrapper,
+ JaxTransformWrapper,
+ TimeLimit,
+ )
+
+ func_env = ErgoCubWalkFuncEnvV0()
+
+ func_env_wrapped = func_env
+ func_env_wrapped = TimeLimit(
+ env=func_env_wrapped, max_episode_steps=5_000
+ ) # TODO
+ func_env_wrapped = ClipActionWrapper(env=func_env_wrapped)
+ func_env_wrapped = FlattenSpacesWrapper(env=func_env_wrapped)
+ func_env_wrapped = JaxTransformWrapper(env=func_env_wrapped, function=jax.jit)
+
+ super().__init__(
+ func_env=func_env_wrapped,
+ metadata=self.metadata,
+ render_mode=render_mode,
+ )
+
+
+class ErgoCubWalkVectorEnvV0(JaxVectorEnv):
+ """"""
+
+ metadata = dict()
+
+ def __init__(
+ self,
+ # func_env: JaxDataclassEnv[
+ # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType
+ # ],
+ num_envs: int,
+ render_mode: str | None = None,
+ # max_episode_steps: int = 5_000,
+ jit_compile: bool = True,
+ **kwargs,
+ ) -> None:
+ """"""
+
+ print("+++", kwargs)
+
+ env = ErgoCubWalkFuncEnvV0()
+
+ # Vectorize the environment.
+ # Note: it automatically wraps the environment in a TimeLimit wrapper.
+ super().__init__(
+ func_env=env,
+ num_envs=num_envs,
+ metadata=self.metadata,
+ render_mode=render_mode,
+ max_episode_steps=5_000, # TODO
+ jit_compile=jit_compile,
+ )
+
+ # from jaxgym.vector.jax import FlattenSpacesVecWrapper
+ #
+ # vec_env_wrapped = FlattenSpacesVecWrapper(env=vec_env)
+
+
+if __name__ == "__main__":
+
+ def make_jax_env(
+ max_episode_steps: Optional[int] = 500, jit: bool = True
+ ) -> JaxEnv:
+ """"""
+
+ if max_episode_steps in {None, 0}:
+ env = ErgoCubWalkFuncEnvV0()
+ else:
+ env = TimeLimit(
+ env=ErgoCubWalkFuncEnvV0(), max_episode_steps=max_episode_steps
+ )
+
+ return JaxEnv(
+ func_env=ToNumPyWrapper(
+ env=FlattenSpacesWrapper(env=env)
+ if not jit
+ else JaxTransformWrapper(
+ function=jax.jit,
+ env=FlattenSpacesWrapper(env=env),
+ ),
+ ),
+ render_mode="meshcat_viz",
+ )
+
+ env = make_jax_env(max_episode_steps=5, jit=True)
+
+ obs, state_info = env.reset(seed=0)
+ _ = env.render()
+ print(obs)