diff --git a/.gitignore b/.gitignore index f2520562c..17dce9518 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +core +*.pickle +*.zip + # IDEs .idea* diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..cfb2b6b85 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true + }, + ] +} \ No newline at end of file diff --git a/examples/resources/ant.sdf b/examples/resources/ant.sdf new file mode 100644 index 000000000..c47dc030b --- /dev/null +++ b/examples/resources/ant.sdf @@ -0,0 +1,520 @@ + + + + + + 1.0 + + 0.025 + 0.025 + 0.025 + 0.0 + 0.0 + 0.0 + + + + + + 0.25 + + + + + + + 0.2886751345948129 0.2886751345948129 0.2886751345948129 + + + + + + -2.7755575615628914e-17 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + 0.3 + + 0.0044800000000000005 + 0.0044800000000000005 + 0.00096 + 0.0 + 0.0 + 0.0 + + 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + 0.4 + + + 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.07999999999999997 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + + + + 0.08800000000000001 + + + 0.4800000000000001 5.551115123125783e-17 0.0 0.0 0.0 -5.551115123125783e-17 + + + + + 0.08 + 0.4 + + + 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + 5.551115123125783e-17 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + 0.2 + + 0.002986666666666667 + 0.002986666666666667 + 0.00064 + 0.0 + 0.0 + 0.0 + + 0.20000000000000007 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + 0.4 + + + 0.20000000000000007 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.4000000000000001 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + + + + 0.08 + 0.4 + + + 0.20000000000000007 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.4000000000000001 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + + + -2.7755575615628914e-17 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + 0.3 + + 0.0044800000000000005 + 0.0044800000000000005 + 0.00096 + 0.0 + 0.0 + 0.0 + + 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + 0.4 + + + 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.07999999999999997 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + + + + 0.08800000000000001 + + + 0.4800000000000001 -5.551115123125783e-17 0.0 0.0 0.0 5.551115123125783e-17 + + + + + 0.08 + 0.4 + + + 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + 5.551115123125783e-17 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + 0.2 + + 0.002986666666666667 + 0.002986666666666667 + 0.00064 + 0.0 + 0.0 + 0.0 + + 0.20000000000000007 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + 0.4 + + + 0.20000000000000007 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.4000000000000001 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + + + + 0.08 + 0.4 + + + 0.20000000000000007 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.4000000000000001 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + + + 0.0 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + 0.3 + + 0.0044800000000000005 + 0.0044800000000000005 + 0.00096 + 0.0 + 0.0 + 0.0 + + 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + 0.4 + + + 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.08 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + + + + 0.08800000000000001 + + + 0.4800000000000001 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + + + + 0.08 + 0.4 + + + 0.28 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + 1.1102230246251565e-16 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + 0.2 + + 0.002986666666666667 + 0.002986666666666667 + 0.00064 + 0.0 + 0.0 + 0.0 + + 0.20000000000000012 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + 0.4 + + + 0.20000000000000012 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.40000000000000013 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + + + + 0.08 + 0.4 + + + 0.20000000000000012 0.0 0.0 -5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.40000000000000013 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + + + 0.0 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + 0.3 + + 0.0044800000000000005 + 0.0044800000000000005 + 0.00096 + 0.0 + 0.0 + 0.0 + + 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + 0.4 + + + 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.08 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + + + + 0.08800000000000001 + + + 0.4800000000000001 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + + + + 0.08 + 0.4 + + + 0.28 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + 1.1102230246251565e-16 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + 0.2 + + 0.002986666666666667 + 0.002986666666666667 + 0.00064 + 0.0 + 0.0 + 0.0 + + 0.20000000000000012 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + 0.4 + + + 0.20000000000000012 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.40000000000000013 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + + + + 0.08 + 0.4 + + + 0.20000000000000012 0.0 0.0 5.551115123125783e-17 1.5707963267948966 0.0 + + + + + 0.08 + + + 0.40000000000000013 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + + + torso + leg_front_left_upper + 0.1767766952966369 0.1767766952966369 0.0 0.0 0.0 0.7853981633974484 + + 0 0 1 + + -0.5235987755982988 + 0.5235987755982988 + + + + + leg_front_left_upper + leg_front_left_lower + 0.48 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + 0 1 0 + + 0.5235987755982988 + 1.2217304763960306 + + + + + torso + leg_front_right_upper + 0.1767766952966369 -0.1767766952966369 0.0 0.0 0.0 -0.7853981633974484 + + 0 0 1 + + -0.5235987755982988 + 0.5235987755982988 + + + + + leg_front_right_upper + leg_front_right_lower + 0.48 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + 0 1 0 + + 0.5235987755982988 + 1.2217304763960306 + + + + + torso + leg_back_left_upper + -0.1767766952966369 0.1767766952966369 0.0 0.0 0.0 2.356194490192345 + + 0 0 1 + + -0.5235987755982988 + 0.5235987755982988 + + + + + leg_back_left_upper + leg_back_left_lower + 0.48000000000000004 0.0 0.0 0.0 0.0 5.551115123125783e-17 + + 0 1 0 + + 0.5235987755982988 + 1.2217304763960306 + + + + + torso + leg_back_right_upper + -0.1767766952966369 -0.1767766952966369 0.0 0.0 0.0 -2.356194490192345 + + 0 0 1 + + -0.5235987755982988 + 0.5235987755982988 + + + + + leg_back_right_upper + leg_back_right_lower + 0.48000000000000004 0.0 0.0 0.0 0.0 -5.551115123125783e-17 + + 0 1 0 + + 0.5235987755982988 + 1.2217304763960306 + + + + + diff --git a/examples/resources/cartpole.urdf b/examples/resources/cartpole.urdf new file mode 100644 index 000000000..7df2fd4c9 --- /dev/null +++ b/examples/resources/cartpole.urdf @@ -0,0 +1,81 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/pyproject.toml b/pyproject.toml index 5e229ba9e..4c31bb012 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,3 +15,10 @@ line-length = 88 [tool.isort] profile = "black" multi_line_output = 3 + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-rsxX -v --strict-markers --forked" +testpaths = [ + "tests", +] diff --git a/setup.cfg b/setup.cfg index 398189d00..4bcb49574 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,6 +60,7 @@ install_requires = pptree rod scipy + typing_extensions; python_version < "3.11" [options.packages.find] where = src @@ -70,13 +71,10 @@ style = isort testing = idyntree - pytest + pytest >= 6.0 + pytest-forked pytest-icdiff robot-descriptions all = %(style)s %(testing)s - -[tool:pytest] -addopts = -rsxX -v --strict-markers -testpaths = tests diff --git a/src/jaxgym/__init__.py b/src/jaxgym/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/jaxgym/__main__.py b/src/jaxgym/__main__.py new file mode 100644 index 000000000..7b9d9fde6 --- /dev/null +++ b/src/jaxgym/__main__.py @@ -0,0 +1,1795 @@ +import warnings + +warnings.simplefilter(action="ignore", category=FutureWarning) + +import functools +import pathlib +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import gymnasium as gym +import jax.random +import matplotlib.pyplot as plt +import mujoco +import numpy as np +import numpy.typing as npt +import stable_baselines3 +from gymnasium.experimental.vector.vector_env import VectorWrapper +from sb3_contrib import TRPO +from scipy.spatial.transform import Rotation +from stable_baselines3 import PPO +from stable_baselines3.common import vec_env as vec_env_sb +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor, VecNormalize + +import jaxsim.typing as jtp +from jaxgym.envs.ant import AntReachTargetFuncEnvV0 +from jaxgym.envs.cartpole import CartpoleSwingUpFuncEnvV0 +from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper, JaxEnv, PyTree +from jaxgym.vector.jax import FlattenSpacesVecWrapper, JaxVectorEnv +from jaxgym.wrappers.jax import ( # TimeLimitStableBaselines, + ActionNoiseWrapper, + ClipActionWrapper, + FlattenSpacesWrapper, + JaxTransformWrapper, + NaNHandlerWrapper, + SquashActionWrapper, + TimeLimit, + ToNumPyWrapper, +) + +# +# +# + + +class MujocoModel: + """""" + + def __init__(self, xml_path: pathlib.Path) -> None: + """""" + + if not xml_path.exists(): + raise FileNotFoundError(f"Could not find file '{xml_path}'") + + self.model = mujoco.MjModel.from_xml_path(filename=str(xml_path), assets=None) + + self.data = mujoco.MjData(self.model) + + # Populate data + mujoco.mj_forward(self.model, self.data) + + # print(self.model.opt) + + def time(self) -> float: + """""" + + return self.data.time + + def gravity(self) -> npt.NDArray: + """""" + + return self.model.opt.gravity + + def number_of_joints(self) -> int: + """""" + + return self.model.njnt + + def number_of_geometries(self) -> int: + """""" + + return self.model.ngeom + + def number_of_bodies(self) -> int: + """""" + + return self.model.nbody + + def joint_names(self) -> List[str]: + """""" + + return [ + mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_JOINT, idx) + for idx in range(self.number_of_joints()) + ] + + def joint_dofs(self, joint_name: str) -> int: + """""" + + if joint_name not in self.joint_names(): + raise ValueError(f"Joint '{joint_name}' not found") + + return self.data.joint(joint_name).qpos.size + + def joint_position(self, joint_name: str) -> npt.NDArray: + """""" + + if joint_name not in self.joint_names(): + raise ValueError(f"Joint '{joint_name}' not found") + + return self.data.joint(joint_name).qpos + + def joint_velocity(self, joint_name: str) -> npt.NDArray: + """""" + + if joint_name not in self.joint_names(): + raise ValueError(f"Joint '{joint_name}' not found") + + return self.data.joint(joint_name).qvel + + def body_names(self) -> List[str]: + """""" + + return [ + mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_BODY, idx) + for idx in range(self.number_of_bodies()) + ] + + def body_position(self, body_name: str) -> npt.NDArray: + """""" + + if body_name not in self.body_names(): + raise ValueError(f"Body '{body_name}' not found") + + return self.data.body(body_name).xpos + + def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray: + """""" + + if body_name not in self.body_names(): + raise ValueError(f"Body '{body_name}' not found") + + return ( + self.data.body(body_name).xmat if dcm else self.data.body(body_name).xquat + ) + + def geometry_names(self) -> List[str]: + """""" + + return [ + mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_GEOM, idx) + for idx in range(self.number_of_geometries()) + ] + + def geometry_position(self, geometry_name: str) -> npt.NDArray: + """""" + + if geometry_name not in self.geometry_names(): + raise ValueError(f"Geometry '{geometry_name}' not found") + + return self.data.geom(geometry_name).xpos + + def geometry_orientation( + self, geometry_name: str, dcm: bool = False + ) -> npt.NDArray: + """""" + + if geometry_name not in self.geometry_names(): + raise ValueError(f"Geometry '{geometry_name}' not found") + + R = np.reshape(self.data.geom(geometry_name).xmat, (3, 3)) + + if dcm: + return R + + q_xyzw = Rotation.from_matrix(R).as_quat() + return q_xyzw[[3, 0, 1, 2]] + + def to_string(self) -> Tuple[str, str]: + """Convert a mujoco model to a string.""" + + import tempfile + + with tempfile.NamedTemporaryFile(mode="w+") as f: + mujoco.mj_saveLastXML(f.name, self.model) + mjcf_string = pathlib.Path(f.name).read_text() + + with tempfile.NamedTemporaryFile(mode="w+") as f: + mujoco.mj_printModel(self.model, f.name) + compiled_model_string = pathlib.Path(f.name).read_text() + + return mjcf_string, compiled_model_string + + +# Full cartpole example with collection loop +# +# -> validate with visualization +# -> move to pytorch later +# -> study TensorDict -> create wrapper -> use PPO +# -> Check JaxToTorch wrapper and adapt it to work for JaxVectorEnv +# -> alternatively, wait for stable_baselines3 + gymnasium (open issue) + +# For TensorDict, check: +# +# -> create docker image with torch 2.0 +# - EnvBase https://github.com/pytorch/rl/blob/main/torchrl/envs/common.py#L120 (has batch_size -> maybe vectorized?) +# - torchrl.collectors +# - torchrl.envs.libs.gym.GymWrapper +# - torchrl.envs.libs.brax.BraxWrapper + +# TODO: JaxSimEnv with render support? + + +class CustomVecEnvSB(vec_env_sb.VecEnv): + """""" + + metadata = {"render_modes": []} + + def __init__( + self, + jax_vector_env: JaxVectorEnv | VectorWrapper, + log_rewards: bool = False, + # num_envs: int, + # observation_space: spaces.Space, + # action_space: spaces.Space, + # render_mode: Optional[str] = None, + ) -> None: + """""" + + if not isinstance(jax_vector_env.unwrapped, JaxVectorEnv): + raise TypeError(type(jax_vector_env)) + + self.jax_vector_env = jax_vector_env + + single_env_action_space: PyTree = jax_vector_env.unwrapped.single_action_space + + single_env_observation_space: PyTree = ( + jax_vector_env.unwrapped.single_observation_space + ) + + super().__init__( + num_envs=self.jax_vector_env.num_envs, + action_space=single_env_action_space.to_box(), + observation_space=single_env_observation_space.to_box(), + render_mode=None, + ) + + self.actions = np.zeros_like(self.jax_vector_env.action_space.sample()) + + # Initialize the RNG seed + self._seed = None + self.seed() + + # Initialize the rewards logger + self.logger_rewards = [] if log_rewards else None + + def reset(self) -> vec_env_sb.base_vec_env.VecEnvObs: + """""" + + observations, state_infos = self.jax_vector_env.reset(seed=self._seed) + return np.array(observations) + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + @staticmethod + @functools.partial(jax.jit, static_argnames=("batch_size",)) + def tree_inverse_transpose(pytree: jtp.PyTree, batch_size: int) -> List[jtp.PyTree]: + """""" + + return [ + jax.tree_util.tree_map(lambda leaf: leaf[i], pytree) + for i in range(batch_size) + ] + + def step_wait(self) -> vec_env_sb.base_vec_env.VecEnvStepReturn: + """""" + + ( + observations, + rewards, + terminals, + truncated, + step_infos, + ) = self.jax_vector_env.step(actions=self.actions) + + done = np.logical_or(terminals, truncated) + + # list_of_step_infos = [ + # jax.tree_util.tree_map(lambda l: l[i], step_infos) + # for i in range(self.jax_vector_env.num_envs) + # ] + + list_of_step_infos = self.tree_inverse_transpose( + pytree=step_infos, batch_size=self.jax_vector_env.num_envs + ) + + # def pytree_to_numpy(pytree: jtp.PyTree) -> jtp.PyTree: + # return jax.tree_util.tree_map(lambda leaf: np.array(leaf), pytree) + # + # list_of_step_infos_numpy = [pytree_to_numpy(pt) for pt in list_of_step_infos] + + list_of_step_infos_numpy = [ + ToNumPyWrapper.pytree_to_numpy(pytree=pt) for pt in list_of_step_infos + ] + + if self.logger_rewards is not None: + self.logger_rewards.append(np.array(rewards).mean()) + + return ( + np.array(observations), + np.array(rewards), + np.array(done), + list_of_step_infos_numpy, + ) + + def close(self) -> None: + return self.jax_vector_env.close() + + def get_attr( + self, attr_name: str, indices: vec_env_sb.base_vec_env.VecEnvIndices = None + ) -> List[Any]: + raise NotImplementedError + + def set_attr( + self, + attr_name: str, + value: Any, + indices: vec_env_sb.base_vec_env.VecEnvIndices = None, + ) -> None: + raise NotImplementedError + + def env_method( + self, + method_name: str, + *method_args, + indices: vec_env_sb.base_vec_env.VecEnvIndices = None, + **method_kwargs, + ) -> List[Any]: + raise NotImplementedError + + def env_is_wrapped( + self, + wrapper_class: Type[gym.Wrapper], + indices: vec_env_sb.base_vec_env.VecEnvIndices = None, + ) -> List[bool]: + return [False] * self.num_envs + # raise NotImplementedError + + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + """""" + + if seed is None: + seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32") + + if np.array(seed, dtype="uint32") != np.array(seed): + raise ValueError(f"seed must be compatible with 'uint32' casting") + + self._seed = seed + return [seed] + + # _ = self.jax_vector_env.reset(seed=seed) + # return [None] + + +def make_vec_env_stable_baselines( + jax_dataclass_env: JaxDataclassEnv | JaxDataclassWrapper, + n_envs: int = 1, + seed: Optional[int] = None, + # monitor_dir: Optional[str] = None, + vec_env_kwargs: Optional[Dict[str, Any]] = None, + # + # env_id: Union[str, Callable[..., gym.Env]], + # # n_envs: int = 1, + # # seed: Optional[int] = None, + # start_index: int = 0, + # monitor_dir: Optional[str] = None, + # wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None, + # env_kwargs: Optional[Dict[str, Any]] = None, + # vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, + # vec_env_kwargs: Optional[Dict[str, Any]] = None, + # monitor_kwargs: Optional[Dict[str, Any]] = None, + # wrapper_kwargs: Optional[Dict[str, Any]] = None, +) -> vec_env_sb.VecEnv: + """""" + + env = jax_dataclass_env + + vec_env_kwargs = vec_env_kwargs if vec_env_kwargs is not None else dict() + + # Vectorize the environment. + # Note: it automatically wraps the environment in a TimeLimit wrapper. + # Note: the space must be PyTree. + vec_env = JaxVectorEnv( + func_env=env, + num_envs=n_envs, + **vec_env_kwargs, + ) + + # Flatten the PyTree spaces to regular Box spaces + vec_env = FlattenSpacesVecWrapper(env=vec_env) + + # if seed is not None: + # _ = vec_env.reset(seed=seed) + + vec_env_sb = CustomVecEnvSB(jax_vector_env=vec_env, log_rewards=True) + + if seed is not None: + _ = vec_env_sb.seed(seed=seed) + + return vec_env_sb + + +def visualizer( + env: JaxEnv | Callable[[None], JaxEnv], policy: BaseAlgorithm +) -> Callable[[Optional[int]], None]: + """""" + + import numpy as np + import rod + from loop_rate_limiters import RateLimiter + from meshcat_viz import MeshcatWorld + + from jaxsim import JaxSim + + # Open the visualizer + world = MeshcatWorld() + world.open() + + # Create the JaxSim environment and get the simulator + env = env() if isinstance(env, Callable) else env + sim: JaxSim = env.unwrapped.func_env.unwrapped.jaxsim + + # Extract the SDF string from the simulated model + jaxsim_model = sim.get_model(model_name="cartpole") + rod_model = jaxsim_model.physics_model.description.extra_info["sdf_model"] + rod_sdf = rod.Sdf(model=rod_model, version="1.7") + sdf_string = rod_sdf.serialize(pretty=True) + + # Insert the model from a URDF/SDF resource + model_name = world.insert_model(model_description=sdf_string, is_urdf=False) + + # Create the visualization function + def rollout(seed: Optional[int] = None) -> None: + """""" + + # Reset the environment + observation, state_info = env.reset(seed=seed) + + # Initialize the model state with the initial observation + world.update_model( + model_name=model_name, + joint_names=["linear", "pivot"], + joint_positions=np.array([observation[0], observation[2]]), + ) + + rtf = 1.0 + down_sampling = 1 + rate = RateLimiter(frequency=float(rtf / (sim.dt() * down_sampling))) + + done = False + + # Visualization loop + while not done: + action, _ = policy.predict(observation=observation, deterministic=True) + print(action) + observation, _, terminated, truncated, _ = env.step(action) + done = terminated or truncated + + world.update_model( + model_name=model_name, + joint_names=["linear", "pivot"], + joint_positions=np.array([observation[0], observation[2]]), + ) + + print(done) + rate.sleep() + + print("done") + + return rollout + + +# ============ +# ENVIRONMENTS +# ============ + +# TODO: +# +# - Initialize ANT already in contact -> otherwise it jumps around (idle after falling?) +# - Tune spring for joint limits and joint friction +# - wrapper to squash action space to [-1, 1] + + +def make_jax_env_cartpole( + render_mode: Optional[str] = None, + max_episode_steps: Optional[int] = 500, +) -> JaxEnv: + """""" + + # TODO: single env -> time limit with stable_baselines? + + import os + + import torch + + if not torch.cuda.is_available(): + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["NVIDIA_VISIBLE_DEVICES"] = "" + + import warnings + + warnings.simplefilter(action="ignore", category=FutureWarning) + + env = NaNHandlerWrapper(env=CartpoleSwingUpFuncEnvV0()) + + if max_episode_steps is not None: + env = TimeLimit(env=env, max_episode_steps=max_episode_steps) + + return JaxEnv( + render_mode=render_mode, + func_env=ToNumPyWrapper( + env=JaxTransformWrapper( + function=jax.jit, + env=FlattenSpacesWrapper( + env=ClipActionWrapper( + env=SquashActionWrapper(env=env), + ) + ), + ) + ), + ) + + +def make_jax_env_ant( + render_mode: Optional[str] = None, + max_episode_steps: Optional[int] = 1_000, +) -> JaxEnv: + """""" + + # TODO: single env -> time limit with stable_baselines? + + import os + + import torch + + if not torch.cuda.is_available(): + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["NVIDIA_VISIBLE_DEVICES"] = "" + + import warnings + + warnings.simplefilter(action="ignore", category=FutureWarning) + + env = NaNHandlerWrapper(env=AntReachTargetFuncEnvV0()) + + if max_episode_steps is not None: + env = TimeLimit(env=env, max_episode_steps=max_episode_steps) + + return JaxEnv( + render_mode=render_mode, + func_env=ToNumPyWrapper( + env=JaxTransformWrapper( + function=jax.jit, + env=FlattenSpacesWrapper( + env=ClipActionWrapper( + env=SquashActionWrapper(env=env), + ) + ), + ) + ), + ) + + +# ============================= +# Test JaxVecEnv vs DummyVecEnv +# ============================= + +if __name__ == "__main__?": + """""" + + max_episode_steps = 200 + func_env = NaNHandlerWrapper(env=CartpoleSwingUpFuncEnvV0()) + + if max_episode_steps is not None: + func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps) + + func_env = ( + # ToNumPyWrapper(env= + # env=JaxTransformWrapper( + # function=jax.jit, + # env=FlattenSpacesWrapper( + ClipActionWrapper( + env=SquashActionWrapper(env=func_env), + ) + # ), + # ), + # ) + ) + + vec_env = make_vec_env_stable_baselines( + jax_dataclass_env=func_env, + n_envs=10, + seed=42, + vec_env_kwargs=dict( + # max_episode_steps=5_000, + jit_compile=True, + ), + ) + + # Seed the environment + # vec_env.seed(seed=42) + + # Reset the environment. + # This has to be done only once since the vectorized environment supports autoreset. + observations = vec_env.reset() + + # Initialize a random policy + random_policy = lambda obs: vec_env.jax_vector_env.action_space.sample() + + for _ in range(1): + # Sample random actions + actions = random_policy(observations) + + # Step the environment + observations, rewards, dones, step_infos = vec_env.step(actions=actions) + + print(observations, rewards, dones, step_infos) + # print() + # print(dones) + + +def evaluate( + env: gym.Env | Callable[[...], gym.Env], + num_episodes: int = 1, + seed: int | None = None, + render: bool = False, + policy: Callable[[npt.NDArray], npt.NDArray] | None = None, + # vec_env_norm: Optional[VecNormalize] = None, +) -> None: + """""" + + # Create the environment if a callable is passed + env = env if isinstance(env, gym.Env) else env() + + # Initialize a random policy if none is passed + policy = policy if policy is not None else lambda obs: env.action_space.sample() + + # if vec_env_norm is not None and not isinstance(vec_env_norm, VecNormalize): + # raise TypeError(vec_env_norm, VecNormalize) + + episodes_length = [] + cumulative_rewards = [] + + for e in range(num_episodes): + # Reset the environment + observation, step_info = env.reset(seed=seed) + + # Initialize done flag + done = False + + # Render the environment + if render: + env.render() + + episodes_length += [0] + cumulative_rewards += [0] + + # Evaluation loop + while not done: + # Increase episode length counter + episodes_length[-1] += 1 + + # Predict the action + action = policy(observation) + + # Step the environment + observation, reward, terminal, truncated, step_info = env_eval.step( + action=action + ) + + # Determine if the episode is done + done = terminal or truncated + + # Store the cumulative reward + cumulative_rewards[-1] += reward + + # Render the environment + if render: + _ = env_eval.render() + + print("ep_len_mean\t", np.array(episodes_length).mean()) + print("ep_rew_mean\t", np.array(cumulative_rewards).mean()) + + +# Train with SB +if __name__ == "__main__cartpole_cpu_vec_env": + """""" + + max_episode_steps = 200 + func_env = NaNHandlerWrapper(env=CartpoleSwingUpFuncEnvV0()) + + if max_episode_steps is not None: + func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps) + + func_env = ClipActionWrapper( + env=SquashActionWrapper(env=func_env), + ) + + vec_env = make_vec_env_stable_baselines( + jax_dataclass_env=func_env, + n_envs=10, + seed=42, + vec_env_kwargs=dict( + jit_compile=True, + ), + ) + + import torch as th + + model = PPO( + "MlpPolicy", + env=vec_env, + # n_steps=2048, + n_steps=256, # in the vector env -> real ones are x10 + batch_size=256, + n_epochs=10, + gamma=0.95, + gae_lambda=0.9, + clip_range=0.1, + normalize_advantage=True, + # target_kl=0.010, + target_kl=0.025, + verbose=1, + learning_rate=0.000_300, + policy_kwargs=dict( + activation_fn=th.nn.ReLU, + net_arch=dict(pi=[512, 512], vf=[512, 512]), + log_std_init=np.log(0.05), + # squash_output=True, + ), + ) + + print(model.policy) + + # Create the evaluation environment + env_eval = make_jax_env_cartpole( + render_mode="meshcat_viz", + max_episode_steps=500, + ) + + for _ in range(1): + # Train the model + model = model.learn(total_timesteps=50_000, progress_bar=False) + + # Create the policy closure + policy = lambda observation: model.policy.predict( + observation=observation, deterministic=True + )[0] + + # Evaluate the policy + print("Evaluating...") + evaluate( + env=env_eval, + num_episodes=10, + seed=None, + render=True, + policy=policy, + ) + +# Train with SB +if __name__ == "__main__cartpole_gpu_vec_env": + """""" + + max_episode_steps = 200 + func_env = NaNHandlerWrapper(env=CartpoleSwingUpFuncEnvV0()) + + if max_episode_steps is not None: + func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps) + + func_env = ClipActionWrapper( + env=SquashActionWrapper( + # env=func_env + env=ActionNoiseWrapper(env=func_env) + ), + ) + + vec_env = make_vec_env_stable_baselines( + jax_dataclass_env=func_env, + # n_envs=10, + n_envs=512, + # n_envs=2048, # TODO + seed=42, + vec_env_kwargs=dict( + jit_compile=True, + ), + ) + + vec_env = VecMonitor( + venv=VecNormalize( + venv=vec_env, + training=True, + norm_obs=True, + norm_reward=True, + clip_obs=10.0, + clip_reward=10.0, + gamma=0.95, + epsilon=1e-8, + ) + ) + + actions = vec_env.jax_vector_env.action_space.sample() + + # 0: ok + # 1: ok + # 2: ok + # 3: ok + # 4: ok + # 5: ok + # 6: ok + # 7: -> now + # 8: + # 9: + vec_env.venv.venv.logger_rewards = [] + seed = vec_env.seed(seed=7)[0] + _ = vec_env.reset() + + import torch as th + + model = PPO( + "MlpPolicy", + env=vec_env, + n_steps=5, # in the vector env -> real ones are x512 + batch_size=256, + n_epochs=10, + gamma=0.95, + gae_lambda=0.9, + clip_range=0.1, + normalize_advantage=True, + target_kl=0.025, + verbose=1, + learning_rate=0.000_300, + policy_kwargs=dict( + activation_fn=th.nn.ReLU, + net_arch=dict(pi=[512, 512], vf=[512, 512]), + log_std_init=np.log(0.05), + # squash_output=True, + ), + ) + + print(model.policy) + + # Create the evaluation environment + env_eval = make_jax_env_cartpole( + render_mode="meshcat_viz", + max_episode_steps=200, + ) + + # from stable_baselines3.common. + + # rewards = np.zeros((10, 982)) DO NOT EXEC + # rewards_7 = np.array(vec_env.venv.venv.logger_rewards) CHANGE _X + # rewards[seed, :] = np.array(vec_env.venv.venv.logger_rewards) + + # rewards = np.vstack([rewards, np.atleast_2d(vec_env.logger_rewards)]) + + # plt.plot( + # # np.arange(start=1, stop=len(vec_env.venv.venv.logger_rewards) + 1) * 512, + # # vec_env.venv.venv.logger_rewards, + # np.arange(start=1, stop=len(vec_env.venv.venv.logger_rewards) + 3) * 512, + # rewards.T, + # label=r"$\hat{r}$" + # ) + # # plt.plot(step_data[model_js.name()].tf, joint_positions_mj, label=["d", "theta"]) + # plt.grid(True) + # plt.legend() + # plt.xlabel("Time steps") + # plt.ylabel("Average reward over 512 environments") + # # plt.title("Trajectory of the model's base") + # plt.show() + + # import pickle + # with open(file=pathlib.Path.home() + # / "git" + # / "jaxsim" + # / "scripts" + # / f"ppo_cartpole_swingup_rewards.pickle", mode="w+b") as f: + # pickle.dump(rewards, f) + + # model.save( + # path=pathlib.Path.home() + # / "git" + # / "jaxsim" + # / "scripts" + # / f"ppo_cartpole_swing_up_seed={seed}.zip" + # ) + + for _ in range(5): + # Train the model + model = model.learn(total_timesteps=50_000, progress_bar=False) + # %time model = model.learn(total_timesteps=500_000, progress_bar=False) + + # Create the policy closure + policy = lambda observation: model.policy.predict( + # observation=observation, deterministic=True + observation=vec_env.normalize_obs(observation), + deterministic=True, + )[0] + + # Evaluate the policy + print("Evaluating...") + evaluate( + env=env_eval, + num_episodes=10, + seed=None, + render=True, + policy=policy, + # vec_env_norm=vec_env, + ) + + # for _ in range(n_steps): + # observation = mj_observation(mujoco_model=m) + # action = model.policy.predict(observation=observation, deterministic=True)[0] + # m.data.ctrl = np.atleast_1d(action) + # mujoco.mj_step(m.model, m.data) + # # mujoco.mj_forward(m.model, m.data) + + # import palettable + # # https://jiffyclub.github.io/palettable/cartocolors/diverging/ + # colors = palettable.cartocolors.diverging.Geyser_5.mpl_colors + # + # r = rewards.copy() + # mean = r.mean(axis=0) + # std = r.std(axis=0) + # std_up = mean + std/2 + # std_down = mean - std/2 + # + # fig, ax1 = plt.subplots(1, 1) + # ax1.fill_between( + # np.arange(start=1, stop=mean.size + 1) * 512, + # std_down, + # std_up, + # label=r"$\pm \sigma$", + # color=colors[1], + # ) + # ax1.plot( + # np.arange(start=1, stop=mean.size + 1) * 512, + # mean, + # color=colors[0], + # ) + # ax1.grid() + # # ax1.legend(loc="lower right") + # ax1.set_title(r"\textbf{Average reward}") + # ax1.set_xlabel("Samples") + # + # # plt.show() + # + # import tikzplotlib + # tikzplotlib.clean_figure() + # print(tikzplotlib.get_tikz_code()) + + # ======================= + # Evaluation environments + # ======================= + # + # env_eval = make_jax_env_cartpole(render_mode="meshcat_viz", max_episode_steps=None) + # + # # observation, step_info = env_eval.reset(seed=42) + # observation, step_info = env_eval.reset() + # + # # Initialize done flag + # done = False + # + # env_eval.render() + # + # i = 0 + # cum_reward = 0.0 + # + # while not done: + # i += 1 + # + # if i == 2000: + # done = True + # + # # Sample a random action + # # action = 0.1* random_policy(env, observation) + # action, _ = model.policy.predict(observation=observation, deterministic=False) + # + # # Step the environment + # observation, reward, terminal, truncated, step_info = env_eval.step( + # action=action + # ) + # + # print(reward) + # cum_reward += reward + # + # # Render the environment + # _ = env_eval.render() + # + # # print(observation, reward, terminal, truncated, step_info) + # # print(env.state) + # + # print("reward =", cum_reward) + # print("episode length =", i) + # + # env_eval.close() + + +# Comparison with mujoco +if __name__ == "__main_comparison_mujoco": + """""" + + model = PPO.load( + path=pathlib.Path.home() + / "git" + / "jaxsim" + / "scripts" + / "ppo_cartpole_swing_up_seed=7.zip" + ) + + # Create the evaluation environment + env_eval = make_jax_env_cartpole( + render_mode="meshcat_viz", + max_episode_steps=250, + ) + + # ================= + # Mujoco evaluation + # ================= + + def mj_observation(mujoco_model: MujocoModel) -> npt.NDArray: + """""" + + mujoco.mj_forward(mujoco_model.model, mujoco_model.data) + + pivot_pos = mujoco_model.joint_position(joint_name="pivot") + # θ = np.arctan2(np.sin(pivot_pos), np.cos(pivot_pos)) + θ = pivot_pos + + return ( + np.array( + [ + mujoco_model.joint_position(joint_name="linear"), + mujoco_model.joint_velocity(joint_name="linear"), + θ, + mujoco_model.joint_velocity(joint_name="pivot"), + ] + ) + .squeeze() + .copy() + ) + + def mj_reset( + mujoco_model: MujocoModel, observation: Optional[npt.NDArray] = None + ) -> npt.NDArray: + """""" + + observation = ( + observation + if observation is not None + else np.array([0.0, 0.0, np.deg2rad(180), 0.0]) + ) + + linear_pos = observation[0] + linear_vel = observation[1] + pivot_pos = observation[2] + pivot_vel = observation[3] + + mujoco_model.data.qpos = np.array([linear_pos, pivot_pos]) + mujoco_model.data.qvel = np.array([linear_vel, pivot_vel]) + mujoco.mj_forward(mujoco_model.model, mujoco_model.data) + + return mj_observation(mujoco_model=mujoco_model) + + def mj_step(action: npt.NDArray, mujoco_model: MujocoModel) -> None: + """""" + + n_steps = int(0.050 / mujoco_model.model.opt.timestep) + mujoco_model.data.ctrl = np.atleast_1d(action.squeeze()).copy() + mujoco.mj_step(mujoco_model.model, mujoco_model.data, n_steps) + + # m.data.qpos = np.array([0.0, np.deg2rad(180)]) + # m.data.qvel = np.array([0.0, 0.0]) + # mujoco.mj_forward(m.model, m.data) + + # # Create the policy closure + # policy = lambda observation: model.policy.predict( + # # observation=observation, deterministic=True + # observation=vec_env.normalize_obs(observation), deterministic=True + # )[0] + + # ============== + # Mujoco regular + # ============== + + model_xml_path = ( + pathlib.Path.home() + / "git" + / "jaxsim" + / "examples" + / "resources" + / "cartpole_mj.xml" + ) + + self = m = MujocoModel(xml_path=model_xml_path) + + mj_action = [] + mj_pos_cart = [] + mj_pos_pole = [] + + done = False + iterations = 0 + mj_reset(mujoco_model=m) + + while not done: + iterations += 1 + observation = mj_observation(mujoco_model=m) + # action = model.policy.predict(observation=observation, deterministic=True)[0] + obs_policy = observation.copy() + obs_policy[2] = np.arctan2(np.sin(obs_policy[2]), np.cos(obs_policy[2])) + action = policy(obs_policy) + + mj_step(action=action, mujoco_model=m) + + mj_action.append(action * 50) + mj_pos_cart.append(observation[0]) + mj_pos_pole.append(observation[2]) + + import time + + time.sleep(0.050) + + print(observation, "\t", action) + + if iterations >= 201: + break + + # ====== + # Jaxsim + # ====== + + # js_action = [] + # js_pos_cart = [] + # js_pos_pole = [] + # + # done = False + # iterations = 0 + # observation, _ = env_eval.reset() + # + # while not done: + # iterations += 1 + # action = policy(observation) + # observation, _, _, _, _, = env_eval.step(action) + # + # js_action.append(action * 50) + # js_pos_cart.append(observation[0]) + # js_pos_pole.append(observation[2]) + # + # # import time + # # + # # time.sleep(0.050) + # + # print(observation, "\t", action) + # + # if iterations >= 201: + # break + + # ============ + # Mujoco alt 1 + # ============ + + model_xml_path = ( + pathlib.Path.home() + / "git" + / "jaxsim" + / "examples" + / "resources" + / "cartpole_mj.xml" + ) + + m = MujocoModel(xml_path=model_xml_path) + + mj_action_alt1 = [] + mj_pos_cart_alt1 = [] + mj_pos_pole_alt1 = [] + + done = False + iterations = 0 + mj_reset(mujoco_model=m) + + while not done: + iterations += 1 + observation = mj_observation(mujoco_model=m) + # action = policy(observation) + obs_policy = observation.copy() + obs_policy[2] = np.arctan2(np.sin(obs_policy[2]), np.cos(obs_policy[2])) + action = policy(obs_policy) + mj_step(action=action, mujoco_model=m) + + mj_action_alt1.append(action * 50) + mj_pos_cart_alt1.append(observation[0]) + mj_pos_pole_alt1.append(observation[2]) + + import time + + time.sleep(0.050) + + print(observation, "\t", action) + + if iterations >= 201: + break + + mj_action_alt1 = np.array(mj_action_alt1) + mj_pos_cart_alt1 = np.array(mj_pos_cart_alt1) + mj_pos_pole_alt1 = np.array(mj_pos_pole_alt1) + + # ============ + # Mujoco alt 2 + # ============ + + model_xml_path = ( + pathlib.Path.home() + / "git" + / "jaxsim" + / "examples" + / "resources" + / "cartpole_mj.xml" + ) + + m = MujocoModel(xml_path=model_xml_path) + + mj_action_alt2 = [] + mj_pos_cart_alt2 = [] + mj_pos_pole_alt2 = [] + + done = False + iterations = 0 + mj_reset(mujoco_model=m) + + while not done: + iterations += 1 + observation = mj_observation(mujoco_model=m) + # action = policy(observation) + obs_policy = observation.copy() + obs_policy[2] = np.arctan2(np.sin(obs_policy[2]), np.cos(obs_policy[2])) + action = policy(obs_policy) + mj_step(action=action, mujoco_model=m) + + mj_action_alt2.append(action * 50) + mj_pos_cart_alt2.append(observation[0]) + mj_pos_pole_alt2.append(observation[2]) + + import time + + time.sleep(0.050) + + print(observation, "\t", action) + + if iterations >= 201: + break + + mj_action_alt2 = np.array(mj_action_alt2) + mj_pos_cart_alt2 = np.array(mj_pos_cart_alt2) + mj_pos_pole_alt2 = np.array(mj_pos_pole_alt2) + + # ==== + # Plot + # ==== + + import palettable + + # https://jiffyclub.github.io/palettable/cartocolors/diverging/ + # colors = palettable.cartocolors.diverging.Geyser_5.mpl_colors + colors = palettable.cartocolors.qualitative.Prism_8.mpl_colors + + fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True) + time = np.arange(start=0, stop=len(mj_action)) * 0.050 + + # ax1.plot(time, js_pos_pole, label=r"Jaxsim", color=colors[1], linewidth=1) + # ax1.plot(time, mj_pos_pole, label=r"Mujoco", color=colors[7], linewidth=1) + # ax2.plot(time, js_pos_cart, label=r"Jaxsim", color=colors[1], linewidth=1) + # ax2.plot(time, mj_pos_cart, label=r"Mujoco", color=colors[7], linewidth=1) + # ax3.plot(time, js_action, label=r"Jaxsim", color=colors[1], linewidth=1) + # ax3.plot(time, mj_action, label=r"Mujoco", color=colors[7], linewidth=1) + + ax1.plot(time, mj_pos_pole, label=r"nominal", color=colors[1], linewidth=1) + ax1.plot(time, mj_pos_pole_alt1, label=r"mass", color=colors[3], linewidth=1) + ax1.plot( + time, mj_pos_pole_alt2, label=r"mass+friction", color=colors[7], linewidth=1 + ) + ax2.plot(time, mj_pos_cart, label=r"nominal", color=colors[1], linewidth=1) + ax2.plot(time, mj_pos_cart_alt1, label=r"mass", color=colors[3], linewidth=1) + ax2.plot( + time, mj_pos_cart_alt2, label=r"mass+friction", color=colors[7], linewidth=1 + ) + ax3.plot(time, mj_action, label=r"nominal", color=colors[1], linewidth=1) + ax3.plot(time, mj_action_alt1, label=r"mass", color=colors[3], linewidth=1) + ax3.plot(time, mj_action_alt2, label=r"mass+friction", color=colors[7], linewidth=1) + + ax1.grid() + ax1.set_ylabel(r"Pole angle $\theta$ [rad]") + ax2.grid() + ax2.set_ylabel(r"Cart position $d$ [m]") + ax3.grid() + ax3.set_ylabel(r"Force applied to cart $f$ [N]") + + # ax1.set_title(r"Pole angle $\theta$") + # ax2.set_title(r"Cart position $d$") + # ax3.set_title(r"Force applied to cart $f$") + + # ax1.legend() + # ax2.legend() + # ax3.legend() + + # plt.legend() + # plt.title(r"\textbf{Comparison of cartpole swing-up performance}") + # fig.suptitle(r"\textbf{Comparison of cartpole swing-up performance}") + fig.supxlabel("Time [s]") + # plt.show() + + import tikzplotlib + + tikzplotlib.clean_figure() + print(tikzplotlib.get_tikz_code()) + +# Train with SB +if __name__ == "__main__ant_vec_gpu_env": + """""" + + max_episode_steps = 1000 + func_env = NaNHandlerWrapper(env=AntReachTargetFuncEnvV0()) + + if max_episode_steps is not None: + func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps) + + func_env = ClipActionWrapper( + env=SquashActionWrapper(env=func_env), + ) + + # TODO: rename _sb to prevent collision with module + vec_env_sb = make_vec_env_stable_baselines( + jax_dataclass_env=func_env, + # n_envs=10, + # n_envs=2048, # troppo -> JIT lungo + # n_envs=100, + # n_envs=1024, + # n_envs=2048, + n_envs=512, + seed=42, + vec_env_kwargs=dict( + jit_compile=True, + ), + ) + + # %time _ = vec_env_sb.reset() + # %time _ = vec_env_sb.reset() + # actions = vec_env_sb.jax_vector_env.action_space.sample() + # %time _ = vec_env_sb.step(actions) + # %time _ = vec_env_sb.step(actions) + + import torch as th + + # TODO: se ogni reset c'e' 1 sec di sim -> mega lento perche' ci sara' sempre + # un env che si sta resettando! + + model = PPO( + "MlpPolicy", + env=vec_env_sb, + # n_steps=2048, + # n_steps=512, # in the vector env -> real ones are x10 + # n_steps=10, # in the vector env -> real ones are x10 + n_steps=4, # in the vector env -> real ones are x2048 + # batch_size=256, + batch_size=1024, + n_epochs=10, + gamma=0.95, + gae_lambda=0.9, + clip_range=0.1, + normalize_advantage=True, + # target_kl=0.010, + target_kl=0.025, + verbose=1, + learning_rate=0.000_300, + policy_kwargs=dict( + activation_fn=th.nn.ReLU, + net_arch=dict(pi=[2048, 1024], vf=[1024, 1024, 256]), + # net_arch=dict(pi=[2048, 2048], vf=[2048, 1024, 512]), + log_std_init=np.log(0.05), + # squash_output=True, + ), + ) + + print(model.policy) + + # Create the evaluation environment + env_eval = make_jax_env_ant( + render_mode="meshcat_viz", + max_episode_steps=1000, + ) + + for _ in range(10): + # Train the model + model = model.learn(total_timesteps=500_000, progress_bar=False) + + # Create the policy closure + policy = lambda observation: model.policy.predict( + observation=observation, deterministic=True + )[0] + + # Evaluate the policy + print("Evaluating...") + evaluate( + env=env_eval, + num_episodes=10, + seed=None, + render=True, + # policy=policy, + ) + + # evaluate( + # env=env_eval, + # num_episodes=10, + # seed=None, + # render=True, + # # policy=policy, + # ) + +# ============= +# RANDOM POLICY +# ============= + +# TODO: generate a JaxVecEnv and validate with DummyVecEnv from SB + +if __name__ == "__main__+": + """""" + + # Create the environment + env = make_jax_env_ant(render_mode="meshcat_viz", max_episode_steps=None) + # env = make_jax_env_cartpole(render_mode="meshcat_viz", max_episode_steps=10) + + # Reset the environment + # observation, state_info = env.reset(seed=42) + observation, state_info = env.reset() + + # Initialize a random policy + random_policy = lambda env, obs: env.action_space.sample() + + # Initialize done flag + done = False + + env.render() + + # s = env.state + # env.func_env.env.transition(state=s, action=env.func_env.action_space.sample()) + # + # with env.func_env.unwrapped.jaxsim.editable(validate=True) as sim: + # sim.data = env.state["env"] + + # import time + # + # time.sleep(2) + + i = 0 + cum_reward = 0.0 + + while not done: + i += 1 + + if i == 2000: + done = True + + # Sample a random action + # action = 0.1* random_policy(env, observation) + action, _ = model.policy.predict(observation=observation, deterministic=False) + + # Step the environment + observation, reward, terminal, truncated, step_info = env.step(action=action) + + print(reward) + cum_reward += reward + + # Render the environment + _ = env.render() + + # print(observation, reward, terminal, truncated, step_info) + # print(env.state) + + print(cum_reward) + + env.close() + +# ================= +# TRAINING PPO/TRPO +# ================= + +if __name__ == "__main__)": + """Stable Baselines""" + + # Initialize properties + # seed = 42 + + # Create a single environment + # func_env = CartpoleSwingUpFuncEnvV0() + + # env = JaxEnv( + # func_env=ToNumPyWrapper( + # env=JaxTransformWrapper( + # function=jax.jit, + # env=FlattenSpacesWrapper(env=CartpoleSwingUpFuncEnvV0()), + # ) + # ) + # ) + + # TODO: try with single env first? + + # env = JaxEnv( + # func_env=ToNumPyWrapper( + # env=JaxTransformWrapper( + # function=jax.jit, + # env=FlattenSpacesWrapper( + # env=TimeLimit(env=func_env, max_episode_steps=1_000) + # ), + # ) + # ) + # ) + + # def make_env() -> gym.Env: + # def make_jax_env(max_episode_steps: Optional[int] = 500) -> JaxEnv: + # """""" + # + # # TODO: single env -> time limit with stable_baselines? + # + # if max_episode_steps is None: + # env = CartpoleSwingUpFuncEnvV0() + # else: + # env = TimeLimit( + # env=CartpoleSwingUpFuncEnvV0(), max_episode_steps=max_episode_steps + # ) + # + # return JaxEnv( + # func_env=ToNumPyWrapper( + # env=JaxTransformWrapper( + # function=jax.jit, + # env=FlattenSpacesWrapper(env=env), + # ) + # ) + # ) + + # check_env(env=env, warn=True, skip_render_check=True) + + # observation = env.reset() + # action = env.action_space.sample() + # observation, reward, terminated, truncate, info = env.step(action) + # print(observation, reward, terminated, truncate, info) + + # + # + # + + # vec_env = make_vec_env_stable_baselines( + # jax_dataclass_env=func_env, + # n_envs=10, + # seed=42, + # vec_env_kwargs=dict( + # max_episode_steps=5_000, + # jit_compile=True, + # ), + # ) + + # vec_env.reset() + + vec_env = make_vec_env( + # env_id=lambda: make_jax_env_cartpole(max_episode_steps=500), + env_id=lambda: make_jax_env_ant(max_episode_steps=2_000), + n_envs=10, + seed=42, + vec_env_cls=SubprocVecEnv, + ) + + import torch as th + + # ANT + model = PPO( + "MlpPolicy", + env=vec_env, + # n_steps=2048, + n_steps=512, # in the vector env -> real ones are x10 + batch_size=256, + n_epochs=10, + gamma=0.95, + gae_lambda=0.98, + clip_range=0.1, + normalize_advantage=True, + # target_kl=0.010, + target_kl=0.025, + verbose=1, + learning_rate=0.000_300, + policy_kwargs=dict( + activation_fn=th.nn.ReLU, + net_arch=dict(pi=[1024, 512], vf=[1024, 512]), + # log_std_init=0.10, + # log_std_init=2.5, + log_std_init=np.log(0.25), + # squash_output=True, + ), + ) + + print(model.policy) + model = model.learn(total_timesteps=200_000, progress_bar=False) + # model = model.learn(total_timesteps=500_000, progress_bar=False) + + # + # CARTPOLE + # + + # TODO: with squash_output -> do I need to apply tanh to the action? + # model = PPO( + # "MlpPolicy", + # env=vec_env, + # # n_steps=2048, + # n_steps=256, # in the vector env -> real ones are x10 + # batch_size=256, + # n_epochs=10, + # gamma=0.95, + # gae_lambda=0.9, + # clip_range=0.1, + # normalize_advantage=True, + # # target_kl=0.010, + # target_kl=0.025, + # verbose=1, + # learning_rate=0.000_300, + # policy_kwargs=dict( + # activation_fn=th.nn.ReLU, + # net_arch=dict(pi=[256, 256], vf=[256, 256]), + # log_std_init=0.05, + # # squash_output=True, + # ), + # ) + + model = TRPO( + "MlpPolicy", + env=vec_env, + n_steps=256, # in the vector env -> real ones are x10 + batch_size=1024, + gamma=0.95, + gae_lambda=0.95, + normalize_advantage=True, + target_kl=0.025, + verbose=1, + learning_rate=0.000_300, + policy_kwargs=dict( + activation_fn=th.nn.ReLU, + net_arch=dict(pi=[256, 256], vf=[256, 256]), + log_std_init=0.05, # TODO np.log()..., maybe 0.05 too large? + # squash_output=True, + ), + ) + + print(model.policy) + model = model.learn(total_timesteps=500_000, progress_bar=False) + # model.learn(total_timesteps=25_000) + # model.save("ppo_cartpole") + + # # Vectorize the environment. + # # Note: it automatically wraps the environment in a TimeLimit wrapper. + # vec_env = JaxVectorEnv( + # func_env=env, + # num_envs=10, + # max_episode_steps=5_000, + # jit_compile=True, + # ) + # + # from jaxgym.vector.jax import FlattenSpacesVecWrapper + # + # vec_env = FlattenSpacesVecWrapper(env=vec_env) + # + # # test.reset(seed=0) + # # test.step(actions=test.action_space.sample()) + + # ========= + # Visualize + # ========= + + visualize = False + + if visualize: + rollout_visualizer = visualizer(env=lambda: make_jax_env(1_000), policy=model) + + import time + + time.sleep(3) + rollout_visualizer(None) + + +if __name__ == "__main___": + """""" + + # Initialize properties + seed = 42 + # num_envs = 3 + + # Create a single environment + func_env = CartpoleSwingUpFuncEnvV0() + + from jaxgym.jax.env import JaxEnv + from jaxgym.wrappers.jax import ( + ClipActionWrapper, + FlattenSpacesWrapper, + JaxTransformWrapper, + ) + + func_env = ClipActionWrapper(env=func_env) + func_env = FlattenSpacesWrapper(env=func_env) + func_env = JaxTransformWrapper(env=func_env, function=jax.jit) + + # state = func_env.initial(rng=jax.random.PRNGKey(seed=seed)) + # action = func_env.action_space.sample() + # func_env.transition(state, action) + + env = JaxEnv(func_env=func_env) + + observation, state_info = env.reset(seed=seed) + + +# TODO: specs kind of ok -> see bottom of cartpole +if __name__ == "__main___": + """""" + + # Initialize properties + seed = 42 + # num_envs = 3 + + # Create a single environment + env = CartpoleSwingUpFuncEnvV0() + + # Vectorize the environment. + # Note: it automatically wraps the environment in a TimeLimit wrapper. + vec_env = JaxVectorEnv( + func_env=env, + num_envs=1_000, + max_episode_steps=10, + jit_compile=True, + ) + + # FRIDAY: + # from jaxgym.functional.jax.flatten_spaces import FlattenSpacesWrapper + # test = FlattenSpacesWrapper(env=env) + # test.transition( + # state=test.initial(rng=jax.random.PRNGKey(0)), action=test.action_space.sample() + # ) + + from jaxgym.vector.jax import FlattenSpacesVecWrapper + + test = FlattenSpacesVecWrapper(env=vec_env) + test.reset(seed=0) + test.step(actions=test.action_space.sample()) + + # import cProfile + # from pstats import SortKey + # + # with cProfile.Profile() as pr: + # + # for _ in range(1000): + # _ = test.step(actions=test.action_space.sample()) + # + # pr.print_stats(sort=SortKey.CUMULATIVE) + # + # exit(0) + + # o = test.env.observation_space.sample() + # test.env.observation_space.flatten_sample(o) + + # exit(0) + # raise # all good! + # TODO: benchmark non-jit sections + # TODO: benchmark from the code instead cmdline after jit compilation + + # Reset the environment. + # This has to be done only once since the vectorized environment supports autoreset. + observations, state_infos = vec_env.reset(seed=seed) + + # Initialize a random policy + random_policy = lambda obs: vec_env.action_space.sample() + + for _ in range(1): + # Sample random actions + actions = random_policy(observations) + + # Step the environment + observations, rewards, terminals, truncated, step_infos = vec_env.step( + action=actions + ) + + print(observations, rewards, terminals, truncated, step_infos) diff --git a/src/jaxgym/_spaces/__init__.py b/src/jaxgym/_spaces/__init__.py new file mode 100644 index 000000000..36c133a4b --- /dev/null +++ b/src/jaxgym/_spaces/__init__.py @@ -0,0 +1 @@ +from .space import Space diff --git a/src/jaxgym/_spaces/pytree_orig.py b/src/jaxgym/_spaces/pytree_orig.py new file mode 100644 index 000000000..7403cf98a --- /dev/null +++ b/src/jaxgym/_spaces/pytree_orig.py @@ -0,0 +1,240 @@ +import gymnasium.spaces +import jax.flatten_util +import jax.numpy as jnp +import jax.tree_util +import numpy as np +import numpy.typing as npt +from gymnasium.spaces.utils import flatdim, flatten +from gymnasium.vector.utils.spaces import batch_space + +import jaxsim.typing as jtp +from jaxsim.utils import not_tracing, tracing + +from .space import Space + +# TODO: inherit from gymnasium.spaces? + + +class PyTree(Space): + """""" + + def __init__(self, low: jtp.PyTree, high: jtp.PyTree): + """""" + + # ========================== + # Check low and high pytrees + # ========================== + # TODO: make generic (pytrees_with_same_dtype|shape|supported_dtype) and move + # to utils + + # supported_dtypes = { + # jnp.array(0, dtype=jnp.float32).dtype, + # jnp.array(0, dtype=jnp.float64).dtype, + # jnp.array(0, dtype=int).dtype, + # jnp.array(0, dtype=bool).dtype, + # } + # + # dtypes_supported, _ = jax.flatten_util.ravel_pytree( + # jax.tree_util.tree_map( + # lambda l1, l2: jnp.array(l1).dtype in supported_dtypes + # and jnp.array(l2).dtype in supported_dtypes, + # low, + # high, + # ) + # ) + # + # if not jnp.all(dtypes_supported): + # # if not_tracing(low) and not jnp.all(dtypes_supported): + # # if jnp.where(jnp.array([tracing(low), jnp.all(dtypes_supported)]).any(), False, True): + # # if np.any([not_tracing(low), jnp.all(dtypes_supported)]): + # # if not_tracing(low): + # raise ValueError( + # "Either low or high pytrees have attributes with unsupported dtype" + # ) + # + # shape_match, _ = jax.flatten_util.ravel_pytree( + # jax.tree_util.tree_map( + # lambda l1, l2: jnp.array(l1).shape == jnp.array(l2).shape, low, high + # ) + # ) + # + # if not jnp.all(shape_match): + # # if not_tracing(low) and not jnp.all(shape_match): + # raise ValueError("Wrong shape of low and high attributes") + # + # dtype_match, _ = jax.flatten_util.ravel_pytree( + # jax.tree_util.tree_map( + # lambda l1, l2: jnp.array(l1).dtype == jnp.array(l2).dtype, low, high + # ) + # ) + # + # if not jnp.all(dtype_match): + # # if not_tracing(low) and not jnp.all(dtype_match): + # raise ValueError("Wrong dtype of low and high attributes") + + # Flatten the pytrees + low_flat, _ = jax.flatten_util.ravel_pytree(low) + high_flat, _ = jax.flatten_util.ravel_pytree(high) + + if low_flat.dtype != high_flat.dtype: + raise ValueError(low_flat.dtype, high_flat.dtype) + + if low_flat.shape != high_flat.shape: + raise ValueError(low_flat.shape, high_flat.shape) + + # Transform all leafs to array and store them in the object + self.low = jax.tree_util.tree_map(lambda l: jnp.array(l), low) + self.high = jax.tree_util.tree_map(lambda l: jnp.array(l), high) + self.shape = low_flat.shape + + # TODO: what if key is a vector? + def sample(self, key: jax.random.PRNGKeyArray) -> jtp.PyTree: + """""" + + def random_array( + key, shape: tuple, min: jtp.PyTree, max: jtp.PyTree, dtype + ) -> jtp.Array: + """Helper to select the right sampling function for the supported dtypes""" + + match dtype: + case jnp.float32.dtype | jnp.float64.dtype: + return jax.random.uniform( + key=key, + shape=shape, + minval=min, + maxval=max, + dtype=dtype, + ) + case jnp.int16.dtype | jnp.int32.dtype | jnp.int64.dtype: + return jax.random.randint( + key=key, + shape=shape, + minval=min, + maxval=max + 1, + dtype=dtype, + ) + case jnp.bool_.dtype: + return jax.random.randint( + key=key, + shape=shape, + minval=min, + maxval=max + 1, + ).astype(bool) + case _: + raise ValueError(dtype) + + # Create and flatten a tree having a PRNGKey for each leaf + key_pytree = jax.tree_util.tree_map(lambda l: jax.random.PRNGKey(0), self.low) + key_pytree_flat, unflatten_fn = jax.flatten_util.ravel_pytree(key_pytree) + + # Generate a pytree having a subkey in each leaf + key, *subkey_flat = jax.random.split(key=key, num=key_pytree_flat.size / 2 + 1) + subkey_pytree = unflatten_fn(jnp.array(subkey_flat).flatten()) + + # Generate a pytree sampling leafs according to their dtype and using a + # different key for each of them + return jax.tree_util.tree_map( + lambda low, high, subkey: random_array( + key=key, shape=low.shape, min=low, max=high, dtype=low.dtype + ), + self.low, + self.high, + subkey_pytree, + ) + + def contains(self, x: jtp.PyTree) -> bool: + """""" + + def is_inside_bounds(x, low, high): + return jax.lax.select( + pred=jnp.all(jnp.array([jnp.all(x >= low), jnp.all(x <= high)])), + on_true=True, + on_false=False, + ) + + contains_all_leaves = jax.tree_util.tree_map( + lambda low, high, l: is_inside_bounds(x=l, low=low, high=high), + self.low, + self.high, + x, + ) + + contains_all_leaves_flat, _ = jax.flatten_util.ravel_pytree(contains_all_leaves) + + return jnp.all(contains_all_leaves_flat) + + @property + def is_np_flattenable(self) -> bool: + """Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`.""" + + return True + + def to_box(self) -> gymnasium.spaces.Box: + """""" + + # low_flat, _ = jax.flatten_util.ravel_pytree(self.low) + # high_flat, _ = jax.flatten_util.ravel_pytree(self.high) + + # return gymnasium.spaces.Box(low=np.array(low_flat), high=np.array(high_flat)) + + return gymnasium.spaces.Box( + low=self.flatten_sample(x=self.low), high=self.flatten_sample(x=self.high) + ) + + # TODO: use above + def flatten_sample(self, x: jtp.PyTree) -> jtp.VectorJax: + """""" + + x_flat, _ = jax.flatten_util.ravel_pytree(x) + return x_flat + + def unflatten_sample(self, x: jtp.Vector) -> jtp.PyTree: + """""" + + _, unflatten_fn = jax.flatten_util.ravel_pytree(self.low) + return unflatten_fn(x) + + def clip(self, x: jtp.PyTree) -> jtp.PyTree: + """""" + + return jax.tree_util.tree_map( + lambda low, high, l: jnp.array( + jnp.clip(a=l, a_min=low, a_max=high), dtype=low.dtype + ), + self.low, + self.high, + x, + ) + + # TODO: flatten() + # TODO: unflatten() from float with proper type casting + # TODO: normalize? + + +@flatdim.register(PyTree) +def _flatdim_pytree(space: PyTree) -> int: + """""" + + low_flat, _ = jax.flatten_util.ravel_pytree(space.low) + return low_flat.size + + +@flatten.register(PyTree) +def _flatten_pytree(space: PyTree, x: jtp.PyTree) -> npt.NDArray: + """""" + + assert x in space + x_flat, _ = jax.flatten_util.ravel_pytree(x) + + return x_flat + + +@batch_space.register(PyTree) +def _batch_space_pytree(space: PyTree, n: int = 1) -> PyTree: + """""" + + low_batched = jax.tree_util.tree_map(lambda l: jnp.stack([l] * n), space.low) + high_batched = jax.tree_util.tree_map(lambda l: jnp.stack([l] * n), space.high) + + # TODO: np_random + return PyTree(low=low_batched, high=high_batched) diff --git a/src/jaxgym/_spaces/space.py b/src/jaxgym/_spaces/space.py new file mode 100644 index 000000000..7303e941a --- /dev/null +++ b/src/jaxgym/_spaces/space.py @@ -0,0 +1,30 @@ +import abc + +import gymnasium.spaces +import jax + +import jaxsim.typing as jtp + + +class Space(abc.ABC): + """""" + + # TODO: add num for multiple samples? Or if multiple keys -> multiple samples? + @abc.abstractmethod + def sample(self, key: jax.random.PRNGKey) -> jtp.PyTree: + """""" + pass + + @abc.abstractmethod + def contains(self, x: jtp.PyTree) -> bool: + """""" + pass + + # @abc.abstractmethod + # def to_gymnasium(self) -> gymnasium.Space: + # """""" + # pass + + def __contains__(self, x: jtp.PyTree) -> bool: + """""" + return self.contains(x) diff --git a/src/jaxgym/envs/__init__.py b/src/jaxgym/envs/__init__.py new file mode 100644 index 000000000..1aff17745 --- /dev/null +++ b/src/jaxgym/envs/__init__.py @@ -0,0 +1,19 @@ +from typing import Any + +from gymnasium.envs.registration import ( + load_plugin_envs, + make, + pprint_registry, + register, + registry, + spec, +) + +register( + id="CartpoleSwingUpEnv-V0", + entry_point="jaxgym.envs.cartpole:CartpoleSwingUpEnvV0", + vector_entry_point="jaxgym.envs.cartpole:CartpoleSwingUpVectorEnvV0", + # max_episode_steps=5_000, + # reward_threshold=195.0, + # kwargs=dict(max_episode_steps=5_000, blabla=True), +) diff --git a/src/jaxgym/envs/ant.py b/src/jaxgym/envs/ant.py new file mode 100644 index 000000000..de8b4d472 --- /dev/null +++ b/src/jaxgym/envs/ant.py @@ -0,0 +1,732 @@ +import dataclasses +import pathlib +from typing import Any, ClassVar, Optional + +import jax.numpy as jnp +import jax.random +import jax_dataclasses +import numpy as np +import numpy.typing as npt +import rod + +import jaxgym.jax.pytree_space as spaces +import jaxsim.typing as jtp +from jaxgym.jax import JaxDataclassEnv, JaxEnv +from jaxgym.vector.jax import JaxVectorEnv +from jaxsim import JaxSim +from jaxsim.physics.algos.soft_contacts import SoftContactsParams +from jaxsim.simulation import simulator_callbacks +from jaxsim.simulation.ode_integration import IntegratorType +from jaxsim.simulation.simulator import SimulatorData, VelRepr +from jaxsim.utils import JaxsimDataclass, Mutability + + +@jax_dataclasses.pytree_dataclass +class AntObservation(JaxsimDataclass): + """Observation of the Ant environment.""" + + base_height: jtp.Float + gravity_projection: jtp.Array + + joint_positions: jtp.Array + joint_velocities: jtp.Array + + base_linear_velocity: jtp.Array + base_angular_velocity: jtp.Array + + contact_state: jtp.Array + + distance_from_goal_x: jtp.Float + distance_from_goal_y: jtp.Float + + @staticmethod + def build( + base_height: jtp.Float, + gravity_projection: jtp.Array, + joint_positions: jtp.Array, + joint_velocities: jtp.Array, + base_linear_velocity: jtp.Array, + base_angular_velocity: jtp.Array, + contact_state: jtp.Array, + distance_from_goal_x: jtp.Float, + distance_from_goal_y: jtp.Float, + ) -> "AntObservation": + """Build an AntObservation object.""" + + return AntObservation( + base_height=jnp.array(base_height, dtype=float), + gravity_projection=jnp.array(gravity_projection, dtype=float), + joint_positions=jnp.array(joint_positions, dtype=float), + joint_velocities=jnp.array(joint_velocities, dtype=float), + base_linear_velocity=jnp.array(base_linear_velocity, dtype=float), + base_angular_velocity=jnp.array(base_angular_velocity, dtype=float), + contact_state=jnp.array(contact_state, dtype=bool), + distance_from_goal_x=jnp.array(distance_from_goal_x, dtype=float), + distance_from_goal_y=jnp.array(distance_from_goal_y, dtype=float), + ) + + +import multiprocessing + +from meshcat_viz import MeshcatWorld + + +@dataclasses.dataclass +class MeshcatVizRenderState: + """Render state of a meshcat-viz visualizer.""" + + world: MeshcatWorld = dataclasses.dataclass(init=False) + + _gui_process: Optional[multiprocessing.Process] = dataclasses.field( + default=None, init=False, repr=False, hash=False, compare=False + ) + + _jaxsim_to_meshcat_viz_name: dict[str, str] = dataclasses.field( + default_factory=dict, init=False, repr=False, hash=False, compare=False + ) + + def __post_init__(self) -> None: + """""" + + self.world = MeshcatWorld() + self.world.open() + + def close(self) -> None: + """""" + + if self.world is not None: + self.world.close() + + if self._gui_process is not None: + self._gui_process.terminate() + self._gui_process.close() + + @staticmethod + def open_window(web_url: str) -> None: + """Open a new window with the given web url.""" + + import webview + + print(web_url) + webview.create_window("meshcat", web_url) + webview.start(gui="qt") + + def open_window_in_process(self) -> None: + """""" + + if self._gui_process is not None: + self._gui_process.terminate() + self._gui_process.close() + + self._gui_process = multiprocessing.Process( + target=MeshcatVizRenderState.open_window, args=(self.world.web_url,) + ) + self._gui_process.start() + + +StateType = dict[str, SimulatorData | jtp.Array] +ActType = jnp.ndarray +ObsType = AntObservation +RewardType = float | jnp.ndarray +TerminalType = bool | jnp.ndarray +RenderStateType = MeshcatVizRenderState + + +@jax_dataclasses.pytree_dataclass +class AntReachTargetFuncEnvV0( + JaxDataclassEnv[ + StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + ] +): + """Ant environment implementing a target reaching task.""" + + name: ClassVar = jax_dataclasses.static_field(default="AntReachTargetFuncEnvV0") + + # Store an instance of the JaxSim simulator. + # It gets initialized with SimulatorData with a functional approach. + _simulator: JaxSim = jax_dataclasses.field(default=None) + + def __post_init__(self) -> None: + """Environment initialization.""" + + # Dummy initialization (not needed here) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + _ = self.jaxsim + model = self.jaxsim.get_model(model_name="ant") + + # simulator_data = self.initial(rng=jax.random.PRNGKey(seed=0)) + # dofs = simulator_data.models["ant"].dofs() + # dofs = model.dofs() + + # Create the action space (static attribute) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + high = jnp.array([25.0] * model.dofs(), dtype=float) + self._action_space = spaces.PyTree(low=-high, high=high) + + # Get joint limits + s_min, s_max = model.joint_limits() + s_range = s_max - s_min + + low = AntObservation.build( + base_height=0.25, + gravity_projection=-jnp.ones(3), + # joint_positions=s_min - 0.10 * s_range, + joint_positions=s_min, + # joint_velocities=-4.0 * jnp.ones_like(s_min), + joint_velocities=-50.0 * jnp.ones_like(s_min), + base_linear_velocity=-5.0 * jnp.ones(3), + base_angular_velocity=-10.0 * jnp.ones(3), + contact_state=jnp.array([False] * 4), + distance_from_goal_x=-10.0, + distance_from_goal_y=-10.0, + ) + + high = AntObservation.build( + base_height=1.0, + gravity_projection=jnp.ones(3), + # joint_positions=s_max + 0.10 * s_range, + joint_positions=s_max, + # joint_velocities=4.0 * jnp.ones_like(s_max), + joint_velocities=50.0 * jnp.ones_like(s_max), + base_linear_velocity=5.0 * jnp.ones(3), + base_angular_velocity=10.0 * jnp.ones(3), + contact_state=jnp.array([True] * 4), + distance_from_goal_x=10.0, + distance_from_goal_y=10.0, + ) + + # Create the observation space (static attribute) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + self._observation_space = spaces.PyTree(low=low, high=high) + + @property + def jaxsim(self) -> JaxSim: + """""" + + if self._simulator is not None: + return self._simulator + + # Create the jaxsim simulator. + # We use a small integration step so that contact detection is more accurate, + # and perform multiple integration steps when we apply the action. + simulator = JaxSim.build( + # Note: any change of either 'step_size' or 'steps_per_run' requires + # updating the number of integration steps in the 'transition' method. + # step_size=0.000_500, + step_size=0.000_250, + steps_per_run=1, + # velocity_representation=VelRepr.Inertial, # TODO + velocity_representation=VelRepr.Body, + integrator_type=IntegratorType.EulerSemiImplicit, + simulator_data=SimulatorData( + gravity=jnp.array([0, 0, -10.0]), + # contact_parameters=SoftContactsParams.build(K=5_000, D=10), + contact_parameters=SoftContactsParams.build(K=10_000, D=20), + ), + ).mutable(mutable=True, validate=False) + + # Get the SDF path + model_sdf_path = ( + pathlib.Path.home() + / "git" + / "jaxsim" + / "examples" + / "resources" + / "ant.sdf" + ) + + # TODO: load with rod and change the pos limits spring & friction params + + # Insert the model + _ = simulator.insert_model_from_description( + model_description=model_sdf_path, model_name="ant" + ) + + # Fix the pytree structure of the model data so that its corresponding shape + # does not change. This is important to keep enabled the shape validation + # checks for JIT compilation. + simulator.data.models = { + model_name: jax.tree_util.tree_map(lambda leaf: jnp.array(leaf), model_data) + for model_name, model_data in simulator.data.models.items() + } + + # Store the simulator object and configure it as immutable with enabled + # pytree structure validation. + # This is done to ensure that the corresponding pytree structure remains constant, + # preventing unwanted JIT recompilations due to mistakes when setting its data. + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + self._simulator = simulator.mutable(mutable=True, validate=True) + + return self._simulator + + def initial(self, rng: Any = None) -> StateType: + """""" + + # Split the key + subkey1, subkey2 = jax.random.split(rng, num=2) + + # Sample an initial observation + initial_observation: AntObservation = self.observation_space.sample_with_key( + key=subkey1 + ) + + # Sample a goal position + goal_xy_position = jax.random.uniform( + key=subkey2, minval=-5.0, maxval=5.0, shape=(2,) + ) + + with self.jaxsim.editable(validate=False) as simulator: + # Reset the simulator and get the model + simulator.reset(remove_models=False) + model = simulator.get_model(model_name="ant") + + # Reset the joint positions + model.reset_joint_positions( + positions=initial_observation.joint_positions, + joint_names=model.joint_names(), + ) + + # Reset the joint velocities + # model.reset_joint_velocities( + # velocities=0.1 * initial_observation.joint_velocities, + # joint_names=model.joint_names(), + # ) + + # TODO: inizializzare s.t. non ci siano penetrazioni leg/terrain + # Reset the base position + model.reset_base_position( + # position=jnp.array([0, 0, initial_observation.base_height]) + position=jnp.array([0, 0, 0.5]) + ) + + # Reset the base velocity + model.reset_base_velocity( + base_velocity=jnp.hstack( + [ + 0.1 * initial_observation.base_linear_velocity, + 0.1 * initial_observation.base_angular_velocity, + ] + ) + ) + + # Simulate for 1s so that the model starts from a + # resting pose on the ground + # simulator = simulator.step_over_horizon( + # horizon_steps=2 * 1000, clear_inputs=True + # )s + + # Return the simulation state + return dict( + simulator_data=simulator.data, + goal=jnp.array(goal_xy_position, dtype=float), + ) + + def transition( + self, state: StateType, action: ActType, rng: Any = None + ) -> StateType: + """""" + + # Get the JaxSim simulator + simulator = self.jaxsim + + # Initialize the simulator with the environment state (containing SimulatorData) + with simulator.editable(validate=True) as simulator: + simulator.data = state["simulator_data"] + + @jax_dataclasses.pytree_dataclass + class SetTorquesOverHorizon(simulator_callbacks.PreStepCallback): + def pre_step(self, sim: JaxSim) -> JaxSim: + """""" + + model = sim.get_model(model_name="ant") + model.zero_input() + model.set_joint_generalized_force_targets( + forces=jnp.atleast_1d(action), joint_names=model.joint_names() + ) + + return sim + + # Compute the number of integration steps to perform + # transition_step_duration = 0.050 + # number_of_integration_steps = int(transition_step_duration / simulator.dt()) + # number_of_integration_steps = jnp.array( + # transition_step_duration / simulator.dt(), dtype=jnp.uint32 + # ) + + # number_of_integration_steps = 100 # 0.050 + number_of_integration_steps = 40 # 0.010 # TODO 20 for having 0.010 + + # Stepping logic + with simulator.editable(validate=True) as simulator: + # simulator, _ = simulator.step_over_horizon_plain( + simulator, _ = simulator.step_over_horizon( + horizon_steps=number_of_integration_steps, + clear_inputs=False, + callback_handler=SetTorquesOverHorizon(), + ) + + # Return the new environment state (updated SimulatorData) + return state | dict(simulator_data=simulator.data) + + def observation(self, state: StateType) -> ObsType: + """""" + + # Initialize the simulator with the environment state (containing SimulatorData) + # and get the simulated model + with self.jaxsim.editable(validate=True) as simulator: + simulator.data = state["simulator_data"] + model = simulator.get_model("ant") + + # Compute the normalized gravity projection in the body frame + W_R_B = model.base_orientation(dcm=True) + # W_gravity = state.simulator.gravity() + W_gravity = self.jaxsim.gravity() + B_gravity = W_R_B.T @ (W_gravity / jnp.linalg.norm(W_gravity)) + + W_p_B = model.base_position() + W_p_goal = jnp.hstack([state["goal"].squeeze(), 0]) + + # Compute the distance between the base and the goal in the body frame + B_p_distance = W_R_B.T @ (W_p_goal - W_p_B) + + # Build the observation from the state + return AntObservation.build( + base_height=model.base_position()[2], + gravity_projection=B_gravity, + joint_positions=model.joint_positions(), + joint_velocities=model.joint_velocities(), + base_linear_velocity=model.base_velocity()[0:3], + base_angular_velocity=model.base_velocity()[3:6], + contact_state=model.in_contact( + link_names=[ + name + for name in model.link_names() + if name.startswith("leg_") and name.endswith("_lower") + ] + ), + distance_from_goal_x=B_p_distance[0], + distance_from_goal_y=B_p_distance[1], + # contact_state=jnp.array( + # [ + # model.get_link(name).in_contact() + # for name in model.link_names() + # if name.startswith("leg_") and name.endswith("_lower") + # ], + # dtype=bool, + # ), + ) + + def reward( + self, state: StateType, action: ActType, next_state: StateType + ) -> RewardType: + """""" + + # with self.jaxsim.editable(validate=True) as simulator_pre: + # simulator_pre.data = state["simulator_data"] + # model_pre = simulator_pre.get_model("ant") + + with self.jaxsim.editable(validate=True) as simulator_next: + simulator_next.data = next_state["simulator_data"] + model_next = simulator_next.get_model("ant") + + # W_p_B_pre = model_pre.base_position() + # W_p_B_next = model_next.base_position() + # v_WB = (W_p_B_next - W_p_B_pre) / simulator_pre.dt() + + terminal = self.terminal(state=state) + obs_in_space = jax.lax.select( + pred=self.observation_space.contains(x=self.observation(state=state)), + on_true=1.0, + on_false=0.0, + ) + + # Position of the base + W_p_B = model_next.base_position() + W_p_xy_goal = state["goal"] + + reward = 0.0 + reward += 1.0 * (1.0 - jnp.array(terminal, dtype=float)) # alive + reward += 5.0 * obs_in_space # + # reward += 100.0 * v_WB[0] # forward velocity + reward -= jnp.linalg.norm(W_p_B[0:2] - W_p_xy_goal) # distance from goal + reward += 1.0 * model_next.in_contact( + link_names=[ + name + for name in model_next.link_names() + if name.startswith("leg_") and name.endswith("_lower") + ] + ).any().astype( + float + ) # contact status + reward -= 0.1 * jnp.linalg.norm(action) / action.size # control cost + + return reward + + def terminal(self, state: StateType) -> TerminalType: + """""" + + # Get the current observation + observation = self.observation(state=state) + + base_too_high = ( + observation.base_height >= self.observation_space.high.base_height + ) + # + # no_feet_in_contact = jnp.where(observation.contact_state.any(), False, True) + + # The state is terminal if the observation is outside is space + # return jax.lax.select( + # pred=self.observation_space.contains(x=observation), + # on_true=False, + # on_false=True, + # ) + # return jnp.array([base_too_high, no_feet_in_contact]).any() + # return no_feet_in_contact + return base_too_high + + # ========= + # Rendering + # ========= + + def render_image( + self, state: StateType, render_state: RenderStateType + ) -> tuple[RenderStateType, npt.NDArray]: + """Show the state.""" + + model_name = "ant" + + # Initialize the simulator with the environment state (containing SimulatorData) + # and get the simulated model + with self.jaxsim.editable(validate=False) as simulator: + simulator.data = state["simulator_data"] + model = simulator.get_model(model_name=model_name) + + # Insert the model lazily in the visualizer if it is not already there + if model_name not in render_state.world._meshcat_models.keys(): + from rod.urdf.exporter import UrdfExporter + + urdf_string = UrdfExporter.sdf_to_urdf_string( + sdf=rod.Sdf( + version="1.7", + model=model.physics_model.description.extra_info["sdf_model"], + ), + pretty=True, + gazebo_preserve_fixed_joints=False, + ) + + meshcat_viz_name = render_state.world.insert_model( + model_description=urdf_string, is_urdf=True, model_name=None + ) + + render_state._jaxsim_to_meshcat_viz_name[model_name] = meshcat_viz_name + + # Check that the model is in the visualizer + if ( + not render_state._jaxsim_to_meshcat_viz_name[model_name] + in render_state.world._meshcat_models.keys() + ): + raise ValueError(f"The '{model_name}' model is not in the meshcat world") + + # Update the model in the visualizer + render_state.world.update_model( + model_name=render_state._jaxsim_to_meshcat_viz_name[model_name], + joint_names=model.joint_names(), + joint_positions=model.joint_positions(), + base_position=model.base_position(), + base_quaternion=model.base_orientation(dcm=False), + ) + + return render_state, np.empty(0) + + def render_init(self, open_gui: bool = False, **kwargs) -> RenderStateType: + """Initialize the render state.""" + + # Initialize the render state + meshcat_viz_state = MeshcatVizRenderState() + + if open_gui: + meshcat_viz_state.open_window_in_process() + + return meshcat_viz_state + + def render_close(self, render_state: RenderStateType) -> None: + """Close the render state.""" + + render_state.close() + # render_state.world.close() + # + # if render_state._gui_process is not None: + # render_state._gui_process.terminate() + # render_state._gui_process.close() + + # TODO: fare classe generica che dato un JaxSim visualizza tutti i modelli + # -> vettorizzato?? metterlo dentro JaxSIm? Fare classe nuova in jaxsim? + # def update_meshcat_world( + # self, world: "MeshcatWorld", state: StateType # TODO: come gestire lo stato?? + # ) -> "MeshcatWorld": + # """""" + # + # # Initialize the simulator with the environment state (containing SimulatorData) + # # and get the simulated model + # with self.jaxsim.editable(validate=False) as simulator: + # simulator.data = state + # model = simulator.get_model("ant") + # + # # Add the model to the world if not already present + # if "ant" not in world._meshcat_models.keys(): + # _ = world.insert_model( + # model_description=( + # pathlib.Path.home() + # / "git" + # / "jaxsim" + # / "examples" + # / "resources" + # / "ant.sdf" + # ), + # is_urdf=False, + # model_name="ant", + # ) + # + # # Update the model + # world.update_model( + # model_name="ant", + # joint_names=model.joint_names(), + # joint_positions=model.joint_positions(), + # base_position=model.base_position(), + # base_quaternion=model.base_orientation(dcm=False), + # ) + # + # return world + + +class AntReachTargetEnvV0(JaxEnv): + """""" + + def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None: + """""" + + from jaxgym.wrappers.jax import ( + ClipActionWrapper, + FlattenSpacesWrapper, + JaxTransformWrapper, + TimeLimit, + ) + + func_env = AntReachTargetFuncEnvV0() + + func_env_wrapped = func_env + func_env_wrapped = TimeLimit( + env=func_env_wrapped, max_episode_steps=5_000 + ) # TODO + func_env_wrapped = ClipActionWrapper(env=func_env_wrapped) + func_env_wrapped = FlattenSpacesWrapper(env=func_env_wrapped) + func_env_wrapped = JaxTransformWrapper(env=func_env_wrapped, function=jax.jit) + + super().__init__( + func_env=func_env_wrapped, + metadata=self.metadata, + render_mode=render_mode, + ) + + +class AntReachTargetVectorEnvV0(JaxVectorEnv): + """""" + + metadata = dict() + + def __init__( + self, + # func_env: JaxDataclassEnv[ + # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + # ], + num_envs: int, + render_mode: str | None = None, + # max_episode_steps: int = 5_000, + jit_compile: bool = True, + **kwargs, + ) -> None: + """""" + + print("+++", kwargs) + + env = AntReachTargetFuncEnvV0() + + # Vectorize the environment. + # Note: it automatically wraps the environment in a TimeLimit wrapper. + super().__init__( + func_env=env, + num_envs=num_envs, + metadata=self.metadata, + render_mode=render_mode, + max_episode_steps=5_000, # TODO + jit_compile=jit_compile, + ) + + # from jaxgym.vector.jax import FlattenSpacesVecWrapper + # + # vec_env_wrapped = FlattenSpacesVecWrapper(env=vec_env) + + +if __name__ == "__main__": + """Stable Baselines""" + + from typing import Optional + + from jaxgym.wrappers.jax import ( + FlattenSpacesWrapper, + JaxTransformWrapper, + TimeLimit, + ToNumPyWrapper, + ) + + def make_jax_env( + max_episode_steps: Optional[int] = 500, jit: bool = True + ) -> JaxEnv: + """""" + + # TODO: single env -> time limit with stable_baselines? + + if max_episode_steps in {None, 0}: + env = AntReachTargetFuncEnvV0() + else: + env = TimeLimit( + env=AntReachTargetFuncEnvV0(), max_episode_steps=max_episode_steps + ) + + return JaxEnv( + func_env=ToNumPyWrapper( + env=FlattenSpacesWrapper(env=env) + if not jit + else JaxTransformWrapper( + function=jax.jit, + env=FlattenSpacesWrapper(env=env), + ), + ), + render_mode="meshcat_viz", + ) + + env = make_jax_env(max_episode_steps=5, jit=False) + + obs, state_info = env.reset(seed=0) + _ = env.render() + raise + for _ in range(5): + action = env.action_space.sample() + # obs, reward, terminated, truncated, info = env.step(action=action) + obs, reward, terminated, truncated, info = env.step( + action=jnp.zeros_like(action) + ) + + # ========= + # Visualize + # ========= + + visualize = False + + if visualize: + rollout_visualizer = visualizer(env=lambda: make_jax_env(1_000), policy=model) + + import time + + time.sleep(3) + rollout_visualizer(None) diff --git a/src/jaxgym/envs/cartpole.py b/src/jaxgym/envs/cartpole.py new file mode 100644 index 000000000..cad6830be --- /dev/null +++ b/src/jaxgym/envs/cartpole.py @@ -0,0 +1,570 @@ +import pathlib +from typing import Any, ClassVar + +import jax.numpy as jnp +import jax.random +import jax_dataclasses +import numpy as np +import numpy.typing as npt + +import jaxgym.jax.pytree_space as spaces +import jaxsim.typing as jtp +from jaxgym.envs.ant import MeshcatVizRenderState +from jaxgym.jax import JaxDataclassEnv, JaxEnv +from jaxgym.vector.jax import JaxVectorEnv +from jaxsim import JaxSim, logging +from jaxsim.simulation.ode_integration import IntegratorType +from jaxsim.simulation.simulator import SimulatorData, VelRepr +from jaxsim.utils import JaxsimDataclass, Mutability + + +@jax_dataclasses.pytree_dataclass +class CartpoleObservation(JaxsimDataclass): + """Observation of the CartPole environment.""" + + linear_pos: jtp.Float + linear_vel: jtp.Float + + pivot_pos: jtp.Float + pivot_vel: jtp.Float + + @staticmethod + def build( + linear_pos: jtp.Float, + linear_vel: jtp.Float, + pivot_pos: jtp.Float, + pivot_vel: jtp.Float, + ) -> "CartpoleObservation": + """""" + + return CartpoleObservation( + linear_pos=jnp.array(linear_pos, dtype=float), + linear_vel=jnp.array(linear_vel, dtype=float), + pivot_pos=jnp.array(pivot_pos, dtype=float), + pivot_vel=jnp.array(pivot_vel, dtype=float), + ) + + +StateType = SimulatorData +ActType = jnp.ndarray +ObsType = CartpoleObservation +RewardType = float | jnp.ndarray +TerminalType = bool | jnp.ndarray +RenderStateType = MeshcatVizRenderState + + +@jax_dataclasses.pytree_dataclass +class CartpoleSwingUpFuncEnvV0( + JaxDataclassEnv[ + StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + ] +): + """CartPole environment implementing a swing-up task.""" + + name: ClassVar = jax_dataclasses.static_field( + default="CartpoleSwingUpFunctionalEnvV0" + ) + + # Store an instance of the JaxSim simulator. + # It gets initialized with SimulatorData with a functional approach. + _simulator: JaxSim = jax_dataclasses.field(default=None) + + def __post_init__(self) -> None: + """Environment initialization.""" + + # Dummy initialization (not needed here) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + _ = self.jaxsim + # _ = self.initial(rng=jax.random.PRNGKey(seed=0)) + + # Create the action space (static attribute) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + self._action_space = spaces.PyTree( + low=jnp.array(-50.0, dtype=float), high=jnp.array(50.0, dtype=float) + ) + + low = CartpoleObservation.build( + linear_pos=-2.4, + linear_vel=-10.0, + pivot_pos=-jnp.pi, + pivot_vel=-4 * jnp.pi, + ) + + high = CartpoleObservation.build( + linear_pos=2.4, + linear_vel=10.0, + pivot_pos=jnp.pi, + pivot_vel=4 * jnp.pi, + ) + + # Create the observation space (static attribute) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + self._observation_space = spaces.PyTree(low=low, high=high) + + @property + def jaxsim(self) -> JaxSim: + """""" + + if self._simulator is not None: + return self._simulator + + # T = 0.010 + # dt = 0.001 + T = 0.050 + dt = 0.000_500 + + # Create the jaxsim simulator + simulator = JaxSim.build( + # step_size=0.001, + # steps_per_run=10, + step_size=dt, + steps_per_run=int(T / dt), + velocity_representation=VelRepr.Inertial, + integrator_type=IntegratorType.EulerSemiImplicit, + simulator_data=SimulatorData(gravity=jnp.array([0, 0, -10.0])), + ).mutable(mutable=True, validate=False) + + # Get the SDF path + model_urdf_path = ( + pathlib.Path.home() + / "git" + / "jaxsim" + / "examples" + / "resources" + / "cartpole.urdf" + ) + + # Insert the model + _ = simulator.insert_model_from_description( + model_description=model_urdf_path, model_name="cartpole" + ) + + # Fix the pytree structure of the model data so that its corresponding shape + # does not change. This is important to keep enabled the shape validation + # checks for JIT compilation. + simulator.data.models = { + model_name: jax.tree_util.tree_map(lambda leaf: jnp.array(leaf), model_data) + for model_name, model_data in simulator.data.models.items() + } + + # Store the simulator object and configure it as immutable with enabled + # pytree structure validation. + # This is done to ensure that the corresponding pytree structure remains constant, + # preventing unwanted JIT recompilations due to mistakes when setting its data. + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + self._simulator = simulator.mutable(mutable=True, validate=True) + + return self._simulator + + def initial(self, rng: Any = None) -> StateType: + """""" + + # Sample an initial observation + initial_observation: CartpoleObservation = ( + self.observation_space.sample_with_key(key=rng) + ) + + with self.jaxsim.editable(validate=False) as simulator: + # Reset the simulator and get the model + simulator.reset(remove_models=False) + model = simulator.get_model(model_name="cartpole") + + # Reset the joint positions + model.reset_joint_positions( + positions=0.9 + * jnp.array( + [initial_observation.linear_pos, initial_observation.pivot_pos] + ), + joint_names=["linear", "pivot"], + ) + + # Reset the joint velocities + model.reset_joint_velocities( + velocities=0.9 + * jnp.array( + [initial_observation.linear_vel, initial_observation.pivot_vel] + ), + joint_names=["linear", "pivot"], + ) + + # TODO: reset the joint velocities + # logging.error("ZEROOO") + # model.reset_joint_positions(positions=jnp.array([0, jnp.deg2rad(180.0)])) + # model.reset_joint_velocities(velocities=jnp.array([0, 0.0])) + + # Return the simulation state + return simulator.data + + def transition( + self, state: StateType, action: ActType, rng: Any = None + ) -> StateType: + """""" + + # Get the JaxSim simulator + simulator = self.jaxsim + + # Initialize the simulator with the environment state (containing SimulatorData) + with simulator.editable(validate=True) as simulator: + simulator.data = state + + # Stepping logic + with simulator.editable(validate=True) as simulator: + # Get the simulated model + model = simulator.get_model(model_name="cartpole") + + # Zero all the inputs + model.zero_input() + + # print(action) + # action = action.squeeze() + + # Apply a linear force to the cart + model.set_joint_generalized_force_targets( + forces=jnp.atleast_1d(action), joint_names=["linear"] + ) + + # TODO: in multi-step -> reset action? + # Or always one step and handle multi-steps with callbacks e.g. controllers? + simulator.step(clear_inputs=False) + + # Return the new environment state (updated SimulatorData) + return simulator.data + + def observation(self, state: StateType) -> ObsType: + """""" + + # Initialize the simulator with the environment state (containing SimulatorData) + # and get the simulated model + with self.jaxsim.editable(validate=False) as simulator: + simulator.data = state + model = simulator.get_model("cartpole") + + # Extract the positions and velocities of the joints + linear_pos, pivot_pos = model.joint_positions() + linear_vel, pivot_vel = model.joint_velocities() + + # Build the observation from the state + return CartpoleObservation.build( + linear_pos=linear_pos, + linear_vel=linear_vel, + # Make sure that the pivot position is always in [-π, π] + pivot_pos=jnp.arctan2(jnp.sin(pivot_pos), jnp.cos(pivot_pos)), + pivot_vel=pivot_vel, + ) + + def reward( + self, state: StateType, action: ActType, next_state: StateType + ) -> RewardType: + """""" + + # Get the current observation + observation = self.observation(state=next_state) + # observation = type(self).observation(self=self, state=next_state) + + # Compute the reward terms + reward_alive = 1.0 - jnp.array(self.terminal(state=next_state), dtype=float) + # type(self).terminal(self=self, state=next_state), dtype=float + reward_pivot = jnp.cos(observation.pivot_pos) + cost_action = jnp.sqrt(action.dot(action)) + cost_pivot_vel = jnp.sqrt(observation.pivot_vel ** 2) + cost_linear_pos = jnp.abs(observation.linear_pos) + + reward = 0 + reward += reward_alive + reward += reward_pivot + reward -= 0.001 * cost_action + reward -= 0.100 * cost_pivot_vel + reward -= 0.500 * cost_linear_pos + + return reward + + def terminal(self, state: StateType) -> TerminalType: + """""" + + # Get the current observation + observation = self.observation(state=state) + # observation = type(self).observation(self=self, state=state) + + # The state is terminal if the observation is outside is space + return jax.lax.select( + pred=self.observation_space.contains(x=observation), + on_true=False, + on_false=True, + ) + + # ========= + # Rendering + # ========= + + # ========= + # Rendering + # ========= + + def render_image( + self, state: StateType, render_state: RenderStateType + ) -> tuple[RenderStateType, npt.NDArray]: + """Show the state.""" + + model_name = "cartpole" + + # Initialize the simulator with the environment state (containing SimulatorData) + # and get the simulated model + with self.jaxsim.editable(validate=False) as simulator: + simulator.data = state + model = simulator.get_model(model_name=model_name) + + # Insert the model lazily in the visualizer if it is not already there + if model_name not in render_state.world._meshcat_models.keys(): + import rod + from rod.urdf.exporter import UrdfExporter + + urdf_string = UrdfExporter.sdf_to_urdf_string( + sdf=rod.Sdf( + version="1.7", + model=model.physics_model.description.extra_info["sdf_model"], + ), + pretty=True, + gazebo_preserve_fixed_joints=False, + ) + + meshcat_viz_name = render_state.world.insert_model( + model_description=urdf_string, is_urdf=True, model_name=None + ) + + render_state._jaxsim_to_meshcat_viz_name[model_name] = meshcat_viz_name + + # Check that the model is in the visualizer + if ( + not render_state._jaxsim_to_meshcat_viz_name[model_name] + in render_state.world._meshcat_models.keys() + ): + raise ValueError(f"The '{model_name}' model is not in the meshcat world") + + # Update the model in the visualizer + render_state.world.update_model( + model_name=render_state._jaxsim_to_meshcat_viz_name[model_name], + joint_names=model.joint_names(), + joint_positions=model.joint_positions(), + base_position=model.base_position(), + base_quaternion=model.base_orientation(dcm=False), + ) + + return render_state, np.empty(0) + + def render_init(self, open_gui: bool = False, **kwargs) -> RenderStateType: + """Initialize the render state.""" + + # Initialize the render state + meshcat_viz_state = MeshcatVizRenderState() + + if open_gui: + meshcat_viz_state.open_window_in_process() + + return meshcat_viz_state + + def render_close(self, render_state: RenderStateType) -> None: + """Close the render state.""" + + render_state.close() + + +class CartpoleSwingUpEnvV0(JaxEnv): + """""" + + def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None: + """""" + + from jaxgym.wrappers.jax import ( + ClipActionWrapper, + FlattenSpacesWrapper, + JaxTransformWrapper, + TimeLimit, + ) + + func_env = CartpoleSwingUpFuncEnvV0() + + func_env_wrapped = func_env + func_env_wrapped = TimeLimit(env=func_env_wrapped, max_episode_steps=5_000) + func_env_wrapped = ClipActionWrapper(env=func_env_wrapped) + func_env_wrapped = FlattenSpacesWrapper(env=func_env_wrapped) + func_env_wrapped = JaxTransformWrapper(env=func_env_wrapped, function=jax.jit) + + super().__init__( + func_env=func_env_wrapped, + metadata=self.metadata, + render_mode=render_mode, + ) + + +class CartpoleSwingUpVectorEnvV0(JaxVectorEnv): + """""" + + metadata = dict() + + def __init__( + self, + # func_env: JaxDataclassEnv[ + # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + # ], + num_envs: int, + render_mode: str | None = None, + # max_episode_steps: int = 5_000, + jit_compile: bool = True, + **kwargs, + ) -> None: + """""" + + # print("+++", kwargs) + + env = CartpoleSwingUpFuncEnvV0() + + # Vectorize the environment. + # Note: it automatically wraps the environment in a TimeLimit wrapper. + super().__init__( + func_env=env, + num_envs=num_envs, + metadata=self.metadata, + render_mode=render_mode, + max_episode_steps=5_000, # TODO + jit_compile=jit_compile, + ) + + # from jaxgym.vector.jax import FlattenSpacesVecWrapper + # + # vec_env_wrapped = FlattenSpacesVecWrapper(env=vec_env) + + +if __name__ == "__main__REGISTER": + """""" + + import gymnasium as gym + + import jaxgym.envs + + gym.envs.registry.keys() + + # + # + # + + env = gym.make("CartpoleSwingUpEnv-V0") + env.spec.pprint(print_all=True) + + # + # + # + + vec_env = gym.make_vec( + "CartpoleSwingUpEnv-V0", num_envs=2, vectorization_mode="custom" + ) + vec_env.spec.pprint(print_all=True) + + from jaxgym.vector.jax.wrappers import FlattenSpacesVecWrapper + + vec_env_wrapped = FlattenSpacesVecWrapper(env=vec_env) + + +if __name__ == "__main__+": + """""" + + # env = CartpoleFunctionalEnvV0() + # state = env.initial(rng=jax.random.PRNGKey(0)) + # action = env.action_space.sample(key=jax.random.PRNGKey(1)) + + key = jax.random.PRNGKey(0) + num = 1000 + + # TODO next week: + # - this is ok + # - write Env wrapper for autoreset / check gymnasium -> lambda for get_action from pytorch? + # - figure out what the info dicts of gymnasium are + # - decide how to perform training loop -> rl algos from where (jax-based or pytorch)? + # + # Pytorch is ok if we sample in parallel only a single step (e.g. on thousands of envs) + + # from jaxgym.functional.wrappers.transform import TransformWrapper + # from jaxgym.functional.wrappers.jax.time_limit import TimeLimit + # + # env = CartpoleSwingUpFunctionalEnvV0() + # env = TimeLimit(env=env, max_episode_steps=100) + # # vec_env.transform(func=jax.vmap) + # # vec_env.transform(func=jax.jit) + # vec_env = TransformWrapper(env=env, function=jax.vmap) + # vec_env = TransformWrapper(env=vec_env, function=jax.jit) + # states = vec_env.initial(rng=jax.random.split(key, num=num)) + # _ = vec_env.observation(state=states) + # action = vec_env.action_space.sample(key=jax.random.split(key, num=1).squeeze()) + # actions = jnp.repeat(action, repeats=num, axis=0) + # next_states = vec_env.transition(state=states, action=actions) + # reward = vec_env.reward(state=states, action=actions, next_state=next_states) + # infos = vec_env.step_info(state=states, action=actions, next_state=next_states) + + # from jaxgym.functional.jax.vector import JaxVectorEnv + # from jaxgym.functional.jax.time_limit import TimeLimit + # from jaxgym.functional.wrappers.transform import TransformWrapper + # from jaxgym.functional.core import FuncWrapper + # + # env = CartpoleSwingUpFunctionalEnvV0() + # + # env_wrapped = TimeLimit(env=env, max_episode_steps=100) + # env_wrapped = TransformWrapper(env=env_wrapped, function=jax.jit) + # # CartpoleSwingUpFunctionalEnvV0.transform(self=env_wrapped, func=jax.jit) + # state = env_wrapped.initial(rng=key) + # _ = env_wrapped.observation(state=state) + # action = env_wrapped.action_space.sample(key=key) + # next_state = env_wrapped.transition(state=state, action=action) + # reward = env_wrapped.reward(state=state, action=action, next_state=next_state) + # info = env_wrapped.step_info(state=state, action=action, next_state=next_state) + # next_state = env_wrapped.transition(state=next_state, action=action) + + from jaxgym.functional.jax.time_limit import TimeLimit + from jaxgym.functional.jax.transform import JaxTransformWrapper + from jaxgym.functional.jax.vector import JaxVectorEnv + from jaxgym.functional.wrappers.transform import TransformWrapper + + env = CartpoleSwingUpFunctionalEnvV0() + # env = TimeLimit(env=env, max_episode_steps=3) + num_envs = 2 + vec_env = JaxVectorEnv( + func_env=env, num_envs=num_envs, max_episode_steps=4, jit_compile=True + ) + observations, state_infos = vec_env.reset() + + actions = vec_env.action_space.sample(key=key) + # actions = jnp.repeat(jnp.atleast_2d(action).T, repeats=num_envs, axis=1).T + + # TODO: the output dict misses the final observation when truncated + # final_observation | final_info + _ = vec_env.step(action=actions) + + # self = vec_env + # env = self.func_env + # states = self.states + # keys_1 = self.subkey(num=self.num_envs) + # keys_2 = self.subkey(num=self.num_envs) + # states, _ = jax.jit(JaxVectorEnv.step_autoreset_func)( + # env, states, actions, keys_1, keys_2 + # ) + + # @jax.jit + # def split(key: jax.random.PRNGKeyArray, num: int) -> jax.random.PRNGKeyArray: + # return jax.random.split(key=key, num=num) + # + # _ = split(key, 5) + + # _ = vec_env.func_env.transition(state=vec_env.states, action=actions) + # _ = vec_env.step(action=actions) + + # observation = env.observation(state=state) + # terminal = env.terminal(state=state) + # reward = env.reward(state=state, action=action) + # next_state = env.transition(state=state, action=action, rng=jax.random.PRNGKey(2)) + + # jax.tree_util.tree_structure(state) + # jax.tree_util.tree_structure(next_state) + # + # jax.tree_util.tree_leaves(state) + # jax.tree_util.tree_leaves(next_state) + + # with env.jaxsim.editable(validate=True) as simulator: + # simulator.data = next_state diff --git a/src/jaxgym/envs/ergocub.py b/src/jaxgym/envs/ergocub.py new file mode 100644 index 000000000..c3d45457c --- /dev/null +++ b/src/jaxgym/envs/ergocub.py @@ -0,0 +1,810 @@ +import dataclasses +import functools +import multiprocessing +import os +import warnings +from typing import Any, Dict, List, Optional, Type, Union + +import gymnasium as gym +import jax.numpy as jnp +import jax.random +import jax_dataclasses +import numpy as np +import numpy.typing as npt +import rod +from gymnasium.experimental.vector.vector_env import VectorWrapper +from meshcat_viz import MeshcatWorld +from resolve_robotics_uri_py import resolve_robotics_uri +from stable_baselines3 import PPO +from stable_baselines3.common import vec_env as vec_env_sb +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import VecMonitor, VecNormalize +from torch import nn + +import jaxgym.jax.pytree_space as spaces +import jaxsim.typing as jtp +from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper, JaxEnv, PyTree +from jaxgym.vector.jax import FlattenSpacesVecWrapper, JaxVectorEnv +from jaxgym.wrappers.jax import ( + ActionNoiseWrapper, + ClipActionWrapper, + FlattenSpacesWrapper, + JaxTransformWrapper, + NaNHandlerWrapper, + SquashActionWrapper, + TimeLimit, + ToNumPyWrapper, +) +from jaxsim import JaxSim +from jaxsim.physics.algos.soft_contacts import SoftContactsParams +from jaxsim.simulation import simulator_callbacks +from jaxsim.simulation.ode_integration import IntegratorType +from jaxsim.simulation.simulator import SimulatorData, VelRepr +from jaxsim.utils import JaxsimDataclass, Mutability + +warnings.simplefilter(action="ignore", category=FutureWarning) + + +@jax_dataclasses.pytree_dataclass +class ErgoCubObservation(JaxsimDataclass): + """Observation of the ErgoCub environment.""" + + base_height: jtp.Float + gravity_projection: jtp.Array + + joint_positions: jtp.Array + joint_velocities: jtp.Array + + base_linear_velocity: jtp.Array + base_angular_velocity: jtp.Array + + contact_state: jtp.Array + + @staticmethod + def build( + base_height: jtp.Float, + gravity_projection: jtp.Array, + joint_positions: jtp.Array, + joint_velocities: jtp.Array, + base_linear_velocity: jtp.Array, + base_angular_velocity: jtp.Array, + contact_state: jtp.Array, + ) -> "ErgoCubObservation": + """Build an ErgoCubObservation object.""" + + return ErgoCubObservation( + base_height=jnp.array(base_height, dtype=float), + gravity_projection=jnp.array(gravity_projection, dtype=float), + joint_positions=jnp.array(joint_positions, dtype=float), + joint_velocities=jnp.array(joint_velocities, dtype=float), + base_linear_velocity=jnp.array(base_linear_velocity, dtype=float), + base_angular_velocity=jnp.array(base_angular_velocity, dtype=float), + contact_state=jnp.array(contact_state, dtype=bool), + ) + + +@dataclasses.dataclass +class MeshcatVizRenderState: + """Render state of a meshcat-viz visualizer.""" + + world: MeshcatWorld = dataclasses.dataclass(init=False) + + _gui_process: Optional[multiprocessing.Process] = dataclasses.field( + default=None, init=False, repr=False, hash=False, compare=False + ) + + _jaxsim_to_meshcat_viz_name: dict[str, str] = dataclasses.field( + default_factory=dict, init=False, repr=False, hash=False, compare=False + ) + + def __post_init__(self) -> None: + """""" + + self.world = MeshcatWorld() + self.world.open() + + def close(self) -> None: + """""" + + if self.world is not None: + self.world.close() + + if self._gui_process is not None: + self._gui_process.terminate() + self._gui_process.close() + + @staticmethod + def open_window(web_url: str) -> None: + """Open a new window with the given web url.""" + + import webview + + print(web_url) + webview.create_window("meshcat", web_url) + webview.start(gui="qt") + + def open_window_in_process(self) -> None: + """""" + + if self._gui_process is not None: + self._gui_process.terminate() + self._gui_process.close() + + self._gui_process = multiprocessing.Process( + target=MeshcatVizRenderState.open_window, args=(self.world.web_url,) + ) + self._gui_process.start() + + +StateType = dict[str, SimulatorData | jtp.Array] +ActType = jnp.ndarray +ObsType = ErgoCubObservation +RewardType = float | jnp.ndarray +TerminalType = bool | jnp.ndarray +RenderStateType = MeshcatVizRenderState + + +@jax_dataclasses.pytree_dataclass +class ErgoCubWalkFuncEnvV0( + JaxDataclassEnv[ + StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + ] +): + """ErgoCub environment implementing a target reaching task.""" + + name: ClassVar = jax_dataclasses.static_field(default="ErgoCubWalkFuncEnvV0") + + # Store an instance of the JaxSim simulator. + # It gets initialized with SimulatorData with a functional approach. + _simulator: JaxSim = jax_dataclasses.field(default=None) + + def __post_init__(self) -> None: + """Environment initialization.""" + + model = self.jaxsim.get_model(model_name="ErgoCub") + + # Create the action space (static attribute) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + high = jnp.array([25.0] * model.dofs(), dtype=float) + + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + self._action_space = spaces.PyTree(low=-high, high=high) + + # Get joint limits + s_min, s_max = model.joint_limits() + s_range = s_max - s_min + + low = ErgoCubObservation.build( + base_height=0.25, + gravity_projection=-jnp.ones(3), + joint_positions=s_min, + joint_velocities=-50.0 * jnp.ones_like(s_min), + base_linear_velocity=-5.0 * jnp.ones(3), + base_angular_velocity=-10.0 * jnp.ones(3), + contact_state=jnp.array([False] * 4), + ) + + high = ErgoCubObservation.build( + base_height=1.0, + gravity_projection=jnp.ones(3), + joint_positions=s_max, + joint_velocities=50.0 * jnp.ones_like(s_max), + base_linear_velocity=5.0 * jnp.ones(3), + base_angular_velocity=10.0 * jnp.ones(3), + contact_state=jnp.array([True] * 4), + ) + + # Create the observation space (static attribute) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + self._observation_space = spaces.PyTree(low=low, high=high) + + @property + def jaxsim(self) -> JaxSim: + """""" + + if self._simulator is not None: + return self._simulator + + # Create the jaxsim simulator. + simulator = JaxSim.build( + # Note: any change of either 'step_size' or 'steps_per_run' requires + # updating the number of integration steps in the 'transition' method. + step_size=0.000_250, + steps_per_run=1, + velocity_representation=VelRepr.Body, + integrator_type=IntegratorType.EulerSemiImplicit, + simulator_data=SimulatorData( + gravity=jnp.array([0, 0, -10.0]), + contact_parameters=SoftContactsParams.build(K=10_000, D=20), + ), + ).mutable(mutable=True, validate=False) + + # Get the SDF path + model_sdf_path = resolve_robotics_uri( + "package://ergoCub/robots/ergoCubGazeboV1_minContacts/model.urdf" + ) + + # Insert the model + _ = simulator.insert_model_from_description( + model_description=model_sdf_path, model_name="ErgoCub" + ) + + simulator.data.models = { + model_name: jax.tree_util.tree_map(lambda leaf: jnp.array(leaf), model_data) + for model_name, model_data in simulator.data.models.items() + } + + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + self._simulator = simulator.mutable(mutable=True, validate=True) + + return self._simulator + + def initial(self, rng: Any = None) -> StateType: + """""" + assert jax.dtypes.issubdtype(rng, jax.dtypes.prng_key) + + # Split the key + subkey1, subkey2 = jax.random.split(rng, num=2) + + # Sample an initial observation + initial_observation: ErgoCubObservation = ( + self.observation_space.sample_with_key(key=subkey1) + ) + + # Sample a goal position + goal_xy_position = jax.random.uniform( + key=subkey2, minval=-5.0, maxval=5.0, shape=(2,) + ) + + with self.jaxsim.editable(validate=False) as simulator: + # Reset the simulator and get the model + simulator.reset(remove_models=False) + model = simulator.get_model(model_name="ErgoCub") + + # Reset the joint positions + model.reset_joint_positions( + positions=initial_observation.joint_positions, + joint_names=model.joint_names(), + ) + + # Reset the base position + model.reset_base_position(position=jnp.array([0, 0, 0.5])) + + # Reset the base velocity + model.reset_base_velocity( + base_velocity=jnp.hstack( + [ + 0.1 * initial_observation.base_linear_velocity, + 0.1 * initial_observation.base_angular_velocity, + ] + ) + ) + + # Return the simulation state + return dict( + simulator_data=simulator.data, + goal=jnp.array(goal_xy_position, dtype=float), + ) + + def transition( + self, state: StateType, action: ActType, rng: Any = None + ) -> StateType: + """""" + + # Get the JaxSim simulator + simulator = self.jaxsim + + # Initialize the simulator with the environment state (containing SimulatorData) + with simulator.editable(validate=True) as simulator: + simulator.data = state["simulator_data"] + + @jax_dataclasses.pytree_dataclass + class SetTorquesOverHorizon(simulator_callbacks.PreStepCallback): + def pre_step(self, sim: JaxSim) -> JaxSim: + """""" + + model = sim.get_model(model_name="ErgoCub") + model.zero_input() + model.set_joint_generalized_force_targets( + forces=jnp.atleast_1d(action), joint_names=model.joint_names() + ) + + return sim + + number_of_integration_steps = 40 # 0.010 # TODO 20 for having 0.010 + + # Stepping logic + with simulator.editable(validate=True) as simulator: + simulator, _ = simulator.step_over_horizon( + horizon_steps=number_of_integration_steps, + clear_inputs=False, + callback_handler=SetTorquesOverHorizon(), + ) + + # Return the new environment state (updated SimulatorData) + return state | dict(simulator_data=simulator.data) + + def observation(self, state: StateType) -> ObsType: + """""" + + # Initialize the simulator with the environment state (containing SimulatorData) + # and get the simulated model + with self.jaxsim.editable(validate=True) as simulator: + simulator.data = state["simulator_data"] + model = simulator.get_model("ErgoCub") + + # Compute the normalized gravity projection in the body frame + W_R_B = model.base_orientation(dcm=True) + W_gravity = self.jaxsim.gravity() + B_gravity = W_R_B.T @ (W_gravity / jnp.linalg.norm(W_gravity)) + + W_p_B = model.base_position() + W_p_goal = jnp.hstack([state["goal"].squeeze(), 0]) + + # Compute the distance between the base and the goal in the body frame + B_p_distance = W_R_B.T @ (W_p_goal - W_p_B) + + # Build the observation from the state + return ErgoCubObservation.build( + base_height=model.base_position()[2], + gravity_projection=B_gravity, + joint_positions=model.joint_positions(), + joint_velocities=model.joint_velocities(), + base_linear_velocity=model.base_velocity()[0:3], + base_angular_velocity=model.base_velocity()[3:6], + contact_state=model.in_contact( + link_names=[name for name in model.link_names() if "_ankle" in name] + ), + ) + + def reward( + self, state: StateType, action: ActType, next_state: StateType + ) -> RewardType: + """""" + + with self.jaxsim.editable(validate=True) as simulator_next: + simulator_next.data = next_state["simulator_data"] + model_next = simulator_next.get_model("ErgoCub") + + terminal = self.terminal(state=state) + obs_in_space = jax.lax.select( + pred=self.observation_space.contains(x=self.observation(state=state)), + on_true=1.0, + on_false=0.0, + ) + + # Position of the base + W_p_B = model_next.base_position() + W_p_xy_goal = state["goal"] + + reward = 0.0 + reward += 1.0 * (1.0 - jnp.array(terminal, dtype=float)) # alive + reward += 5.0 * obs_in_space # + # reward += 100.0 * v_WB[0] # forward velocity + reward -= jnp.linalg.norm(W_p_B[0:2] - W_p_xy_goal) # distance from goal + reward += 1.0 * model_next.in_contact( + link_names=[ + name + for name in model_next.link_names() + if name.startswith("leg_") and name.endswith("_lower") + ] + ).any().astype(float) + reward -= 0.1 * jnp.linalg.norm(action) / action.size # control cost + + return reward + + def terminal(self, state: StateType) -> TerminalType: + # Get the current observation + observation = self.observation(state=state) + + base_too_high = ( + observation.base_height >= self.observation_space.high.base_height + ) + return base_too_high + + # ========= + # Rendering + # ========= + + def render_image( + self, state: StateType, render_state: RenderStateType + ) -> tuple[RenderStateType, npt.NDArray]: + """Show the state.""" + + model_name = "ErgoCub" + + # Initialize the simulator with the environment state (containing SimulatorData) + # and get the simulated model + with self.jaxsim.editable(validate=False) as simulator: + simulator.data = state["simulator_data"] + model = simulator.get_model(model_name=model_name) + + # Insert the model lazily in the visualizer if it is not already there + if model_name not in render_state.world._meshcat_models.keys(): + from rod.urdf.exporter import UrdfExporter + + urdf_string = UrdfExporter.sdf_to_urdf_string( + sdf=rod.Sdf( + version="1.7", + model=model.physics_model.description.extra_info["sdf_model"], + ), + pretty=True, + gazebo_preserve_fixed_joints=False, + ) + + meshcat_viz_name = render_state.world.insert_model( + model_description=urdf_string, is_urdf=True, model_name=None + ) + + render_state._jaxsim_to_meshcat_viz_name[model_name] = meshcat_viz_name + + # Check that the model is in the visualizer + if ( + not render_state._jaxsim_to_meshcat_viz_name[model_name] + in render_state.world._meshcat_models.keys() + ): + raise ValueError(f"The '{model_name}' model is not in the meshcat world") + + # Update the model in the visualizer + render_state.world.update_model( + model_name=render_state._jaxsim_to_meshcat_viz_name[model_name], + joint_names=model.joint_names(), + joint_positions=model.joint_positions(), + base_position=model.base_position(), + base_quaternion=model.base_orientation(dcm=False), + ) + + return render_state, np.empty(0) + + def render_init(self, open_gui: bool = False, **kwargs) -> RenderStateType: + """Initialize the render state.""" + + # Initialize the render state + meshcat_viz_state = MeshcatVizRenderState() + + if open_gui: + meshcat_viz_state.open_window_in_process() + + return meshcat_viz_state + + def render_close(self, render_state: RenderStateType) -> None: + """Close the render state.""" + + render_state.close() + + +class ErgoCubWalkEnvV0(JaxEnv): + """""" + + def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None: + """""" + + from jaxgym.wrappers.jax import ( + ClipActionWrapper, + FlattenSpacesWrapper, + JaxTransformWrapper, + TimeLimit, + ) + + func_env = ErgoCubWalkFuncEnvV0() + + func_env_wrapped = func_env + func_env_wrapped = TimeLimit( + env=func_env_wrapped, max_episode_steps=5_000 + ) # TODO + func_env_wrapped = ClipActionWrapper(env=func_env_wrapped) + func_env_wrapped = FlattenSpacesWrapper(env=func_env_wrapped) + func_env_wrapped = JaxTransformWrapper(env=func_env_wrapped, function=jax.jit) + + super().__init__( + func_env=func_env_wrapped, + metadata=self.metadata, + render_mode=render_mode, + ) + + +class ErgoCubWalkVectorEnvV0(JaxVectorEnv): + """""" + + metadata = dict() + + def __init__( + self, + # func_env: JaxDataclassEnv[ + # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + # ], + num_envs: int, + render_mode: str | None = None, + # max_episode_steps: int = 5_000, + jit_compile: bool = True, + **kwargs, + ) -> None: + """""" + + print("+++", kwargs) + + env = ErgoCubWalkFuncEnvV0() + + # Vectorize the environment. + # Note: it automatically wraps the environment in a TimeLimit wrapper. + super().__init__( + func_env=env, + num_envs=num_envs, + metadata=self.metadata, + render_mode=render_mode, + max_episode_steps=5_000, # TODO + jit_compile=jit_compile, + ) + + # from jaxgym.vector.jax import FlattenSpacesVecWrapper + # + # vec_env_wrapped = FlattenSpacesVecWrapper(env=vec_env) + + +if __name__ == "__main__": + """Stable Baselines""" + + def make_jax_env( + max_episode_steps: Optional[int] = 500, jit: bool = True + ) -> JaxEnv: + """""" + + # TODO: single env -> time limit with stable_baselines? + + if max_episode_steps in {None, 0}: + env = ErgoCubWalkFuncEnvV0() + else: + env = TimeLimit( + env=ErgoCubWalkFuncEnvV0(), max_episode_steps=max_episode_steps + ) + + return JaxEnv( + func_env=ToNumPyWrapper( + env=FlattenSpacesWrapper(env=env) + if not jit + else JaxTransformWrapper( + function=jax.jit, + env=FlattenSpacesWrapper(env=env), + ), + ), + render_mode="meshcat_viz", + ) + + class CustomVecEnvSB(vec_env_sb.VecEnv): + """""" + + metadata = {"render_modes": []} + + def __init__( + self, + jax_vector_env: JaxVectorEnv | VectorWrapper, + log_rewards: bool = False, + # num_envs: int, + # observation_space: spaces.Space, + # action_space: spaces.Space, + # render_mode: Optional[str] = None, + ) -> None: + """""" + + if not isinstance(jax_vector_env.unwrapped, JaxVectorEnv): + raise TypeError(type(jax_vector_env)) + + self.jax_vector_env = jax_vector_env + + single_env_action_space: PyTree = ( + jax_vector_env.unwrapped.single_action_space + ) + + single_env_observation_space: PyTree = ( + jax_vector_env.unwrapped.single_observation_space + ) + + super().__init__( + num_envs=self.jax_vector_env.num_envs, + action_space=single_env_action_space.to_box(), + observation_space=single_env_observation_space.to_box(), + ) + + self.actions = np.zeros_like(self.jax_vector_env.action_space.sample()) + + # Initialize the RNG seed + self._seed = None + self.seed() + + # Initialize the rewards logger + self.logger_rewards = [] if log_rewards else None + + def reset(self) -> vec_env_sb.base_vec_env.VecEnvObs: + """""" + + observations, state_infos = self.jax_vector_env.reset(seed=self._seed) + return np.array(observations) + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + @staticmethod + @functools.partial(jax.jit, static_argnames=("batch_size",)) + def tree_inverse_transpose( + pytree: jtp.PyTree, batch_size: int + ) -> List[jtp.PyTree]: + """""" + + return [ + jax.tree_util.tree_map(lambda leaf: leaf[i], pytree) + for i in range(batch_size) + ] + + def step_wait(self) -> vec_env_sb.base_vec_env.VecEnvStepReturn: + """""" + + ( + observations, + rewards, + terminals, + truncated, + step_infos, + ) = self.jax_vector_env.step(actions=self.actions) + + done = np.logical_or(terminals, truncated) + + # list_of_step_infos = [ + # jax.tree_util.tree_map(lambda l: l[i], step_infos) + # for i in range(self.jax_vector_env.num_envs) + # ] + + list_of_step_infos = self.tree_inverse_transpose( + pytree=step_infos, batch_size=self.jax_vector_env.num_envs + ) + + # def pytree_to_numpy(pytree: jtp.PyTree) -> jtp.PyTree: + # return jax.tree_util.tree_map(lambda leaf: np.array(leaf), pytree) + # + # list_of_step_infos_numpy = [pytree_to_numpy(pt) for pt in list_of_step_infos] + + list_of_step_infos_numpy = [ + ToNumPyWrapper.pytree_to_numpy(pytree=pt) for pt in list_of_step_infos + ] + + if self.logger_rewards is not None: + self.logger_rewards.append(np.array(rewards).mean()) + + return ( + np.array(observations), + np.array(rewards), + np.array(done), + list_of_step_infos_numpy, + ) + + def close(self) -> None: + return self.jax_vector_env.close() + + def get_attr( + self, attr_name: str, indices: vec_env_sb.base_vec_env.VecEnvIndices = None + ) -> List[Any]: + raise AttributeError + # raise NotImplementedError + + def set_attr( + self, + attr_name: str, + value: Any, + indices: vec_env_sb.base_vec_env.VecEnvIndices = None, + ) -> None: + raise NotImplementedError + + def env_method( + self, + method_name: str, + *method_args, + indices: vec_env_sb.base_vec_env.VecEnvIndices = None, + **method_kwargs, + ) -> List[Any]: + raise NotImplementedError + + def env_is_wrapped( + self, + wrapper_class: Type[gym.Wrapper], + indices: vec_env_sb.base_vec_env.VecEnvIndices = None, + ) -> List[bool]: + return [False] * self.num_envs + # raise NotImplementedError + + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + """""" + + if seed is None: + seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32") + + if np.array(seed, dtype="uint32") != np.array(seed): + raise ValueError(f"seed must be compatible with 'uint32' casting") + + self._seed = seed + return [seed] + + # _ = self.jax_vector_env.reset(seed=seed) + # return [None] + + def make_vec_env_stable_baselines( + jax_dataclass_env: JaxDataclassEnv | JaxDataclassWrapper, + n_envs: int = 1, + seed: Optional[int] = None, + vec_env_kwargs: Optional[Dict[str, Any]] = None, + ) -> vec_env_sb.VecEnv: + """""" + + env = jax_dataclass_env + + vec_env_kwargs = vec_env_kwargs if vec_env_kwargs is not None else dict() + + vec_env = JaxVectorEnv( + func_env=env, + num_envs=n_envs, + **vec_env_kwargs, + ) + + # Flatten the PyTree spaces to regular Box spaces + vec_env = FlattenSpacesVecWrapper(env=vec_env) + + vec_env_sb = CustomVecEnvSB(jax_vector_env=vec_env, log_rewards=True) + + if seed is not None: + _ = vec_env_sb.seed(seed=seed) + + return vec_env_sb + + os.environ["IGN_GAZEBO_RESOURCE_PATH"] = "/conda/share/" # DEBUG + + max_episode_steps = 200 + func_env = NaNHandlerWrapper(env=ErgoCubWalkFuncEnvV0()) + + if max_episode_steps is not None: + func_env = TimeLimit(env=func_env, max_episode_steps=max_episode_steps) + + func_env = ClipActionWrapper( + env=SquashActionWrapper(env=ActionNoiseWrapper(env=func_env)), + ) + + vec_env = make_vec_env_stable_baselines( + jax_dataclass_env=func_env, + n_envs=6000, + seed=42, + vec_env_kwargs=dict( + jit_compile=True, + ), + ) + + vec_env = VecMonitor( + venv=VecNormalize( + venv=vec_env, + training=True, + ) + ) + + vec_env.venv.venv.logger_rewards = [] + seed = vec_env.seed(seed=7)[0] + _ = vec_env.reset() + + model = PPO( + "MlpPolicy", + env=vec_env, + n_steps=5, # in the vector env -> real ones are x512 + batch_size=256, + n_epochs=10, + gamma=0.95, + gae_lambda=0.9, + clip_range=0.1, + normalize_advantage=True, + target_kl=0.025, + verbose=2, + learning_rate=0.000_300, + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=dict(pi=[512, 512], vf=[512, 512]), + log_std_init=np.log(0.05), + ), + ) + + print(model.policy) + + model = model.learn(total_timesteps=50000, progress_bar=True) diff --git a/src/jaxgym/functional/__init__.py b/src/jaxgym/functional/__init__.py new file mode 100644 index 000000000..f8a909b60 --- /dev/null +++ b/src/jaxgym/functional/__init__.py @@ -0,0 +1,4 @@ +from .func_env import FuncEnv +from .func_wrapper import FuncWrapper + +# from .func_space import FuncSpace diff --git a/src/jaxgym/functional/_jax/__init__.py b/src/jaxgym/functional/_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/jaxgym/functional/_jax/_autoreset.py b/src/jaxgym/functional/_jax/_autoreset.py new file mode 100644 index 000000000..a20faa5fd --- /dev/null +++ b/src/jaxgym/functional/_jax/_autoreset.py @@ -0,0 +1,82 @@ +from typing import Any, Generic + +import jax.numpy as jnp +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.functional.core import FuncWrapper + +WrapperStateType = StateType +WrapperObsType = ObsType +WrapperActType = ActType +WrapperRewardType = RewardType + +# TODO: non si puo' fare wrappando FuncWrapper perche' observation deve chiamare +# initial, e initial ha bisogno di rng +# Implementare sopra env.FunctionalJaxEnv? -> Il JaxVecEnv fa gia' autoreset di default. + + +@jax_dataclasses.pytree_dataclass +class AutoResetWrapper( + FuncWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ], + Generic[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + ], +): + """""" + + def is_done(self, state: WrapperStateType) -> bool: + """""" + + info = self.env.step_info() + + return jnp.array( + [ + self.terminal(state=state), + "truncated" in info and info["truncated"] is True, + ], + dtype=bool, + ).any() + + def observation(self, state: WrapperStateType) -> WrapperObsType: + """""" + + return self.env.observation( + state=self.wrapper_state_to_environment_state(wrapper_state=state) + ) + + def step_info( + self, state: WrapperStateType, action: ActType, next_state: WrapperStateType + ) -> dict[str, Any]: + """Info dict about a full transition.""" + + return self.env.step_info( + state=self.wrapper_state_to_environment_state(wrapper_state=state), + action=action, + next_state=self.wrapper_state_to_environment_state( + wrapper_state=next_state + ), + ) diff --git a/src/jaxgym/functional/_jax/env.py b/src/jaxgym/functional/_jax/env.py new file mode 100644 index 000000000..831ab5c64 --- /dev/null +++ b/src/jaxgym/functional/_jax/env.py @@ -0,0 +1,115 @@ +from typing import Any + +import gymnasium as gym +import jax +import jax.numpy as jnp +import jax.random as jrng +import numpy as np +from gymnasium.envs.registration import EnvSpec +from gymnasium.experimental.functional import ActType, FuncEnv, StateType +from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy +from gymnasium.utils import seeding +from gymnasium.vector.utils import batch_space + + +# TODO: this is still copy-paste from gymnasium +# TODO: update with my logic +# TODO: autoreset wrapper works on top of this env +class FunctionalJaxEnv(gym.Env): + """A conversion layer for jax-based environments.""" + + state: StateType + rng: jrng.PRNGKey + + def __init__( + self, + func_env: FuncEnv, + metadata: dict[str, Any] | None = None, + render_mode: str | None = None, + reward_range: tuple[float, float] = (-float("inf"), float("inf")), + spec: EnvSpec | None = None, + ): + """Initialize the environment from a FuncEnv.""" + + if metadata is None: + metadata = {"render_mode": []} + + self.func_env = func_env + + self.observation_space = func_env.observation_space + self.action_space = func_env.action_space + + self.metadata = metadata + self.render_mode = render_mode + self.reward_range = reward_range + + self.spec = spec + + self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) + + if self.render_mode == "rgb_array": + self.render_state = self.func_env.render_init() + else: + self.render_state = None + + np_random, _ = seeding.np_random() + seed = np_random.integers(0, 2 ** 32 - 1, dtype="uint32") + + self.rng = jrng.PRNGKey(seed) + + def reset(self, *, seed: int | None = None, options: dict | None = None): + """Resets the environment using the seed.""" + + super().reset(seed=seed) + if seed is not None: + self.rng = jrng.PRNGKey(seed) + + rng, self.rng = jrng.split(self.rng) + + self.state = self.func_env.initial(rng=rng) + obs = self.func_env.observation(self.state) + info = self.func_env.state_info(self.state) + + obs = jax_to_numpy(obs) + + return obs, info + + def step(self, action: ActType): + """Steps through the environment using the action.""" + + if self._is_box_action_space: + assert isinstance(self.action_space, gym.spaces.Box) # For typing + action = np.clip(action, self.action_space.low, self.action_space.high) + else: # Discrete + # For now we assume jax envs don't use complex spaces + err_msg = f"{action!r} ({type(action)}) invalid" + assert self.action_space.contains(action), err_msg + + rng, self.rng = jrng.split(self.rng) + + next_state = self.func_env.transition(self.state, action, rng) + observation = self.func_env.observation(next_state) + reward = self.func_env.reward(self.state, action, next_state) + terminated = self.func_env.terminal(next_state) + info = self.func_env.step_info(self.state, action, next_state) + self.state = next_state + + observation = jax_to_numpy(observation) + + return observation, float(reward), bool(terminated), False, info + + # def render(self): + # """Returns the render state if `render_mode` is "rgb_array".""" + # if self.render_mode == "rgb_array": + # self.render_state, image = self.func_env.render_image( + # self.state, self.render_state + # ) + # return image + # else: + # raise NotImplementedError + # + # def close(self): + # """Closes the environments and render state if set.""" + # if self.render_state is not None: + # self.func_env.render_close(self.render_state) + # self.render_state = None diff --git a/src/jaxgym/functional/func_env.py b/src/jaxgym/functional/func_env.py new file mode 100644 index 000000000..65ef80963 --- /dev/null +++ b/src/jaxgym/functional/func_env.py @@ -0,0 +1,184 @@ +import abc +from typing import Any, Generic + +import gymnasium as gym +import numpy.typing as npt +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + + +# Similar to https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/experimental/functional.py +class FuncEnv( + Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType], + abc.ABC, +): + """ + Base class for functional environments. + + Note: + This class is meant to be kept stateless. + The state of the environment (possibly encapsulated with the states of wrappers) + should be stored in the `state` argument of the methods. + + Note: + This functional approach mainly targets JAX-based environments, but has been + formulated in a generic way so that it can be implemented in other frameworks. + """ + + # These spaces have to be populated in the __post_init__ method. + # If necessary, the __post_init__ method can access the simulator. + # _action_space: spaces.Space | None = jax_dataclasses.static_field(init=False) + # _observation_space: spaces.Space | None = jax_dataclasses.static_field(init=False) + _action_space: gym.Space | None = None + _observation_space: gym.Space | None = None + + @property + def unwrapped( + self, + ) -> "FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType]": + """Return the innermost environment.""" + + return self + + @property + def action_space(self) -> gym.Space[ActType]: + """Return the action space.""" + + return self._action_space + + @property + def observation_space(self) -> gym.Space[ObsType]: + """Return the observation space.""" + + return self._observation_space + + # ================ + # Abstract methods + # ================ + + @abc.abstractmethod + def initial(self, rng: Any = None) -> StateType: + """ + Initialize the environment returning its initial state.s. + + Args: + rng: A resource to initialize the RNG of the functional environment. + + Returns: + The initial state of the environment. + """ + pass + + @abc.abstractmethod + def transition( + self, state: StateType, action: ActType, rng: Any = None + ) -> StateType: + """ + Compute the next state by applying the given action to the functional environment + in the given state. + + Args: + state: + action: + rng: + + Returns: + The next state of the environment. + """ + pass + + @abc.abstractmethod + def observation(self, state: StateType) -> ObsType: + """ + Compute the observation of the environment in the given state. + + Args: + state: + + Returns: + The observation computed from the given state. + """ + pass + + @abc.abstractmethod + def reward( + self, state: StateType, action: ActType, next_state: StateType + ) -> RewardType: + """""" + pass + + @abc.abstractmethod + def terminal(self, state: StateType) -> TerminalType: + """""" + pass + + # ========= + # Rendering + # ========= + + # @abc.abstractmethod + def render_image( + self, state: StateType, render_state: RenderStateType + ) -> tuple[RenderStateType, npt.NDArray]: + """Show the state.""" + raise NotImplementedError + + # @abc.abstractmethod + def render_init(self, **kwargs) -> RenderStateType: + """Initialize the render state.""" + raise NotImplementedError + + # @abc.abstractmethod + def render_close(self, render_state: RenderStateType) -> None: + """Close the render state.""" + raise NotImplementedError + + # ============= + # Other methods + # ============= + + def state_info(self, state: StateType) -> dict[str, Any]: + """Info dict about a single state.""" + + return {} + + def step_info( + self, state: StateType, action: ActType, next_state: StateType + ) -> dict[str, Any]: + """Info dict about a full transition.""" + + return { + # TODO: this is to keep the autoreset structure constant and return + # always a dict with the same keys. Check if this is necessary. + # "state_info": self.state_info(state=state), + # "next_state_info": self.state_info(state=next_state), + } + + # TODO: remove + # def transform(self, func: Callable[[Callable], Callable]) -> None: + # """Functional transformations.""" + # raise NotImplementedError + # # with self.mutable_context(mutability=Mutability.MUTABLE): + # self.initial = func(self.initial) + # self.transition = func(self.transition) + # self.observation = func(self.observation) + # self.reward = func(self.reward) + # self.terminal = func(self.terminal) + # self.state_info = func(self.state_info) + # self.step_info = func(self.step_info) + + def __str__(self) -> str: + """""" + + return f"<{type(self).__name__}>" + + def __repr__(self) -> str: + """""" + + return str(self) diff --git a/src/jaxgym/functional/func_space.py b/src/jaxgym/functional/func_space.py new file mode 100644 index 000000000..cb199a4c0 --- /dev/null +++ b/src/jaxgym/functional/func_space.py @@ -0,0 +1,6 @@ +import gymnasium as gym + + +class FuncSpace(gym.Space): + def __init__(self) -> None: + raise NotImplementedError diff --git a/src/jaxgym/functional/func_wrapper.py b/src/jaxgym/functional/func_wrapper.py new file mode 100644 index 000000000..e6ada8798 --- /dev/null +++ b/src/jaxgym/functional/func_wrapper.py @@ -0,0 +1,225 @@ +from typing import Any, Generic, TypeVar + +import gymnasium as gym +import numpy.typing as npt +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.functional import FuncEnv + +WrapperStateType = TypeVar("WrapperStateType") +WrapperObsType = TypeVar("WrapperObsType") +WrapperActType = TypeVar("WrapperActType") +WrapperRewardType = TypeVar("WrapperRewardType") + + +class FuncWrapper( + FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType], + Generic[ + # FuncEnv types + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + # FuncWrapper types + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ], +): + """""" + + env: FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType] + + _action_space: gym.Space | None = None + _observation_space: gym.Space | None = None + + @property + def unwrapped( + self, + ) -> FuncEnv[ + StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + ]: + """""" + + return self.env.unwrapped + + @property + def action_space(self) -> gym.Space[ActType]: + """""" + + return ( + self._action_space + if self._action_space is not None + else self.env.action_space + ) + + @property + def observation_space(self) -> gym.Space[ObsType]: + """""" + + return ( + self._observation_space + if self._observation_space is not None + else self.env.observation_space + ) + + @action_space.setter + def action_space(self, space: gym.Space[ActType]) -> None: + """""" + + self._action_space = space + + @observation_space.setter + def observation_space(self, space: gym.Space[ObsType]) -> None: + """""" + + self._observation_space = space + + def __str__(self): + """""" + + return f"<{type(self).__name__}{self.env}>" + + def __repr__(self): + """""" + + return str(self) + + # ================================== + # Implementation of FunEnv interface + # ================================== + + def initial(self, rng: Any = None) -> WrapperStateType: + """""" + + return self.env.initial(rng=rng) + + def transition( + self, state: WrapperStateType, action: WrapperActType, rng: Any = None + ) -> WrapperStateType: + """""" + + return self.env.transition(state=state, action=action, rng=rng) + + def observation(self, state: WrapperStateType) -> WrapperObsType: + """""" + + return self.env.observation(state=state) + + def reward( + self, + state: WrapperStateType, + action: WrapperActType, + next_state: WrapperStateType, + ) -> WrapperRewardType: + """""" + + return self.env.reward(state=state, action=action, next_state=next_state) + + def terminal(self, state: WrapperStateType) -> TerminalType: + """""" + + return self.env.terminal(state=state) + + def state_info(self, state: WrapperStateType) -> dict[str, Any]: + """""" + + return self.env.state_info(state=state) + + def step_info( + self, + state: WrapperStateType, + action: WrapperActType, + next_state: WrapperStateType, + ) -> dict[str, Any]: + """""" + + return self.env.step_info(state=state, action=action, next_state=next_state) + + # ========= + # Rendering + # ========= + + def render_image( + self, state: StateType, render_state: RenderStateType + ) -> tuple[RenderStateType, npt.NDArray]: + """""" + return self.env.render_image(state=state, render_state=render_state) + + def render_init(self, **kwargs) -> RenderStateType: + """""" + return self.env.render_init(**kwargs) + + def render_close(self, render_state: RenderStateType) -> None: + """""" + return self.env.render_close(render_state=render_state) + + +# class ActionFuncWrapper( +# FuncWrapper[ +# # FuncEnv types +# StateType, +# ObsType, +# ActType, +# RewardType, +# TerminalType, +# RenderStateType, +# # FuncWrapper types +# StateType, +# ObsType, +# WrapperActType, +# RewardType, +# ], +# Generic[WrapperActType], +# abc.ABC, +# ): +# """""" +# +# @abc.abstractmethod +# def action(self, action: WrapperActType) -> ActType: +# """""" +# +# pass +# +# def reward( +# self, +# state: WrapperStateType, +# action: WrapperActType, +# next_state: WrapperStateType, +# ) -> WrapperRewardType: +# """""" +# +# return self.env.reward( +# state=state, action=self.action(action=action), next_state=next_state +# ) +# +# def transition( +# self, state: WrapperStateType, action: WrapperActType, rng: Any = None +# ) -> WrapperStateType: +# """""" +# +# return self.env.transition( +# state=state, action=self.action(action=action), rng=rng +# ) +# +# def step_info( +# self, +# state: WrapperStateType, +# action: WrapperActType, +# next_state: WrapperStateType, +# ) -> dict[str, Any]: +# """""" +# +# return self.env.step_info( +# state=state, action=self.action(action=action), next_state=next_state +# ) diff --git a/src/jaxgym/jax/__init__.py b/src/jaxgym/jax/__init__.py new file mode 100644 index 000000000..3f85d53fb --- /dev/null +++ b/src/jaxgym/jax/__init__.py @@ -0,0 +1,6 @@ +from .dataclass_func_env import JaxDataclassEnv +from .dataclass_func_env_wrapper import JaxDataclassActionWrapper, JaxDataclassWrapper +from .env import JaxEnv +from .pytree_space import PyTree + +# from .jaxsim_func_env import JaxSimFuncEnv diff --git a/src/jaxgym/jax/dataclass_func_env.py b/src/jaxgym/jax/dataclass_func_env.py new file mode 100644 index 000000000..90086b5f3 --- /dev/null +++ b/src/jaxgym/jax/dataclass_func_env.py @@ -0,0 +1,41 @@ +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +import jaxgym.jax.pytree_space as spaces +from jaxgym.functional import FuncEnv +from jaxsim.utils import JaxsimDataclass + + +@jax_dataclasses.pytree_dataclass +class JaxDataclassEnv( + FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType], + JaxsimDataclass, +): + """ + Base class for JAX-based functional environments. + + Note: + All environments implementing this class must be pytree_dataclasses. + """ + + # Override spaces for JAX, storing them as static fields. + # Note: currently only PyTree spaces are supported. + # Note: always sample from these spaces using functional methods in order to + # avoid incurring in JIT recompilations. + _action_space: spaces.PyTree | None = jax_dataclasses.static_field(init=False) + _observation_space: spaces.PyTree | None = jax_dataclasses.static_field(init=False) + + @property + def action_space(self) -> spaces.PyTree: + return self._action_space + + @property + def observation_space(self) -> spaces.PyTree: + return self._observation_space diff --git a/src/jaxgym/jax/dataclass_func_env_wrapper.py b/src/jaxgym/jax/dataclass_func_env_wrapper.py new file mode 100644 index 000000000..757399013 --- /dev/null +++ b/src/jaxgym/jax/dataclass_func_env_wrapper.py @@ -0,0 +1,174 @@ +import abc +from typing import Any, Generic + +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +import jaxgym.jax.pytree_space as spaces +from jaxgym.functional.func_wrapper import ( + FuncWrapper, + WrapperActType, + WrapperObsType, + WrapperRewardType, + WrapperStateType, +) +from jaxgym.jax import JaxDataclassEnv +from jaxsim.utils import JaxsimDataclass + + +@jax_dataclasses.pytree_dataclass +class JaxDataclassWrapper( + FuncWrapper[ + # + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + # + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ], + # TODO + Generic[ + # + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + # + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ], + JaxsimDataclass, +): + """""" + + env: JaxDataclassEnv[ + StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + ] + + _action_space: spaces.PyTree | None = jax_dataclasses.static_field(init=False) + _observation_space: spaces.PyTree | None = jax_dataclasses.static_field(init=False) + + def __post_init__(self) -> None: + """""" + + if not isinstance(self.env.unwrapped, JaxDataclassEnv): + raise TypeError(type(self.env.unwrapped), JaxDataclassEnv) + + @property + def action_space(self) -> spaces.PyTree: + """""" + + return ( + self._action_space + if self._action_space is not None + else self.env.action_space + ) + + @property + def observation_space(self) -> spaces.PyTree: + """""" + + return ( + self._observation_space + if self._observation_space is not None + else self.env.observation_space + ) + + @action_space.setter + def action_space(self, space: spaces.PyTree) -> None: + """""" + + self._action_space = space + + @observation_space.setter + def observation_space(self, space: spaces.PyTree) -> None: + """""" + + self._observation_space = space + + +@jax_dataclasses.pytree_dataclass +class JaxDataclassActionWrapper( + JaxDataclassWrapper[ + # FuncEnv types + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + # FuncWrapper types + StateType, + ObsType, + ActType, + RewardType, + ], + # TODO + Generic[ + # + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + ], + abc.ABC, +): + """""" + + @abc.abstractmethod + def action(self, action: WrapperActType) -> ActType: + """""" + + pass + + def reward( + self, + state: WrapperStateType, + action: WrapperActType, + next_state: WrapperStateType, + ) -> WrapperRewardType: + """""" + + return self.env.reward( + state=state, action=self.action(action=action), next_state=next_state + ) + + def transition( + self, state: WrapperStateType, action: WrapperActType, rng: Any = None + ) -> WrapperStateType: + """""" + + return self.env.transition( + state=state, action=self.action(action=action), rng=rng + ) + + def step_info( + self, + state: WrapperStateType, + action: WrapperActType, + next_state: WrapperStateType, + ) -> dict[str, Any]: + """""" + + return self.env.step_info( + state=state, action=self.action(action=action), next_state=next_state + ) diff --git a/src/jaxgym/jax/env.py b/src/jaxgym/jax/env.py new file mode 100644 index 000000000..87856e63e --- /dev/null +++ b/src/jaxgym/jax/env.py @@ -0,0 +1,265 @@ +# import multiprocessing +from typing import Any, Generic, SupportsFloat + +import gymnasium as gym +import jax.numpy as jnp +import jax.random +import numpy as np +from gymnasium.core import ActType, ObsType, RenderFrame +from gymnasium.envs.registration import EnvSpec +from gymnasium.utils import seeding + +import jaxgym.jax.pytree_space as spaces +from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper +from jaxsim import logging + +# from meshcat_viz import MeshcatWorld + + +class JaxEnv(gym.Env[ObsType, ActType], Generic[ObsType, ActType]): + """""" + + action_space: spaces.PyTree + observation_space: spaces.PyTree + + metadata: dict[str, Any] = {"render_modes": ["meshcat_viz", "meshcat_viz_gui"]} + + def __init__( + self, + # func_env: FuncEnv | FuncWrapper, + func_env: JaxDataclassEnv | JaxDataclassWrapper, + metadata: dict[str, Any] | None = None, + render_mode: str | None = None, + reward_range: tuple[float, float] = (-float("inf"), float("inf")), + spec: EnvSpec | None = None, + ) -> None: + """""" + + if not isinstance(func_env.unwrapped, JaxDataclassEnv): + raise TypeError(type(func_env.unwrapped), JaxDataclassEnv) + + metadata = metadata if metadata is not None else dict(render_mode=list()) + + # Store the jax environment + self.func_env = func_env + + # Initialize the state of the environment + self.state = None + + # Expose the same spaces + self.action_space = func_env.action_space + self.observation_space = func_env.observation_space + # assert isinstance(self.action_space, spaces.PyTree) + # assert isinstance(self.observation_space, spaces.PyTree) + + # Store the other mandatory attributes that gym.Env expects + self.metadata = metadata + self.render_mode = render_mode + self.reward_range = reward_range + + # Store the environment specs + self.spec = spec + + # self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) + # + # if self.render_mode == "rgb_array": + # self.render_state = self.func_env.render_init() + # else: + # self.render_state = None + self.render_state = None + self._meshcat_world = None # old + self._meshcat_window = None # old + + # Initialize the RNGs with a random seed + seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32") + self._np_random, _ = seeding.np_random(seed=int(seed)) + self.rng = jax.random.PRNGKey(seed=seed) + + def subkey(self, num: int = 1) -> jax.random.PRNGKeyArray: + """ + Generate one or multiple sub-keys from the internal key. + + Note: + The internal key is automatically updated, there's no need to handle + the environment key externally. + + Args: + num: Number of keys to generate. + + Returns: + The generated sub-keys. + """ + + self.rng, *sub_keys = jax.random.split(self.rng, num=num + 1) + return jnp.stack(sub_keys).squeeze() + + def step( + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """""" + + # TODO: clip action with wrapper + # assert isinstance(self.action_space, spaces.PyTree) + # action = self.action_space.clip(x=action) + + # if self._is_box_action_space: + # assert isinstance(self.action_space, gym.spaces.Box) # For typing + # action = np.clip(action, self.action_space.low, self.action_space.high) + # else: # Discrete + # # For now we assume jax envs don't use complex spaces + # err_msg = f"{action!r} ({type(action)}) invalid" + # assert self.action_space.contains(action), err_msg + + # rng, self.rng = jrng.split(self.rng) + + # Advance the functional environment + next_state = self.func_env.transition( + state=self.state, action=action, rng=self.subkey(num=1) + ) + + # Extract updated data from the advanced environment + observation = self.func_env.observation(state=next_state) + reward = self.func_env.reward( + state=self.state, action=action, next_state=next_state + ) + info = self.func_env.step_info( + state=self.state, action=action, next_state=next_state + ) + + # Detect if the environment reached a terminal state + terminated = self.func_env.terminal(state=next_state) + truncated = ( + # False if "truncated" not in info else type(terminated)(info["truncated"]) + type(terminated)(False) + if "truncated" not in info + else type(terminated)(info["truncated"]) + ) + + # Remove the redundant "truncated" entry from info if present + _ = info.pop("truncated", None) + + # Store the updated state + self.state = next_state + + # observation = jax_to_numpy(observation) + + return observation, reward, terminated, truncated, info + + def reset( + self, *, seed: int | None = None, options: dict | None = None + ) -> tuple[ObsType, dict[str, Any]]: + """Resets the environment using the seed.""" + + super().reset(seed=seed) + self.rng = jax.random.PRNGKey(seed) if seed is not None else self.rng + + # Seed the spaces + self.action_space.seed(seed=seed) + self.observation_space.seed(seed=seed) + + # Generate initial state + self.state = self.func_env.initial(rng=self.subkey(num=1)) + + # Sample the initial observation and info + obs = self.func_env.observation(state=self.state) + info = self.func_env.state_info(state=self.state) + + # obs = jax_to_numpy(obs) + + # assert self.observation_space.contains(obs), obs + if obs not in self.observation_space: + logging.warning(f"Initial observation not in observation space") + logging.debug(obs) + + return obs, info + + # @property + # def visualizer(self) -> MeshcatWorld: + # """Returns the visualizer if `render_mode` is 'meshcat_viz'.""" + # + # if self._meshcat_world is not None: + # return self._meshcat_world + # + # world = MeshcatWorld() + # world.open() + # + # def open_window(web_url: str) -> None: + # import tkinter + # + # import webview + # + # # Create an instance of tkinter frame or window + # win = tkinter.Tk() + # win.geometry("700x350") + # + # webview.create_window("meshcat", web_url) + # webview.start(gui="qt") + # + # # TODO: non si apre niente in subprocess! + # p = multiprocessing.Process(target=open_window, args=(world.web_url,)) + # + # self._meshcat_window = p + # self._meshcat_world = world + # return self._meshcat_world + + # def update_meshcat_world(self) -> None: + # """""" + # + # return None + + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Returns the render state if `render_mode` is 'rgb_array'.""" + + if self.render_mode not in {None, "meshcat_viz", "meshcat_viz_gui"}: + raise NotImplementedError(self.render_mode) + + if self.render_mode is None: + return None + + if self.render_state is None: + if self.render_mode == "meshcat_viz": + self.render_state = self.func_env.render_init(open_gui=False) + elif self.render_mode == "meshcat_viz_gui": + self.render_state = self.func_env.render_init(open_gui=True) + else: + raise ValueError(self.render_mode) + + self.render_state, image = self.func_env.render_image( + self.state, self.render_state + ) + + return image + + # # TODO: how to create proper interfaces? + # self._meshcat_world = self.func_env.unwrapped.update_meshcat_world( + # world=self.visualizer, state=self.state["env"] + # ) + # + # return None + + # if self.render_mode == "rgb_array": + # self.render_state, image = self.func_env.render_image( + # self.state, self.render_state + # ) + # return image + # else: + # raise NotImplementedError + + def close(self) -> None: + """""" + + # import meshcat_viz.meshcat + # + # if self._meshcat_world is not None: + # # self._meshcat_world.close() + # self._meshcat_world = None + # self._meshcat_window.kill() + # self._meshcat_window = None + + if self.render_state is not None: + self.func_env.render_close(self.render_state) + self.render_state = None + + def __str__(self): + """Returns the wrapper name and the :attr:`env` representation string.""" + return f"<{type(self).__name__}{self.func_env}>" diff --git a/src/jaxgym/jax/jaxsim_func_env.py b/src/jaxgym/jax/jaxsim_func_env.py new file mode 100644 index 000000000..fb7c15af0 --- /dev/null +++ b/src/jaxgym/jax/jaxsim_func_env.py @@ -0,0 +1,3 @@ +class JaxSimFuncEnv: + def __init__(self) -> None: + raise NotImplementedError diff --git a/src/jaxgym/jax/pytree_space.py b/src/jaxgym/jax/pytree_space.py new file mode 100644 index 000000000..6cdee9a3b --- /dev/null +++ b/src/jaxgym/jax/pytree_space.py @@ -0,0 +1,378 @@ +import copy +import logging +from typing import Any + +import gymnasium as gym +import jax.flatten_util +import jax.numpy as jnp +import jax.tree_util +import numpy as np +import numpy.typing as npt +from gymnasium.spaces.utils import flatdim, flatten +from gymnasium.vector.utils.spaces import batch_space + +import jaxsim.typing as jtp +from jaxsim.utils import not_tracing + + +class PyTree(gym.Space[jtp.PyTree]): + """A generic space operating on JAX PyTree objects.""" + + def __init__( + self, + low: jtp.PyTree, + high: jtp.PyTree, + seed: int | None = None, + # TODO: + vectorize: int | None = None, + ) -> None: + """""" + + # ==================== + # Handle vectorization + # ==================== + + self.vectorize = vectorize + self.vectorized = False + + if vectorize is not None and vectorize < 2: + msg = f"Ignoring 'vectorize={vectorize}' argument since it is < 2" + logging.warning(msg=msg) + + if vectorize is not None and vectorize >= 2: + self.vectorized = True + low = jax.tree_util.tree_map(lambda l: jnp.stack([l] * vectorize), low) + high = jax.tree_util.tree_map(lambda l: jnp.stack([l] * vectorize), high) + + # ========================== + # Check low and high pytrees + # ========================== + # TODO: make generic (pytrees_with_same_dtype|shape|supported_dtype) and move + # to utils + + def check() -> None: + supported_dtypes = { + jnp.array(0, dtype=jnp.float32).dtype, + jnp.array(0, dtype=jnp.float64).dtype, + jnp.array(0, dtype=int).dtype, + jnp.array(0, dtype=bool).dtype, + } + + dtypes_supported = self.flatten_pytree( + pytree=jax.tree_util.tree_map( + lambda l1, l2: jnp.array(l1).dtype in supported_dtypes + and jnp.array(l2).dtype in supported_dtypes, + low, + high, + ) + ) + + if not jnp.all(dtypes_supported): + raise ValueError( + "Either low or high pytrees have attributes with unsupported dtype" + ) + + shape_match = self.flatten_pytree( + pytree=jax.tree_util.tree_map( + lambda l1, l2: jnp.array(l1).shape == jnp.array(l2).shape, low, high + ) + ) + + if not jnp.all(shape_match): + raise ValueError("Wrong shape of low and high attributes") + + dtype_match = self.flatten_pytree( + pytree=jax.tree_util.tree_map( + lambda l1, l2: jnp.array(l1).dtype == jnp.array(l2).dtype, low, high + ) + ) + + if not jnp.all(dtype_match): + raise ValueError("Wrong dtype of low and high attributes") + + if not_tracing(var=low): + check() + + # =============== + # Build the space + # =============== + + # Flatten the pytrees + low_flat = self.flatten_pytree(pytree=low) + high_flat = self.flatten_pytree(pytree=high) + + if low_flat.dtype != high_flat.dtype: + raise ValueError(low_flat.dtype, high_flat.dtype) + + if low_flat.shape != high_flat.shape: + raise ValueError(low_flat.shape, high_flat.shape) + + # Transform all leafs to array and store them in the object + self.low = jax.tree_util.tree_map(lambda l: jnp.array(l), low) + self.high = jax.tree_util.tree_map(lambda l: jnp.array(l), high) + + # Initialize the seed if not given + seed = ( + seed + if seed is not None + else np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32") + ) + + # Initialize the JAX random key + self.key = jax.random.PRNGKey(seed=seed) + + # Initialize parent class + super().__init__(shape=None, dtype=None, seed=int(seed)) + + def subkey(self, num: int = 1) -> jax.random.PRNGKeyArray: + """ + Generate one or multiple sub-keys from the internal key. + + Note: + The internal key is automatically updated, there's no need to handle + the environment key externally. + + Args: + num: Number of keys to generate. + + Returns: + The generated sub-keys. + """ + + self.key, *sub_keys = jax.random.split(self.key, num=num + 1) + return jnp.stack(sub_keys).squeeze() + + # TODO: what if key is a vector? -> multiple outputs? + def sample_with_key(self, key: jax.random.PRNGKeyArray) -> jtp.PyTree: + """""" + + def random_array( + key, shape: tuple, min: jtp.PyTree, max: jtp.PyTree, dtype + ) -> jtp.Array: + """Helper to select the right sampling function for the supported dtypes""" + + match dtype: + case jnp.float32.dtype | jnp.float64.dtype: + return jax.random.uniform( + key=key, + shape=shape, + minval=min, + maxval=max, + dtype=dtype, + ) + case jnp.int16.dtype | jnp.int32.dtype | jnp.int64.dtype: + return jax.random.randint( + key=key, + shape=shape, + minval=min, + maxval=max + 1, + dtype=dtype, + ) + case jnp.bool_.dtype: + return jax.random.randint( + key=key, + shape=shape, + minval=min, + maxval=max + 1, + ).astype(bool) + case _: + raise ValueError(dtype) + + # Create and flatten a tree having a PRNGKey for each leaf. + # We do this just to get the number of keys we need to generate, and the + # function to unflatten the ravelled tree. + dummy_pytree = jax.tree_util.tree_map(lambda l: jax.random.PRNGKey(0), self.low) + dummy_pytree_flat, unflatten_fn = jax.flatten_util.ravel_pytree(dummy_pytree) + + # Use the subkey to generate new keys, one for each leaf. + # Note: the division by 2 is needed because keys are vector of 2 elements. + subkey_flat = jax.random.split(key=key, num=dummy_pytree_flat.size // 2) + + # Generate a pytree having a different subkey in each leaf + subkey_pytree = unflatten_fn(jnp.array(subkey_flat).flatten()) + + # Generate a pytree by sampling leafs according to their dtype and using a + # different subkey for each of them + return jax.tree_util.tree_map( + lambda low, high, subkey: random_array( + key=subkey, shape=low.shape, min=low, max=high, dtype=low.dtype + ), + self.low, + self.high, + subkey_pytree, + ) + + def sample(self, mask: Any | None = None) -> jtp.PyTree: + """""" + + # Generate a subkey + subkey = self.subkey(num=1) + + return self.sample_with_key(key=subkey) + + def seed(self, seed: int | None = None) -> list[int]: + """""" + + seed = ( + seed + if seed is not None + else np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32") + ) + + self.key = jax.random.PRNGKey(seed=seed) + return super().seed(seed=seed) + + def contains(self, x: jtp.PyTree) -> bool: + """""" + + def is_inside_bounds(x, low, high): + return jax.lax.select( + pred=x.size == 0 or jnp.all((x >= low) & (x <= high)), + on_true=True, + on_false=False, + ) + + contains_all_leaves = jax.tree_util.tree_map( + lambda low, high, l: is_inside_bounds(x=l, low=low, high=high), + self.low, + self.high, + x, + ) + + contains_all_leaves_flat = self.flatten_pytree(pytree=contains_all_leaves) + + return jnp.all(contains_all_leaves_flat) + + @property + def is_np_flattenable(self) -> bool: + """Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`.""" + + return True + + def to_box(self) -> gym.spaces.Box: + """""" + + get_first_element = lambda pytree: jax.tree_util.tree_map( + lambda l: l[0], pytree + ) + + low = self.low if not self.vectorized else get_first_element(self.low) + high = self.high if not self.vectorized else get_first_element(self.high) + + low_flat = np.array(self.flatten_pytree(pytree=low)) + high_flat = np.array(self.flatten_pytree(pytree=high)) + + if self.vectorized: + assert self.vectorize >= 2 + + repeats = tuple([self.vectorize] + [1] * low_flat.ndim) + + low_flat = np.tile(low_flat, repeats) + high_flat = np.tile(high_flat, repeats) + + return gym.spaces.Box( + low=np.array(low_flat, dtype=np.float32), + high=np.array(high_flat, dtype=np.float32), + seed=copy.deepcopy(self.np_random), + ) + + def to_dict(self) -> gym.spaces.Dict: + # if low/high are a dataclass -> convert to dict + raise NotImplementedError + + @staticmethod + def flatten_pytree(pytree: jtp.PyTree) -> jtp.VectorJax: + """""" + + # print("flatten_pytree") + pytree_flat, _ = jax.flatten_util.ravel_pytree(pytree) + return pytree_flat + + def flatten_sample(self, pytree: jtp.PyTree) -> jtp.VectorJax: + """""" + + if not self.vectorized: + return self.flatten_pytree(pytree=pytree) + + # @jax.jit + # def flatten_pytree(pytree: jtp.PyTree) -> jtp.ArrayJax: + # print("compiling") + # return jax.vmap(self.flatten_pytree)(pytree) + # + # return flatten_pytree(pytree=pytree) + + # TODO: this trigger recompilation -> do some trick + # return jax.jit(jax.vmap(self.flatten_pytree))(pytree) + return PyTree._flatten_sample_vmap(pytree) + + @staticmethod + @jax.jit + def _flatten_sample_vmap(pytree: jtp.PyTree) -> jtp.VectorJax: + return jax.vmap(PyTree.flatten_pytree)(pytree) + + def unflatten_sample(self, x: jtp.Vector) -> jtp.PyTree: + """""" + + if not self.vectorized: + _, unflatten_fn = jax.flatten_util.ravel_pytree(self.low) + return unflatten_fn(x) + + # low_1d = jax.tree_util.tree_map(lambda l: l[0], self.low) + # low_1d_flat, unflatten_fn = jax.flatten_util.ravel_pytree(low_1d) + # + # @jax.jit + # def unflatten_sample(x: jtp.Vector) -> jtp.PyTree: + # return jax.vmap(unflatten_fn)(x) + # + # return unflatten_sample(x=x) + return PyTree._unflatten_sample_vmap(x=x, low=self.low) + + @staticmethod + @jax.jit + def _unflatten_sample_vmap(x: jtp.Vector, low: jtp.PyTree) -> jtp.PyTree: + """""" + + low_1d = jax.tree_util.tree_map(lambda l: l[0], low) + low_1d_flat, unflatten_fn = jax.flatten_util.ravel_pytree(low_1d) + + return jax.vmap(unflatten_fn)(x) + + def clip(self, x: jtp.PyTree) -> jtp.PyTree: + """""" + + # TODO: prevent recompilation + # @jax.jit + def _clip(pytree: jtp.PyTree) -> jtp.PyTree: + return jax.tree_util.tree_map( + lambda low, high, leaf: jnp.array( + jnp.clip(a=leaf, a_min=low, a_max=high), dtype=jnp.array(low).dtype + ), + self.low, + self.high, + pytree, + ) + + return _clip(pytree=x) + + +@flatdim.register(PyTree) +def _flatdim_pytree(space: PyTree) -> int: + """""" + + low_flat = space.flatten_sample(pytree=space.low) + return low_flat.size + + +@flatten.register(PyTree) +def _flatten_pytree(space: PyTree, x: jtp.PyTree) -> npt.NDArray: + """""" + + assert x in space + return space.flatten_sample(pytree=x) + + +@batch_space.register(PyTree) +def _batch_space_pytree(space: PyTree, n: int = 1) -> PyTree: + """""" + + return PyTree(low=space.low, high=space.high, vectorize=n) diff --git a/src/jaxgym/jax/utils.py b/src/jaxgym/jax/utils.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/jaxgym/stable_baselines.py b/src/jaxgym/stable_baselines.py new file mode 100644 index 000000000..c6526ab7c --- /dev/null +++ b/src/jaxgym/stable_baselines.py @@ -0,0 +1,209 @@ +import functools +from typing import Any, Dict, List, Optional, Type, Union + +import gymnasium as gym +import jax.random +import numpy as np +from gymnasium.experimental.vector.vector_env import VectorWrapper +from stable_baselines3.common import vec_env as vec_env_sb + +import jaxsim.typing as jtp +from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper, PyTree +from jaxgym.vector.jax import FlattenSpacesVecWrapper, JaxVectorEnv +from jaxgym.wrappers.jax import ToNumPyWrapper + + +class CustomVecEnvSB(vec_env_sb.VecEnv): + """Custom vectorized environment for SB3.""" + + metadata = {"render_modes": []} + + def __init__( + self, + jax_vector_env: JaxVectorEnv | VectorWrapper, + ) -> None: + """ + Create a custom vectorized environment for SB3 from a JaxVectorEnv. + + Args: + jax_vector_env: The JaxVectorEnv to wrap. + """ + + if not isinstance(jax_vector_env.unwrapped, JaxVectorEnv): + raise TypeError(type(jax_vector_env)) + + self.jax_vector_env = jax_vector_env + + single_env_action_space: PyTree = jax_vector_env.unwrapped.single_action_space + + single_env_observation_space: PyTree = ( + jax_vector_env.unwrapped.single_observation_space + ) + + super().__init__( + num_envs=self.jax_vector_env.num_envs, + action_space=single_env_action_space.to_box(), + observation_space=single_env_observation_space.to_box(), + render_mode=None, + ) + + self.actions = np.zeros_like(self.jax_vector_env.action_space.sample()) + + # Initialize the RNG seed + self._seed = None + self.seed() + + def reset(self) -> vec_env_sb.base_vec_env.VecEnvObs: + """Reset all the environments.""" + + observations, state_infos = self.jax_vector_env.reset(seed=self._seed) + return np.array(observations) + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + @staticmethod + @functools.partial(jax.jit, static_argnames=("batch_size",)) + def tree_inverse_transpose(pytree: jtp.PyTree, batch_size: int) -> List[jtp.PyTree]: + """ + Utility function to perform the inverse of a pytree transpose operation. + + It converts a pytree having the batch size in the first dimension of its leaves + to a list of pytrees having a single batch sample in their leaves. + + Note: Check the direct transpose operation in the following link: + https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees + + Args: + pytree: The batched pytree. + batch_size: The batch size. + + Returns: + A list of pytrees having a single batch sample in their leaves. + """ + + return [ + jax.tree_util.tree_map(lambda leaf: leaf[i], pytree) + for i in range(batch_size) + ] + + def step_wait(self) -> vec_env_sb.base_vec_env.VecEnvStepReturn: + """Wait for the step taken with step_async().""" + + ( + observations, + rewards, + terminals, + truncated, + step_infos, + ) = self.jax_vector_env.step(actions=self.actions) + + done = np.logical_or(terminals, truncated) + + # Convert the infos from a batched dictionary to a list of dictionaries + list_of_step_infos = self.tree_inverse_transpose( + pytree=step_infos, batch_size=self.jax_vector_env.num_envs + ) + + # Convert all info data to numpy + list_of_step_infos_numpy = [ + ToNumPyWrapper.pytree_to_numpy(pytree=pt) for pt in list_of_step_infos + ] + + return ( + np.array(observations), + np.array(rewards), + np.array(done), + list_of_step_infos_numpy, + ) + + def close(self) -> None: + """Clean up the environment's resources.""" + + return self.jax_vector_env.close() + + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + """Sets the random seeds for all environments.""" + + if seed is None: + seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32") + + if np.array(seed, dtype="uint32") != np.array(seed): + raise ValueError(f"seed must be compatible with 'uint32' casting") + + self._seed = seed + return [seed] + + def get_attr( + self, attr_name: str, indices: vec_env_sb.base_vec_env.VecEnvIndices = None + ) -> List[Any]: + raise NotImplementedError + + def set_attr( + self, + attr_name: str, + value: Any, + indices: vec_env_sb.base_vec_env.VecEnvIndices = None, + ) -> None: + raise NotImplementedError + + def env_method( + self, + method_name: str, + *method_args, + indices: vec_env_sb.base_vec_env.VecEnvIndices = None, + **method_kwargs, + ) -> List[Any]: + raise NotImplementedError + + def env_is_wrapped( + self, + wrapper_class: Type[gym.Wrapper], + indices: vec_env_sb.base_vec_env.VecEnvIndices = None, + ) -> List[bool]: + raise NotImplementedError + + +def make_vec_env_stable_baselines( + jax_dataclass_env: JaxDataclassEnv | JaxDataclassWrapper, + n_envs: int = 1, + seed: Optional[int] = None, + # monitor_dir: Optional[str] = None, + vec_env_kwargs: Optional[Dict[str, Any]] = None, +) -> vec_env_sb.VecEnv: + """ + Create a SB3 vectorized environment from an individual `JaxDataclassEnv`. + + Args: + jax_dataclass_env: The individual `JaxDataclassEnv`. + n_envs: Number of parallel environments. + seed: The seed for the vectorized environment. + vec_env_kwargs: Additional arguments to pass upon environment creation. + + Returns: + The SB3 vectorized environment. + """ + + env = jax_dataclass_env + vec_env_kwargs = vec_env_kwargs if vec_env_kwargs is not None else dict() + + # Vectorize the environment. + # Note: it automatically wraps the environment in a TimeLimit wrapper. + # Note: the space must be PyTree. + vec_env = JaxVectorEnv( + func_env=env, + num_envs=n_envs, + **vec_env_kwargs, + ) + + # Flatten the PyTree spaces to regular Box spaces + vec_env = FlattenSpacesVecWrapper(env=vec_env) + + # Convert the vectorized environment to a SB3 vectorized environment + vec_env_sb = CustomVecEnvSB(jax_vector_env=vec_env) + + # Set the seed + if seed is not None: + _ = vec_env_sb.seed(seed=seed) + + return vec_env_sb diff --git a/src/jaxgym/tests/test_spaces.py b/src/jaxgym/tests/test_spaces.py new file mode 100644 index 000000000..0a5e656d0 --- /dev/null +++ b/src/jaxgym/tests/test_spaces.py @@ -0,0 +1,102 @@ +import jax.numpy as jnp +import jax.random +import jax_dataclasses +import pytest + +import jaxsim.typing as jtp +from jaxgym import spaces + + +def compare(low: jtp.PyTree, high: jtp.PyTree, box: spaces.Box) -> None: + """""" + + assert box.contains(x=low) + assert box.contains(x=high) + + key = jax.random.PRNGKey(seed=0) + + for _ in range(10): + key, subkey = jax.random.split(key=key, num=2) + sample = box.sample(key=subkey) + assert box.contains(x=sample) + + +def test_box_numpy() -> None: + """""" + + low = jnp.zeros(10) + high = jnp.ones(10) + + box = spaces.Box(low=low, high=high) + + compare(low=low, high=high, box=box) + + assert box.contains(x=0.5 * jnp.ones_like(low)) + assert not box.contains(x=1.5 * jnp.ones_like(low)) + assert not box.contains(x=-0.5 * jnp.ones_like(low)) + + with pytest.raises(ValueError): + _ = spaces.Box(low=low, high=jnp.ones(low.size + 1)) + + with pytest.raises(ValueError): + _ = spaces.Box(low=low, high=jnp.ones(low.size, dtype=int)) + + +def test_box_pytree() -> None: + """""" + + @jax_dataclasses.pytree_dataclass + class SimplePyTree: + flag: jtp.Bool + value: jtp.Float + position: jtp.Vector + velocity: jtp.Vector + + @staticmethod + def zero() -> "SimplePyTree": + return SimplePyTree( + flag=False, + value=0, + position=jnp.zeros(5), + velocity=jnp.zeros(10), + ) + + zero = SimplePyTree.zero() + + low = SimplePyTree( + flag=False, + value=-42.0, + position=-10.0 * jnp.ones_like(zero.position), + velocity=0.1 * jnp.ones_like(zero.velocity), + ) + + high = SimplePyTree( + flag=True, + value=42.0, + position=10.0 * jnp.ones_like(zero.position), + velocity=5.0 * jnp.ones_like(zero.velocity), + ) + + box = spaces.Box(low=low, high=high) + + compare(low=low, high=high, box=box) + + # Wrong dimension of 'position' + with pytest.raises(ValueError): + wrong_high = SimplePyTree( + flag=True, + value=42.0, + position=jnp.zeros(6), + velocity=jnp.zeros(10), + ) + _ = spaces.Box(low=low, high=wrong_high) + + # Wrong type of 'position' and 'value' + with pytest.raises(ValueError): + wrong_high = SimplePyTree( + flag=True, + value=int(42), + position=jnp.zeros(5, dtype=int), + velocity=jnp.zeros(10), + ) + _ = spaces.Box(low=low, high=wrong_high) diff --git a/src/jaxgym/vector/__init__.py b/src/jaxgym/vector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/jaxgym/vector/jax/__init__.py b/src/jaxgym/vector/jax/__init__.py new file mode 100644 index 000000000..0cab549d5 --- /dev/null +++ b/src/jaxgym/vector/jax/__init__.py @@ -0,0 +1,2 @@ +from .vector_env import JaxVectorEnv +from .wrappers import FlattenSpacesVecWrapper diff --git a/src/jaxgym/vector/jax/vector_env.py b/src/jaxgym/vector/jax/vector_env.py new file mode 100644 index 000000000..096db16a8 --- /dev/null +++ b/src/jaxgym/vector/jax/vector_env.py @@ -0,0 +1,359 @@ +import copy +from typing import Any, Sequence + +import jax.flatten_util +import jax.numpy as jnp +import jax.random +import numpy as np +from gymnasium.envs.registration import EnvSpec +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) +from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv +from gymnasium.utils import seeding +from gymnasium.vector.utils import batch_space + +import jaxgym.jax.pytree_space as spaces +import jaxsim.typing as jtp +from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper +from jaxgym.wrappers.jax import JaxTransformWrapper, TimeLimit +from jaxsim import logging +from jaxsim.utils import not_tracing + + +# https://github.com/Farama-Foundation/Gymnasium/blob/e0cd42f77504060e770ab52932bf7eba45ff1976/gymnasium/experimental/functional_jax_env.py#L116 +# TODO: allow num_envs = 1 so we have automatically autoreset? +# class JaxVectorEnv(VectorEnv[VectorObsType, VectorActType, ArrayType]): +# Note no dataclass here on all stuff related to VectorEnv +class JaxVectorEnv(VectorEnv[ObsType, ActType, ArrayType]): + """ + A vectorized version of JAX-based functional environments exposing `VectorEnv` APIs. + """ + + observation_space: spaces.PyTree + action_space: spaces.PyTree + + def __init__( + self, + func_env: JaxDataclassEnv[ + StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + ] + | JaxDataclassWrapper, + num_envs: int, + max_episode_steps: int = 0, + metadata: dict[str, Any] | None = None, + render_mode: str | None = None, + reward_range: tuple[float, float] = (-float("inf"), float("inf")), + spec: EnvSpec | None = None, + jit_compile: bool = True, + ) -> None: + """""" + + if not isinstance(func_env.unwrapped, JaxDataclassEnv): + raise TypeError(type(func_env.unwrapped), JaxDataclassEnv) + + metadata = metadata if metadata is not None else dict(render_mode=list()) + + self.num_envs = num_envs + self.func_env_single = func_env + self.single_observation_space = func_env.observation_space + self.single_action_space = func_env.action_space + + # TODO: convert other spaces to their PyTree equivalent + assert isinstance(func_env.action_space, spaces.PyTree) + assert isinstance(func_env.observation_space, spaces.PyTree) + + self.action_space = batch_space(self.single_action_space, n=num_envs) + self.observation_space = batch_space(self.single_observation_space, n=num_envs) + + # TODO: attributes below + self.metadata = metadata + self.render_mode = render_mode + self.reward_range = reward_range + self.spec = spec + # self.time_limit = max_episode_steps + + # Store the original functional environment + self.func_env_single = func_env + + def has_wrapper( + func_env: JaxDataclassEnv | JaxDataclassWrapper, + wrapper_cls: type, + ) -> bool: + """""" + + while not isinstance(func_env, JaxDataclassEnv): + if isinstance(func_env, wrapper_cls): + return True + + func_env = func_env.env + + return False + + # Always wrap the environment in a TimeLimit wrapper, that automatically counts + # the number of steps and issues a "truncated" flag. + # Note: the TimeLimit wrapper is a no-op if max_episode_steps is 0. + # Note: the state of the wrapped environment now is different. The state of + # the original environment is now encapsulated in a dictionary. + # TODO: make this optional? Check if it is already wrapped? + # if max_episode_steps is not None: + if not has_wrapper(func_env=self.func_env_single, wrapper_cls=TimeLimit): + logging.debug( + "[JaxVectorEnv] Wrapping the environment in a 'TimeLimit' wrapper" + ) + self.func_env_single = TimeLimit( + env=self.func_env_single, max_episode_steps=max_episode_steps + ) + + # Initialize the attribute that will store the environments state + self.states = None + + # Initialize the step counter + # TODO: handled by TimeLimit + # self.steps = jnp.zeros(self.num_envs, dtype=jnp.uint32) + + # TODO: in our case, assume pytree? -> batch easy and generic? + # --> singledispatch from gymnasium.space to pytree? And add a wrapper to_numpy|to_pytorch later? + # Doing like this, Obs|Action|Reward are always pytree. + # self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) + + # if self.render_mode == "rgb_array": + # self.render_state = self.func_env.render_init() + # else: + # self.render_state = None + + # Initialize the RNGs with a random seed + seed = np.random.default_rng().integers(0, 2 ** 32 - 1, dtype="uint32") + self._np_random, _ = seeding.np_random(seed=int(seed)) + self._key = jax.random.PRNGKey(seed=seed) + + # self.func_env = TransformWrapper(env=self.func_env, function=jax.vmap) + self.func_env = JaxTransformWrapper(env=self.func_env_single, function=jax.vmap) + + # Compile resources in JIT if requested. + # Note: this wrapper will override any other JIT wrapper already present. + if jit_compile: + self.step_autoreset_func = jax.jit(self.step_autoreset_func) + self.func_env = JaxTransformWrapper(env=self.func_env, function=jax.jit) + + def subkey(self, num: int = 1) -> jax.random.PRNGKeyArray: + """ + Generate one or multiple sub-keys from the internal key. + + Note: + The internal key is automatically updated, there's no need to handle + the environment key externally. + + Args: + num: Number of keys to generate. + + Returns: + The generated sub-keys. + """ + + self._key, *sub_keys = jax.random.split(self._key, num=num + 1) + return jnp.stack(sub_keys).squeeze() + + def reset( + self, *, seed: int | None = None, options: dict | None = None + ) -> tuple[ObsType, dict[str, Any]]: + """ + Reset the environments. + + Note: + This method should be called just once after the creation of the vectorized + environment. This class implements autoreset, therefore environments that + either terminated or have been truncated get automatically reset. + + Args: + seed: + options: + + Returns: + A tuple containing the initial observations and the initial states' info. + """ + + super().reset(seed=seed) + self._key = jax.random.PRNGKey(seed) if seed is not None else self._key + + # Generate initial states + self.states = self.func_env.initial(rng=self.subkey(num=self.num_envs)) + + # Sample initial observations and infos + observations = self.func_env.observation(self.states) + infos = self.func_env.state_info(self.states) + + return observations, infos + + @staticmethod + def binary_mask_pytree( + pytree_a: jtp.PyTree, pytree_b: jtp.PyTree, mask: Sequence[bool] + ) -> jtp.PyTree: + """ + Compute a new vectorized PyTree selecting elements from either of the + two input PyTrees according to the boolean mask. + + Note: + The shapes of pytree_a and pytree_b must match, and they must be vectorized, + meaning that all their leafs have share the dimension of the first axis. + The mask should have as many elements as this shared dimension. + + Args: + pytree_a: the first vectorized PyTree object. + pytree_b: the second vectorized PyTree object. + mask: the boolean mask to select elements either from pytree_a (when True) + or pytree_b (when False). + + Returns: + A new PyTree having elements taken either from pytree_a or pytree_a + according to mask. + """ + + def check(): + first_dim_a = jax.tree_util.tree_map(lambda l: l.shape[0], pytree_a) + first_dim_b = jax.tree_util.tree_map(lambda l: l.shape[0], pytree_b) + + # Check that the input PyTrees have the same first dimension of their leaves + if first_dim_a != first_dim_b: + raise ValueError() + + in_axis_a = jnp.unique( + jax.flatten_util.ravel_pytree(first_dim_a)[0] + ).squeeze() + in_axis_b = jnp.unique( + jax.flatten_util.ravel_pytree(first_dim_a)[0] + ).squeeze() + + # Check that all leaves have the same first dimension and it matches with + # the length of the mask + if in_axis_a != in_axis_b != len(mask): + raise ValueError() + + if not_tracing(var=pytree_a): + check() + + # Convert the boolean mask to a PyTree having boolean leaves. + # True elements of the leaves are taken from pytree_a, False ones from pytree_b. + mask_pytree = jax.tree_util.tree_map( + lambda l: jnp.ones_like(l, dtype=bool) + * mask[(...,) + (jnp.newaxis,) * (l.ndim - 1)], + pytree_a, + ) + + # Create the output pytree taking elements from either pytree_a or pytree_b + # according to the boolean PyTree built from the mask + tree_out = jax.tree_util.tree_map( + lambda a, b, m: jnp.where(m, a, b), pytree_a, pytree_b, mask_pytree + ) + + return tree_out + + @staticmethod + def step_autoreset_func( + env: JaxDataclassEnv | JaxDataclassWrapper, + states: StateType, + actions: ActType, + key1: jax.random.PRNGKeyArray, + key2: jax.random.PRNGKeyArray, + ) -> tuple[StateType, tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]]: + """""" + + # Duplicate the keys + # TODO: with new jax version maybe split can be jitted -> pass just one key and split inside + # key, *subkey_1 = jax.random.split(key, num=num_envs + 1) + # key, *subkey_2 = jax.random.split(key, num=num_envs + 1) + + # Compute data by stepping the environments + next_states = env.transition(state=states, action=actions, rng=key1) + rewards = env.reward(state=states, action=actions, next_state=next_states) + terminals = env.terminal(state=next_states) + step_infos = env.step_info(state=states, action=actions, next_state=next_states) + truncated = step_infos["truncated"] + + # Check if any environment is done + dones = jnp.logical_or(terminals, truncated) + + # Add into step_infos the information about the final state even if the + # environments are not done. + # This is necessary for having a constant structure of the output pytree. + # The _final_observation|_final_info masks can be used to filter out the + # actual final data from the final_observation|final_info dictionaries. + # + # Note: the step_info dictionary of done environments that have been + # automatically reset shouldn't be consumed. It refers to the environment + # before being reset. Users in this case should read final_info. + step_infos |= ( + dict( + final_observation=env.observation(state=next_states), + terminal_observation=env.observation(state=next_states), # sb3 + final_info=copy.deepcopy(step_infos), + _final_observation=dones, + _final_info=dones, + is_done=dones, + ) + # Backward compatibility (?) -> SB3 (TODO: done in TimeLimit) + # | { + # "TimeLimit.truncated": truncated, + # "terminal_observation": env.observation(state=next_states), + # } + ) + + # Compute the new state and new state_infos for all environments. + # We return this data only for those that are done. + # new_states = env.initial(rng=key2) + # new_state_infos = env.state_info(state=new_states) + + # Compute the new states for all environments. + # We return this data only for those that are done. + new_states = env.initial(rng=key2) + + # Merge the environment states + new_env_states = JaxVectorEnv.binary_mask_pytree( + mask=dones, + # If done, return the new initial states + pytree_a=new_states, + # If not done, return the next states + pytree_b=next_states, + ) + + # Compute the new observations. + # This is a normal observation for environments that are not done. + # This is the observation of the initial state for environments that were done. + new_observations = env.observation(state=new_env_states) + + return new_env_states, ( + new_observations, + rewards, + terminals, + truncated, + step_infos, + ) + + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: + """Steps through the environment using the action.""" + + # TODO: clip as method of Space.clip(x)? use here with hasattr(space, clip)? + assert isinstance(self.action_space, spaces.PyTree) + actions = self.action_space.clip(x=actions) + + # TODO: move these inside autoreset as soon as jax.random.split + # supports jit compilation + keys_1 = self.subkey(num=self.num_envs) + keys_2 = self.subkey(num=self.num_envs) + + self.states, out = JaxVectorEnv.step_autoreset_func( + env=self.func_env, + states=self.states, + actions=actions, + key1=keys_1, + key2=keys_2, + ) + + return out diff --git a/src/jaxgym/vector/jax/wrappers/__init__.py b/src/jaxgym/vector/jax/wrappers/__init__.py new file mode 100644 index 000000000..88d67ef80 --- /dev/null +++ b/src/jaxgym/vector/jax/wrappers/__init__.py @@ -0,0 +1,3 @@ +from .flatten_spaces import FlattenSpacesVecWrapper + +# from .tensordict import TensorDictVecWrapper diff --git a/src/jaxgym/vector/jax/wrappers/flatten_spaces.py b/src/jaxgym/vector/jax/wrappers/flatten_spaces.py new file mode 100644 index 000000000..797aec323 --- /dev/null +++ b/src/jaxgym/vector/jax/wrappers/flatten_spaces.py @@ -0,0 +1,81 @@ +import jax.numpy as jnp +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) +from gymnasium.experimental.vector.vector_env import VectorWrapper + +from jaxgym.vector.jax import JaxVectorEnv + +WrapperStateType = StateType +WrapperObsType = jnp.ndarray # TODO jax.typing +WrapperActType = jnp.ndarray +WrapperRewardType = RewardType + + +# TODO: not dataclass when operating on VectorWrapper -> check other ones +class FlattenSpacesVecWrapper(VectorWrapper): + """""" + + # TODO: vec_env? + env: JaxVectorEnv + + def __init__(self, env: JaxVectorEnv) -> None: + """""" + + if not isinstance(env, JaxVectorEnv): + raise TypeError(type(env)) + + self.action_space = env.action_space.to_box() + self.observation_space = env.observation_space.to_box() + + super().__init__(env=env) + + def reset( + self, + **kwargs + # *, + # seed: int | list[int] | None = None, + # options: dict[str, Any] | None = None, + # ) -> tuple[ObsType, dict[str, Any]]: + ): + """""" + + observations, state_infos = self.env.reset(**kwargs) + # return self.env.observation_space.flatten_pytree(pytree=observation), state_info + return ( + self.env.observation_space.flatten_sample(pytree=observations), + state_infos, + ) + + def step(self, actions): + # ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: + """""" + + observations, rewards, terminals, truncated, step_infos = self.env.step( + actions=self.env.action_space.unflatten_sample(x=actions) + ) + + if "final_observation" in step_infos: + step_infos["final_observation"] = self.env.observation_space.flatten_sample( + pytree=step_infos["final_observation"] + ) + + if "terminal_observation" in step_infos: + step_infos[ + "terminal_observation" + ] = self.env.observation_space.flatten_sample( + pytree=step_infos["terminal_observation"] + ) + + return ( + self.env.observation_space.flatten_sample(pytree=observations), + rewards, + terminals, + truncated, + step_infos, + ) diff --git a/src/jaxgym/vector/jax/wrappers/tensordict.py b/src/jaxgym/vector/jax/wrappers/tensordict.py new file mode 100644 index 000000000..10fd5e437 --- /dev/null +++ b/src/jaxgym/vector/jax/wrappers/tensordict.py @@ -0,0 +1,3 @@ +class TensorDictVecWrapper: + def __init__(self) -> None: + raise NotImplementedError diff --git a/src/jaxgym/wrappers/__init__.py b/src/jaxgym/wrappers/__init__.py new file mode 100644 index 000000000..9926dee6e --- /dev/null +++ b/src/jaxgym/wrappers/__init__.py @@ -0,0 +1,2 @@ +from .state import StateWrapper +from .transform import TransformWrapper diff --git a/src/jaxgym/wrappers/clip_action.py b/src/jaxgym/wrappers/clip_action.py new file mode 100644 index 000000000..04e5d8896 --- /dev/null +++ b/src/jaxgym/wrappers/clip_action.py @@ -0,0 +1,25 @@ +# from typing import Generic +# from jaxgym.functional import ActionFuncWrapper +# from jaxgym.functional.func_wrapper import WrapperActType +# from gymnasium.experimental.functional import ActType +# import gymnasium as gym +# import numpy as np +# +# +# class ClipActionWrapper( +# ActionFuncWrapper[WrapperActType], +# Generic[WrapperActType], +# ): +# """""" +# +# def action(self, action: WrapperActType) -> ActType: +# """""" +# +# if self.action_space.contains(x=action): +# return action +# +# assert isinstance(self.action_space, gym.spaces.Box) +# +# return np.clip( +# action, a_min=self.action_space.low, a_max=self.action_space.high +# ) diff --git a/src/jaxgym/wrappers/jax/__init__.py b/src/jaxgym/wrappers/jax/__init__.py new file mode 100644 index 000000000..66e767ad2 --- /dev/null +++ b/src/jaxgym/wrappers/jax/__init__.py @@ -0,0 +1,10 @@ +from .action_noise import ActionNoiseWrapper +from .clip_action import ClipActionWrapper +from .flatten_spaces import FlattenSpacesWrapper +from .nan_handler import NaNHandlerWrapper +from .squash_action import SquashActionWrapper +from .time_limit import TimeLimit +from .to_numpy import ToNumPyWrapper +from .transform import JaxTransformWrapper + +# from .time_limit_sb import TimeLimitStableBaselines diff --git a/src/jaxgym/wrappers/jax/action_noise.py b/src/jaxgym/wrappers/jax/action_noise.py new file mode 100644 index 000000000..5fe4f214f --- /dev/null +++ b/src/jaxgym/wrappers/jax/action_noise.py @@ -0,0 +1,77 @@ +from typing import Any, Callable, Generic + +import jax.flatten_util +import jax.numpy as jnp +import jax.tree_util +import jax_dataclasses +import numpy.typing as npt +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.jax import JaxDataclassWrapper +from jaxsim import logging + +WrapperStateType = StateType +WrapperObsType = ObsType +WrapperActType = ActType +WrapperRewardType = RewardType + + +@jax_dataclasses.pytree_dataclass +class ActionNoiseWrapper( + JaxDataclassWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ], + Generic[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + ], +): + """""" + + noise_fn: Callable[ + [npt.NDArray, jax.random.PRNGKeyArray], npt.NDArray + ] = jax_dataclasses.static_field( + default=lambda action, rng: action + + 0.05 * jax.random.normal(key=rng, shape=action.shape) + ) + + def __post_init__(self) -> None: + """""" + + super().__post_init__() + + msg = f"[{self.__class__.__name__}] enabled" + logging.debug(msg=msg) + + def transition( + self, state: WrapperStateType, action: WrapperActType, rng: Any = None + ) -> WrapperStateType: + """""" + + rng, subkey = jax.random.split(rng, num=2) + + action_flat, restore_fn = jax.flatten_util.ravel_pytree(pytree=action) + action_noisy_flat = self.noise_fn(action_flat, subkey) + action_noisy = restore_fn(action_noisy_flat) + + return self.env.transition(state=state, action=action_noisy, rng=rng) diff --git a/src/jaxgym/wrappers/jax/clip_action.py b/src/jaxgym/wrappers/jax/clip_action.py new file mode 100644 index 000000000..3d2381bca --- /dev/null +++ b/src/jaxgym/wrappers/jax/clip_action.py @@ -0,0 +1,50 @@ +from typing import Generic + +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.functional.func_wrapper import WrapperActType +from jaxgym.jax import JaxDataclassActionWrapper +from jaxsim import logging + + +@jax_dataclasses.pytree_dataclass +class ClipActionWrapper( + JaxDataclassActionWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + ], + Generic[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + ], +): + """""" + + def __post_init__(self) -> None: + """""" + + super().__post_init__() + + msg = f"[{self.__class__.__name__}] enabled" + logging.debug(msg=msg) + + def action(self, action: WrapperActType) -> ActType: + """""" + + return self.action_space.clip(x=action) diff --git a/src/jaxgym/wrappers/jax/flatten_spaces.py b/src/jaxgym/wrappers/jax/flatten_spaces.py new file mode 100644 index 000000000..bb31de3e2 --- /dev/null +++ b/src/jaxgym/wrappers/jax/flatten_spaces.py @@ -0,0 +1,75 @@ +from typing import Any + +import gymnasium as gym +import jax.numpy as jnp +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.jax import JaxDataclassWrapper +from jaxsim import logging + +WrapperStateType = StateType +WrapperObsType = jnp.ndarray +WrapperActType = jnp.ndarray +WrapperRewardType = RewardType + + +# TODO: maybe better over JaxEnv to be consistent with JaxVectorEnv? +@jax_dataclasses.pytree_dataclass +class FlattenSpacesWrapper( + JaxDataclassWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ] +): + """""" + + # Propagate to other Jax wrappers + def __post_init__(self): + """""" + + super().__post_init__() + + msg = f"[{self.__class__.__name__}] enabled" + logging.debug(msg=msg) + + @property + def action_space(self) -> gym.Space: + """""" + + return self.env.action_space.to_box() + + @property + def observation_space(self) -> gym.Space: + """""" + + return self.env.observation_space.to_box() + + def transition( + self, state: WrapperStateType, action: WrapperActType, rng: Any = None + ) -> WrapperStateType: + """""" + + action_pytree = self.env.action_space.unflatten_sample(x=action) + return self.env.transition(state=state, action=action_pytree, rng=rng) + + def observation(self, state: WrapperStateType) -> WrapperObsType: + """""" + + observation_pytree = self.env.observation(state=state) + return self.env.observation_space.flatten_pytree(pytree=observation_pytree) diff --git a/src/jaxgym/wrappers/jax/nan_handler.py b/src/jaxgym/wrappers/jax/nan_handler.py new file mode 100644 index 000000000..73638f5c5 --- /dev/null +++ b/src/jaxgym/wrappers/jax/nan_handler.py @@ -0,0 +1,138 @@ +from typing import Any, ClassVar, Generic + +import jax.flatten_util +import jax.numpy as jnp +import jax.tree_util +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +import jaxsim.typing as jtp +from jaxgym.functional import FuncEnv +from jaxgym.wrappers import StateWrapper +from jaxsim import logging + +WrapperStateType = StateType +WrapperObsType = ObsType +WrapperActType = ActType +WrapperRewardType = RewardType + + +@jax_dataclasses.pytree_dataclass +class NaNHandlerWrapper( + StateWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + WrapperStateType, + ], + Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType], +): + """Reset the environment when a NaN is encountered.""" + + env: FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType] + + HasNanKey: ClassVar[str] = "has_nan" + EnvironmentStateKey: ClassVar[str] = "env" + + def __post_init__(self) -> None: + """""" + + msg = f"[{self.__class__.__name__}] enabled" + logging.debug(msg=msg) + + def wrapper_state_to_environment_state( + self, wrapper_state: WrapperStateType + ) -> StateType: + """""" + + return wrapper_state[NaNHandlerWrapper.EnvironmentStateKey] + + def initial(self, rng: Any = None) -> WrapperStateType: + """""" + + return { + NaNHandlerWrapper.HasNanKey: jnp.array(False), + NaNHandlerWrapper.EnvironmentStateKey: self.env.initial(rng=rng), + } + + def transition( + self, state: WrapperStateType, action: WrapperActType, rng: Any = None + ) -> WrapperStateType: + """""" + + # Copy the current state + old_state = jax.tree_util.tree_map( + lambda x: x, state[NaNHandlerWrapper.EnvironmentStateKey] + ) + + # Step the environment + new_state = self.env.transition( + state=self.wrapper_state_to_environment_state(wrapper_state=state), + action=action, + rng=rng, + ) + + new_state_without_nans = jax.tree_util.tree_map( + lambda leaf_new, leaf_old: jax.lax.select( + pred=self.pytree_has_nan_values(pytree=leaf_new), + on_true=leaf_old, + on_false=leaf_new, + ), + new_state, + old_state, + ) + + return { + NaNHandlerWrapper.HasNanKey: self.pytree_has_nan_values(pytree=new_state), + NaNHandlerWrapper.EnvironmentStateKey: new_state_without_nans, + } + + def step_info( + self, + state: WrapperStateType, + action: WrapperActType, + next_state: WrapperStateType, + ) -> dict[str, Any]: + """""" + + # Get the step info from the environment + info = self.env.step_info( + state=self.wrapper_state_to_environment_state(wrapper_state=state), + action=action, + next_state=self.wrapper_state_to_environment_state( + wrapper_state=next_state + ), + ) + + # Activate the truncation flag if the episode is over + truncated = jnp.array(next_state[NaNHandlerWrapper.HasNanKey], dtype=bool) + + # Check if any other wrapper already truncated the environment + truncated = jnp.logical_or(truncated, info.get("truncated", False)) + + # Handle the case in which the environment has been truncated and is done + truncated = jax.lax.select( + pred=self.terminal(state=next_state), + on_true=False, + on_false=truncated, + ) + + # Return the extended step info + return info | dict(truncated=truncated) | {"TimeLimit.truncated": truncated} + + @staticmethod + def pytree_has_nan_values(pytree: jtp.PyTree) -> jtp.Bool: + """""" + + pytree_flat, _ = jax.flatten_util.ravel_pytree(pytree=pytree) + return jnp.isnan(pytree_flat).any() diff --git a/src/jaxgym/wrappers/jax/observation_noise.py b/src/jaxgym/wrappers/jax/observation_noise.py new file mode 100644 index 000000000..782b0ffd6 --- /dev/null +++ b/src/jaxgym/wrappers/jax/observation_noise.py @@ -0,0 +1,128 @@ +from typing import Any, Callable, Generic + +import jax.numpy as jnp +import jax.tree_util +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.jax import JaxDataclassWrapper +from jaxsim import logging + +WrapperStateType = StateType +WrapperObsType = ObsType +WrapperActType = ActType +WrapperRewardType = RewardType + + +# TODO: cannot do it here because only transition() has rng and not observation() +@jax_dataclasses.pytree_dataclass +class ObservationNoiseWrapper( + JaxDataclassWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ], + Generic[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + ], +): + """""" + + noise_fn: Callable[[ObsType], ObsType] = jax_dataclasses.static_field() + + def __post_init__(self) -> None: + """""" + + super().__post_init__() + + msg = f"[{self.__class__.__name__}] enabled" + logging.debug(msg=msg) + + # assert isinstance(self.env.action_space.sample(), np.ndarray) + # assert isinstance(self.env.observation_space.sample(), np.ndarray) + + def transition( + self, state: WrapperStateType, action: WrapperActType, rng: Any = None + ) -> WrapperStateType: + """""" + + return self.env.transition(state=state, action=action, rng=rng) + + # def observation(self, state: WrapperStateType) -> WrapperObsType: + # """""" + # + # observation = ToNumPyWrapper.pytree_to_numpy(self.env.observation(state=state)) + # return np.array(observation, dtype=self.env.observation_space.dtype) + + # def reward( + # self, + # state: WrapperStateType, + # action: WrapperActType, + # next_state: WrapperStateType, + # ) -> WrapperRewardType: + # """""" + # + # return float( + # ToNumPyWrapper.pytree_to_numpy( + # self.env.reward(state=state, action=action, next_state=next_state) + # ) + # ) + + # def terminal(self, state: WrapperStateType) -> TerminalType: + # """""" + # + # return ToNumPyWrapper.pytree_to_numpy(self.env.terminal(state=state)) + # + # def state_info(self, state: WrapperStateType) -> dict[str, Any]: + # """""" + # + # return ToNumPyWrapper.pytree_to_numpy(self.env.state_info(state=state)) + # + # def step_info( + # self, + # state: WrapperStateType, + # action: WrapperActType, + # next_state: WrapperStateType, + # ) -> dict[str, Any]: + # """""" + # + # return ToNumPyWrapper.pytree_to_numpy( + # self.env.step_info(state=state, action=action, next_state=next_state) + # ) + # + # @staticmethod + # def pytree_to_numpy(pytree: Any) -> Any: + # """""" + # + # def convert_leaf(leaf: Any) -> Any: + # """""" + # + # if ( + # isinstance(leaf, (np.ndarray, jnp.ndarray)) + # and leaf.size == 1 + # and leaf.dtype == "bool" + # ): + # return bool(leaf) + # + # return np.array(leaf) + # + # return jax.tree_util.tree_map(lambda l: convert_leaf(l), pytree) diff --git a/src/jaxgym/wrappers/jax/squash_action.py b/src/jaxgym/wrappers/jax/squash_action.py new file mode 100644 index 000000000..0b3e53b65 --- /dev/null +++ b/src/jaxgym/wrappers/jax/squash_action.py @@ -0,0 +1,105 @@ +from typing import Generic + +import jax.numpy as jnp +import jax.tree_util +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.functional.func_wrapper import WrapperActType +from jaxgym.jax import JaxDataclassActionWrapper +from jaxsim import logging +from jaxsim import typing as jtp +from jaxsim.utils import Mutability + + +@jax_dataclasses.pytree_dataclass +class SquashActionWrapper( + JaxDataclassActionWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + ], + Generic[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + ], +): + """""" + + def __post_init__(self) -> None: + """""" + + super().__post_init__() + + msg = f"[{self.__class__.__name__}] enabled" + logging.debug(msg=msg) + + # Replace the action space with the squashed action space. + # Note: we assume the entire action space is bounded and there are no +-inf. + # Note: we assume the entire action space is composed by floats (no bools, etc). + + import copy + + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + # First make a copy of the PyTree + self.action_space = copy.deepcopy(self.env.action_space) + # self.action_space = jax.tree_util.tree_map( + # lambda l: l, self.env.action_space + # ) + + # The update both the low and high bounds + self.action_space.low = jax.tree_util.tree_map( + lambda l: -1.0 * jnp.ones_like(l), self.env.action_space.low + ) + self.action_space.high = jax.tree_util.tree_map( + lambda l: 1.0 * jnp.ones_like(l), self.env.action_space.high + ) + + def action(self, action: WrapperActType) -> ActType: + """""" + + return self.unsquash( + pytree=action, + low=self.env.action_space.low, + high=self.env.action_space.high, + ) + + @staticmethod + def squash(pytree: jtp.PyTree, low: jtp.PyTree, high: jtp.PyTree) -> jtp.PyTree: + """""" + + pytree_squashed = jax.tree_util.tree_map( + lambda leaf, l, h: 2 * (leaf - l) / (h - l) - 1, + pytree, + low, + high, + ) + + return pytree_squashed + + @staticmethod + def unsquash(pytree: jtp.PyTree, low: jtp.PyTree, high: jtp.PyTree) -> jtp.PyTree: + """""" + + pytree_unsquashed = jax.tree_util.tree_map( + lambda leaf, l, h: (leaf + 1) * (h - l) / 2 + l, + pytree, + low, + high, + ) + + return pytree_unsquashed diff --git a/src/jaxgym/wrappers/jax/time_limit.py b/src/jaxgym/wrappers/jax/time_limit.py new file mode 100644 index 000000000..8c69eb74b --- /dev/null +++ b/src/jaxgym/wrappers/jax/time_limit.py @@ -0,0 +1,159 @@ +from typing import Any, ClassVar, Generic + +import jax.lax +import jax.numpy as jnp +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.functional import FuncEnv +from jaxgym.wrappers import StateWrapper +from jaxsim import logging + +WrapperStateType = dict[str, Any] +WrapperObsType = ObsType +WrapperActType = ActType +WrapperRewardType = RewardType + + +@jax_dataclasses.pytree_dataclass +class TimeLimit( + StateWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + WrapperStateType, + ], + Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType], +): + """""" + + env: FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType] + max_episode_steps: int = jax_dataclasses.static_field() + + ElapsedStepsKey: ClassVar[str] = "elapsed_steps" + EnvironmentStateKey: ClassVar[str] = "env" + + def __post_init__(self) -> None: + """""" + + # TODO assert >=1? + msg = f"[{self.__class__.__name__}] max_episode_steps={self.max_episode_steps}" + logging.debug(msg=msg) + + def wrapper_state_to_environment_state( + self, wrapper_state: WrapperStateType + ) -> StateType: + """""" + + return wrapper_state[TimeLimit.EnvironmentStateKey] + + def initial(self, rng: Any = None) -> WrapperStateType: + """""" + + environment_state = self.env.initial(rng=rng) + + return { + TimeLimit.EnvironmentStateKey: environment_state, + TimeLimit.ElapsedStepsKey: 0, + } + + def transition( + self, state: WrapperStateType, action: WrapperActType, rng: Any = None + ) -> WrapperStateType: + """""" + + elapsed_steps = state[TimeLimit.ElapsedStepsKey] + elapsed_steps += 1 + + # print("+++") + # print(state) + # print(self.wrapper_state_to_environment_state(wrapper_state=state)) + + environment_state = self.env.transition( + state=self.wrapper_state_to_environment_state(wrapper_state=state), + action=action, + rng=rng, + ) + + return { + TimeLimit.EnvironmentStateKey: environment_state, + TimeLimit.ElapsedStepsKey: elapsed_steps, + } + + def step_info( + self, state: WrapperStateType, action: ActType, next_state: WrapperStateType + ) -> dict[str, Any]: + """""" + + # Get the step info from the environment + info = self.env.step_info( + state=self.wrapper_state_to_environment_state(wrapper_state=state), + action=action, + next_state=self.wrapper_state_to_environment_state( + wrapper_state=next_state + ), + ) + + # assert "truncated" not in info # gymnasium + + # TODO: make a specific wrapper for stable baselines? + # 1. add TimeLimit.truncated + # 2. add terminal_observation + # 3. step_dict -> list[step_dict] + # 4. all to numpy + # assert "TimeLimit.truncated" not in info # stable-baselines3 + # TODO: in stable-baselines -> truncated and terminated are mutually exclusive + + # Activate the truncation flag if the episode is over + truncated = jnp.array( + next_state[TimeLimit.ElapsedStepsKey] >= self.max_episode_steps, dtype=bool + ) + + # If max_episode_steps=0, this wrapper is a no-op + truncated = truncated if self.max_episode_steps != 0 else False + + # Check if any other wrapper already truncated the environment + truncated = jnp.logical_or(truncated, info.get("truncated", False)) + # truncated = jax.lax.select( + # pred="truncated" in info, + # on_true=jnp.logical_or(truncated, info["truncated"]), + # on_false=truncated, + # ) + + # Handle the case in which the environment has been truncated and is done + truncated = jax.lax.select( + pred=self.terminal(state=next_state), + on_true=False, + on_false=truncated, + ) + + # Return the extended step info + # return info | dict(truncated=truncated) + return ( + info + | dict(truncated=truncated) + | {"TimeLimit.truncated": truncated} + # the following has to be done in the vecwrapper because other wrappers + # might have changed the observation structure: + # | dict(terminal_observation=self.observation(state=next_state)) + ) + + +# @jax.jit +# def has_field(d) -> bool: +# import jax.lax +# return jax.lax.select( +# pred="f" in d, +# on_true=True, +# on_false=False, +# ) diff --git a/src/jaxgym/wrappers/jax/time_limit_sb.py b/src/jaxgym/wrappers/jax/time_limit_sb.py new file mode 100644 index 000000000..285c8006d --- /dev/null +++ b/src/jaxgym/wrappers/jax/time_limit_sb.py @@ -0,0 +1,109 @@ +from typing import Any, Generic + +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.functional import FuncEnv +from jaxgym.jax import JaxDataclassWrapper +from jaxsim import logging + +WrapperStateType = StateType +WrapperObsType = ObsType +WrapperActType = ActType +WrapperRewardType = RewardType + + +# NON USARE -> DIRETTAMENTE FATTO IN TimeLimit e bon + + +@jax_dataclasses.pytree_dataclass +class TimeLimitStableBaselines( + JaxDataclassWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ], + Generic[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + ], +): + """""" + + env: FuncEnv[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType] + + def __post_init__(self) -> None: + """""" + + msg = f"[{self.__class__.__name__}] enabled" + logging.debug(msg=msg) + + def step_info( + self, state: WrapperStateType, action: ActType, next_state: WrapperStateType + ) -> dict[str, Any]: + """""" + + # Get the step info from the environment + info = self.env.step_info(state=state, action=action, next_state=next_state) + + return info | {"TimeLimit.truncated": info.get("truncated", False)} + + # has_truncated_key = jax.lax.select( + # pred="TimeLimit.truncated" in info, on_true=True, on_false=False + # ) + # + # return jax.lax.select( + # # pred=jnp.array([]).all(), + # pred=info.get("truncated", False), + # on_true=info | {"TimeLimit.truncated": True}, + # on_false=info | {"TimeLimit.truncated": False}, + # ) + + # assert "truncated" not in info # gymnasium + # + # # TODO: make a specific wrapper for stable baselines? + # # 1. add TimeLimit.truncated + # # 2. add terminal_observation + # # 3. step_dict -> list[step_dict] + # # 4. all to numpy + # assert "TimeLimit.truncated" not in info # stable-baselines3 + # # TODO: in stable-baselines -> truncated and terminated are mutually exclusive + # + # # Activate the truncation flag if the episode is over + # truncated = jnp.array( + # next_state[TimeLimit.ElapsedStepsKey] >= self.max_episode_steps, dtype=bool + # ) + # + # # If max_episode_steps=0, this wrapper is a no-op + # truncated = truncated if self.max_episode_steps != 0 else False + # + # # Return the extended step info + # return info | dict(truncated=truncated) | {"TimeLimit.truncated": truncated} + + +# @jax.jit +# def has_field(d) -> bool: +# import jax.lax +# return jax.lax.select( +# pred="f" in d, +# on_true=True, +# on_false=False, +# ) diff --git a/src/jaxgym/wrappers/jax/to_numpy.py b/src/jaxgym/wrappers/jax/to_numpy.py new file mode 100644 index 000000000..610d018bb --- /dev/null +++ b/src/jaxgym/wrappers/jax/to_numpy.py @@ -0,0 +1,157 @@ +from typing import Any + +import jax.numpy as jnp +import jax.tree_util +import jax_dataclasses +import numpy as np +import numpy.typing as npt +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.jax import JaxDataclassWrapper +from jaxsim import logging + +WrapperStateType = StateType +WrapperObsType = npt.NDArray +WrapperActType = npt.NDArray +WrapperRewardType = float + + +@jax_dataclasses.pytree_dataclass +class ToNumPyWrapper( + JaxDataclassWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ] +): + # class ToNumPyWrapper( + # FuncWrapper[ + # # + # StateType, + # ObsType, + # ActType, + # RewardType, + # TerminalType, + # RenderStateType, + # # + # WrapperStateType, + # WrapperObsType, + # WrapperActType, + # WrapperRewardType, + # ], + # Generic[ + # StateType, + # ObsType, + # ActType, + # RewardType, + # TerminalType, + # RenderStateType, + # ], + # ): + """""" + + # def __init__( + # self, + # env: FuncEnv[ + # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + # ], + # ): + # """""" + # + # self.env = env + # assert isinstance(self.env.action_space.sample(), np.ndarray) + # assert isinstance(self.env.observation_space.sample(), np.ndarray) + + def __post_init__(self) -> None: + """""" + + super().__post_init__() + + msg = f"[{self.__class__.__name__}] enabled" + logging.debug(msg=msg) + + assert isinstance(self.env.action_space.sample(), np.ndarray) + assert isinstance(self.env.observation_space.sample(), np.ndarray) + + # def transition( + # self, state: WrapperStateType, action: WrapperActType, rng: Any = None + # ) -> WrapperStateType: + # """""" + # + # return ToNumPyWrapper.pytree_to_numpy( + # self.env.transition(state=state, action=action, rng=rng) + # ) + + def observation(self, state: WrapperStateType) -> WrapperObsType: + """""" + + observation = ToNumPyWrapper.pytree_to_numpy(self.env.observation(state=state)) + return np.array(observation, dtype=self.env.observation_space.dtype) + + def reward( + self, + state: WrapperStateType, + action: WrapperActType, + next_state: WrapperStateType, + ) -> WrapperRewardType: + """""" + + return float( + ToNumPyWrapper.pytree_to_numpy( + self.env.reward(state=state, action=action, next_state=next_state) + ) + ) + + def terminal(self, state: WrapperStateType) -> TerminalType: + """""" + + return ToNumPyWrapper.pytree_to_numpy(self.env.terminal(state=state)) + + def state_info(self, state: WrapperStateType) -> dict[str, Any]: + """""" + + return ToNumPyWrapper.pytree_to_numpy(self.env.state_info(state=state)) + + def step_info( + self, + state: WrapperStateType, + action: WrapperActType, + next_state: WrapperStateType, + ) -> dict[str, Any]: + """""" + + return ToNumPyWrapper.pytree_to_numpy( + self.env.step_info(state=state, action=action, next_state=next_state) + ) + + @staticmethod + def pytree_to_numpy(pytree: Any) -> Any: + """""" + + def convert_leaf(leaf: Any) -> Any: + """""" + + if ( + isinstance(leaf, (np.ndarray, jnp.ndarray)) + and leaf.size == 1 + and leaf.dtype == "bool" + ): + return bool(leaf) + + return np.array(leaf) + + return jax.tree_util.tree_map(lambda l: convert_leaf(l), pytree) diff --git a/src/jaxgym/wrappers/jax/transform.py b/src/jaxgym/wrappers/jax/transform.py new file mode 100644 index 000000000..3eae512a0 --- /dev/null +++ b/src/jaxgym/wrappers/jax/transform.py @@ -0,0 +1,69 @@ +import dataclasses +from typing import Callable + +import jax_dataclasses +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.jax import JaxDataclassEnv, JaxDataclassWrapper +from jaxgym.wrappers import TransformWrapper +from jaxsim import logging +from jaxsim.utils import JaxsimDataclass, Mutability + + +# TODO: Make Jit and Vmap explicit wrappers so that we can check that env is JaxDataClass? +@jax_dataclasses.pytree_dataclass +# class JaxTransformWrapper(TransformWrapper, JaxsimDataclass): +class JaxTransformWrapper(TransformWrapper, JaxDataclassWrapper): + """""" + + # env: JaxDataclassEnv[ + # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + # ] + + function: Callable[[Callable], Callable] = jax_dataclasses.static_field() + + transform_initial: dataclasses.InitVar[bool] = dataclasses.field(default=True) + transform_transition: dataclasses.InitVar[bool] = dataclasses.field(default=True) + transform_observation: dataclasses.InitVar[bool] = dataclasses.field(default=True) + transform_reward: dataclasses.InitVar[bool] = dataclasses.field(default=True) + transform_terminal: dataclasses.InitVar[bool] = dataclasses.field(default=True) + transform_state_info: dataclasses.InitVar[bool] = dataclasses.field(default=True) + transform_step_info: dataclasses.InitVar[bool] = dataclasses.field(default=True) + + def __post_init__( # noqa + self, + transform_initial: bool, + transform_transition: bool, + transform_observation: bool, + transform_reward: bool, + transform_terminal: bool, + transform_state_info: bool, + transform_step_info: bool, + ) -> None: + """""" + + JaxDataclassWrapper.__post_init__(self) + + msg = f"[{self.__class__.__name__}] function={self.function}" + logging.debug(msg=msg) + + with self.mutable_context(mutability=Mutability.MUTABLE): + # super().__post_init__( + super().__init__( + env=self.env, + function=self.function, + transform_initial=transform_initial, + transform_transition=transform_transition, + transform_observation=transform_observation, + transform_reward=transform_reward, + transform_terminal=transform_terminal, + transform_state_info=transform_state_info, + transform_step_info=transform_step_info, + ) diff --git a/src/jaxgym/wrappers/state.py b/src/jaxgym/wrappers/state.py new file mode 100644 index 000000000..223986126 --- /dev/null +++ b/src/jaxgym/wrappers/state.py @@ -0,0 +1,132 @@ +import abc +from typing import Any, Generic, TypeVar + +import numpy.typing as npt +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.functional import FuncWrapper + +WrapperStateType = TypeVar("WrapperStateType") +WrapperObsType = ObsType +WrapperActType = ActType +WrapperRewardType = RewardType + + +class StateWrapper( + FuncWrapper[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ], + Generic[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + WrapperStateType, + ], + abc.ABC, +): + """""" + + @abc.abstractmethod + def wrapper_state_to_environment_state( + self, wrapper_state: WrapperStateType + ) -> StateType: + """""" + + pass + + @abc.abstractmethod + def initial(self, rng: Any = None) -> WrapperStateType: + """""" + + pass + + @abc.abstractmethod + def transition( + self, state: WrapperStateType, action: WrapperActType, rng: Any = None + ) -> WrapperStateType: + """""" + + pass + + # + # + # + + def observation(self, state: WrapperStateType) -> WrapperObsType: + """""" + + return self.env.observation( + state=self.wrapper_state_to_environment_state(wrapper_state=state) + ) + + def reward( + self, + state: WrapperStateType, + action: WrapperActType, + next_state: WrapperStateType, + ) -> WrapperRewardType: + """""" + + return self.env.reward( + state=self.wrapper_state_to_environment_state(wrapper_state=state), + action=action, + next_state=self.wrapper_state_to_environment_state( + wrapper_state=next_state + ), + ) + + def terminal(self, state: WrapperStateType) -> TerminalType: + """""" + + return self.env.terminal( + state=self.wrapper_state_to_environment_state(wrapper_state=state) + ) + + def state_info(self, state: WrapperStateType) -> dict[str, Any]: + """Info dict about a single state.""" + + return self.env.state_info( + state=self.wrapper_state_to_environment_state(wrapper_state=state) + ) + + def step_info( + self, state: WrapperStateType, action: ActType, next_state: WrapperStateType + ) -> dict[str, Any]: + """Info dict about a full transition.""" + + return self.env.step_info( + state=self.wrapper_state_to_environment_state(wrapper_state=state), + action=action, + next_state=self.wrapper_state_to_environment_state( + wrapper_state=next_state + ), + ) + + def render_image( + self, state: StateType, render_state: RenderStateType + ) -> tuple[RenderStateType, npt.NDArray]: + """Render the state.""" + + return self.env.render_image( + state=self.wrapper_state_to_environment_state(wrapper_state=state), + render_state=render_state, + ) diff --git a/src/jaxgym/wrappers/transform.py b/src/jaxgym/wrappers/transform.py new file mode 100644 index 000000000..d568204a2 --- /dev/null +++ b/src/jaxgym/wrappers/transform.py @@ -0,0 +1,125 @@ +from typing import Callable, Generic + +from gymnasium.experimental.functional import ( + ActType, + ObsType, + RenderStateType, + RewardType, + StateType, + TerminalType, +) + +from jaxgym.functional import FuncEnv, FuncWrapper + +WrapperStateType = StateType +WrapperObsType = ObsType +WrapperActType = ActType +WrapperRewardType = RewardType + + +class TransformWrapper( + FuncWrapper[ + # + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + # + WrapperStateType, + WrapperObsType, + WrapperActType, + WrapperRewardType, + ], + Generic[ + StateType, + ObsType, + ActType, + RewardType, + TerminalType, + RenderStateType, + ], +): + """""" + + def __init__( + self, + env: FuncEnv[ + StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + ], + function: Callable[[Callable], Callable] = lambda f: f, + transform_initial: bool = True, + transform_transition: bool = True, + transform_observation: bool = True, + transform_reward: bool = True, + transform_terminal: bool = True, + transform_state_info: bool = True, + transform_step_info: bool = True, + ): + """""" + + self.env = env + + # Here to show up in repr + self.function = function + + if transform_initial: + self.initial = function(self.initial) + + if transform_transition: + self.transition = function(self.transition) + + if transform_observation: + self.observation = function(self.observation) + + if transform_reward: + self.reward = function(self.reward) + + if transform_terminal: + self.terminal = function(self.terminal) + + if transform_state_info: + self.state_info = function(self.state_info) + + if transform_step_info: + self.step_info = function(self.step_info) + + # @staticmethod + # def transform_env( + # env: FuncEnv[ + # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + # ], + # function: Callable[[Callable], Callable], + # transform_initial: bool = True, + # transform_transition: bool = True, + # transform_observation: bool = True, + # transform_reward: bool = True, + # transform_terminal: bool = True, + # transform_state_info: bool = True, + # transform_step_info: bool = True, + # ) -> FuncWrapper[ + # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + # ]: + # """""" + # + # if transform_initial: + # self.initial = function(self.initial) + # + # if transform_transition: + # self.transition = function(self.transition) + # + # if transform_observation: + # self.observation = function(self.observation) + # + # if transform_reward: + # self.reward = function(self.reward) + # + # if transform_terminal: + # self.terminal = function(self.terminal) + # + # if transform_state_info: + # self.state_info = function(self.state_info) + # + # if transform_step_info: + # self.step_info = function(self.step_info) diff --git a/src/jaxsim/simulation/integrators.py b/src/jaxsim/simulation/integrators.py index 71ab42c48..450c7bedc 100644 --- a/src/jaxsim/simulation/integrators.py +++ b/src/jaxsim/simulation/integrators.py @@ -338,7 +338,7 @@ def odeint_euler( t: TimeHorizon, *args, num_sub_steps: int = 1, - return_aux: bool = False + return_aux: bool = False, ) -> Union[State, Tuple[State, Dict[str, Any]]]: """ Integrate a system of ODEs using the Euler method. @@ -378,7 +378,7 @@ def odeint_euler_semi_implicit( t: TimeHorizon, *args, num_sub_steps: int = 1, - return_aux: bool = False + return_aux: bool = False, ) -> Union[State, Tuple[State, Dict[str, Any]]]: """ Integrate a system of ODEs using the Semi-Implicit Euler method. @@ -418,7 +418,7 @@ def odeint_rk4( t: TimeHorizon, *args, num_sub_steps: int = 1, - return_aux: bool = False + return_aux: bool = False, ) -> Union[State, Tuple[State, Dict[str, Any]]]: """ Integrate a system of ODEs using the Runge-Kutta 4 method. diff --git a/src/jaxsim/training/__init__.py b/src/jaxsim/training/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/jaxsim/training/agent.py b/src/jaxsim/training/agent.py new file mode 100644 index 000000000..3fc9edb40 --- /dev/null +++ b/src/jaxsim/training/agent.py @@ -0,0 +1,791 @@ +import copy +import dataclasses +import datetime +import functools +import pathlib +from typing import Any, Callable, Dict, Tuple, Union + +import flax.training.checkpoints +import flax.training.train_state +import gym.spaces +import jax +import jax.experimental.loops +import jax.numpy as jnp +import jax_dataclasses +import optax +from flax.core.frozen_dict import FrozenDict +from optax._src.alias import ScalarOrSchedule + +import jaxsim.typing as jtp +from jaxsim import logging +from jaxsim.utils import JaxsimDataclass + +from .memory import DataLoader, Memory +from .networks import ActorCriticNetworks, ActorNetwork, CriticNetwork + + +@jax_dataclasses.pytree_dataclass +class PPOParams: + # Gradient descent + alpha: float = jax_dataclasses.field(default=0.0003) + optimizer: Callable[ + [ScalarOrSchedule], optax.GradientTransformation + ] = jax_dataclasses.static_field(default=optax.adam) + + # RL Algorithm + gamma: float = jax_dataclasses.field(default=0.99) + lambda_gae: float = jax_dataclasses.field(default=0.95) + + # PPO + # beta_kl: float = jax_dataclasses.static_field(default=0.0) + beta_kl: float = jax_dataclasses.field(default=0.0) + target_kl: float = jax_dataclasses.field(default=0.010) + epsilon_clip: float = jax_dataclasses.field(default=0.2) + + # Other params + entropy_loss_weight: float = jax_dataclasses.field(default=0.0) + + +class PPOTrainState(flax.training.train_state.TrainState): + beta_kl: float + + +@dataclasses.dataclass(frozen=True) +class CheckpointManager: + checkpoint_path: pathlib.Path + + def save_best( + self, + train_state: PPOTrainState, + measure: str, + metric: Union[int, float], + keep: int = 1, + ) -> None: + checkpoint_path = self.checkpoint_path / f"best_{measure}" + + path_to_checkpoint = flax.training.checkpoints.save_checkpoint( + ckpt_dir=checkpoint_path, + target=train_state, + step=metric, + prefix=f"checkpoint_", + keep=keep, + overwrite=True, + ) + + logging.info(msg=f"Saved checkpoint: {path_to_checkpoint}") + + def save_latest( + self, train_state: PPOTrainState, keep_every_n_steps: int = None + ) -> None: + path_to_checkpoint = flax.training.checkpoints.save_checkpoint( + ckpt_dir=self.checkpoint_path, + target=train_state, + step=train_state.step, + prefix=f"checkpoint_", + keep=1, + keep_every_n_steps=keep_every_n_steps, + ) + + logging.info(msg=f"Saved checkpoint: {path_to_checkpoint}") + + def load_best( + self, dummy_train_state: PPOTrainState, measure: str + ) -> PPOTrainState: + checkpoint_path = self.checkpoint_path / f"best_{measure}" + + train_state = flax.training.checkpoints.restore_checkpoint( + ckpt_dir=checkpoint_path, + target=dummy_train_state, + prefix=f"checkpoint_", + ) + + return train_state + + def load_latest(self, dummy_train_state: PPOTrainState) -> PPOTrainState: + train_state = flax.training.checkpoints.restore_checkpoint( + ckpt_dir=self.checkpoint_path, + target=dummy_train_state, + prefix=f"checkpoint_", + ) + + return train_state + + +@jax_dataclasses.pytree_dataclass +class Agent(JaxsimDataclass): + key: jax.random.PRNGKey = jax_dataclasses.field( + default_factory=lambda: jax.random.PRNGKey(seed=0), repr=False + ) + + params: PPOParams = jax_dataclasses.field(default_factory=PPOParams) + train_state: PPOTrainState = jax_dataclasses.field(default=None, repr=False) + + action_space: gym.spaces.Box = jax_dataclasses.static_field(default=None) + observation_space: gym.spaces.Box = jax_dataclasses.static_field(default=None) + + checkpoint_manager: CheckpointManager = jax_dataclasses.static_field(default=None) + + _num_timesteps: int = jax_dataclasses.field(default=0) + _num_iterations: int = jax_dataclasses.field(default=0) + + _min_reward: float = jnp.finfo(jnp.float32).max + _max_reward: float = jnp.finfo(jnp.float32).min + + @staticmethod + def build( + actor: ActorNetwork = None, + critic: CriticNetwork = None, + action_space: gym.spaces.Space = None, + observation_space: gym.spaces.Space = None, + key: jax.random.PRNGKey = jax.random.PRNGKey(seed=0), + params: PPOParams = PPOParams(), + train_state: PPOTrainState = None, # TODO remove? + checkpoint_label: str = "training", + load_checkpoint_from_path: pathlib.Path = None, + ) -> "Agent": + date_str = datetime.datetime.now().strftime("%y%m%d_%H%M") + checkpoint_folder = f"{checkpoint_label}_{date_str}" + + checkpoint_path = ( + pathlib.Path("~/jaxsim_results").expanduser() / checkpoint_folder + ) + + agent = Agent( + key=key, + params=params, + train_state=train_state, + action_space=action_space, + observation_space=observation_space, + checkpoint_manager=CheckpointManager(checkpoint_path), + ) + + if agent.train_state is None: + # Generate a new network RNG key + key = agent.advance_key(num_sub_keys=1) + + # Get a dummy observation + observation_dummy = jnp.zeros(shape=observation_space.shape) + + # Create the actor/critic network + actor_critic = ActorCriticNetworks(actor=actor, critic=critic) + + # Initialize the actor/critic network + out, params_critic = actor_critic.init_with_output(key, observation_dummy) + distribution_dummy, value_dummy = out + + # Check value shape + if value_dummy.shape != (1,): + raise ValueError(value_dummy.shape, (1,)) + + # Check action shape + dummy_action = distribution_dummy.sample(seed=key) + if dummy_action.shape != action_space.shape: + raise ValueError(dummy_action.shape, action_space.shape) + + # Initialize the actor/critic train state + train_state = PPOTrainState.create( + apply_fn=actor_critic.apply, + params=params_critic, + tx=agent.params.optimizer(agent.params.alpha), + beta_kl=agent.params.beta_kl, + ) + + with agent.editable(validate=False) as agent: + agent.train_state = train_state + + # Replace the actor/critic train state + # agent = jax_dataclasses.replace(agent, train_state=train_state) # noqa + + if load_checkpoint_from_path is not None: + load_checkpoint_from_path = ( + load_checkpoint_from_path.expanduser().absolute() + ) + + if not load_checkpoint_from_path.exists(): + raise FileExistsError(load_checkpoint_from_path) + + with agent.editable(validate=False) as agent: + agent.train_state = flax.training.checkpoints.restore_checkpoint( + ckpt_dir=load_checkpoint_from_path, + target=agent.train_state, + prefix=f"checkpoint_", + ) + + return agent + + # /tmp/jaxsim/cartpole_20220614_1756/checkpoint_#i + # /tmp/jaxsim/cartpole_20220614_1756/reward/checkpoint_REWARD + # /tmp/jaxsim/cartpole_20220614_1756/episode_steps/checkpoint_REWARD + + def save_checkpoint( + self, + prefix: str = "checkpoint_", + checkpoint_path: pathlib.Path = pathlib.Path.cwd() / "checkpoints", + ) -> None: + path_to_checkpoint = flax.training.checkpoints.save_checkpoint( + target=self.train_state, + prefix=prefix, + ckpt_dir=checkpoint_path, + step=self.train_state.step, + ) + + logging.info(msg=f"Save checkpoint: {path_to_checkpoint}") + + def load_checkpoint( + self, checkpoint_path: pathlib.Path = pathlib.Path.cwd() / "checkpoints" + ) -> "Agent": + train_state = flax.training.checkpoints.restore_checkpoint( + ckpt_dir=checkpoint_path, target=self.train_state + ) + + with self.editable(validate=True) as agent: + agent.train_state = train_state + + agent._set_mutability(mutability=self._mutability()) + return agent + + # return jax_dataclasses.replace(self, train_state=train_state) + + def advance_key(self, num_sub_keys: int = 1) -> jax.random.PRNGKey: + keys = jax.random.split(key=self.key, num=num_sub_keys + 1) + + object.__setattr__(self, "key", keys[0]) + + return keys[1:].squeeze() + + def advance_key2(self, num_sub_keys: int = 1) -> Tuple[jax.random.PRNGKey, "Agent"]: + keys = jax.random.split(key=self.key, num=num_sub_keys + 1) + return keys[1:].squeeze(), jax_dataclasses.replace(self, key=keys[0]) # noqa + + def choose_action( + self, + observation: jtp.Vector, + explore: bool = True, + key: jax.random.PRNGKey = None, + ) -> Tuple[jtp.VectorJax, jtp.VectorJax, jtp.VectorJax]: + # Get a new key if not passed + key = key if key is not None else self.advance_key() + + # Infer π_θ(⋅|sₜ) and V_ϕ(sₜ) + distribution, value = self.train_state.apply_fn( + self.train_state.params, data=observation + ) + + # Sample an action from π_θ(⋅|sₜ) + action = jax.lax.select( + pred=explore, + on_true=distribution.sample(seed=key), + on_false=distribution.mode(), + ) + + # Compute log-likelihood of action: log[π_θ(aₜ|sₜ)] + log_prob_action = distribution.log_prob(value=action) + + return action, log_prob_action, value + # return ( + # jnp.array(action).squeeze(), + # jnp.array(log_prob_action).squeeze(), + # jnp.array(value).squeeze(), + # ) + + @staticmethod + @functools.partial(jax.jit) + def estimate_advantage_gae_jit( + train_state: PPOTrainState, memory: Memory, gamma: float, lambda_gae: float + ) -> jtp.VectorJax: + # Closure that computes V(o). Used to boostrap the return when necessary. + value_of = lambda observation: train_state.apply_fn( + train_state.params, data=observation + )[1].squeeze() + + return Agent.estimate_advantage_gae( + memory=memory, gamma=gamma, lambda_gae=lambda_gae, V=value_of + ) + + @staticmethod + def estimate_advantage_gae( + memory: Memory, + gamma: float, + lambda_gae: float, + V: Callable[[jtp.ArrayJax], jtp.ArrayJax] = lambda _: 0.0, + ) -> jtp.VectorJax: + r = memory.rewards + mask = 1 - memory.dones.astype(dtype=int) + + # Extract additional data from the info dictionary + is_terminal = memory.infos["is_terminal"] + next_observations = memory.infos["terminal_observation"] + + # The last trajectory in the memory is likely truncated, i.e. is_done[-1] = 0. + # In order to estimate Â, we get the next observation stored in the info dict + # and boostrap the return of the last observation of the trajectory with TD(0). + Vs = jnp.hstack([memory.values.squeeze(), V(next_observations[-1])]) + + with jax.experimental.loops.Scope() as s: + # Allocate the Âₜ array (adding a trailing zero, as Vs). + # We cannot know how the trajectory will continue from the truncated last + # trajectory of the memory. In this case, since we have computed the value + # of the next observation, we estimate A of the last sample using TD(0) + # returns, i.e. Âₕ = δₕ = rₕ + γ⋅V(sₕ₊₁) - V(sₕ). + # This can be done either setting Âₕ₊₁ = 0 or, equivalently, λₕ = 0. + # We proceed with the former case. + s.A = jnp.zeros_like(Vs) + + # Iteration same as: reversed(mask.size) + for t in s.range(mask.size - 1, -1, -1): + # Classic TD(0) boostrap: δₜ = rₜ + γ⋅V(sₜ₊₁) - V(sₜ) + delta_t_not_done = r[t] + gamma * Vs[t + 1] - Vs[t] + + # Use one step of Monte Carlo if terminal, TD(0) otherwise. + # This allows to consider two different types of termination: + # - MC: Reached a terminal state s_T. All the rewards following s_T + # are considered to be 0. + # - TD(0): The trajectory reached the maximum length, and it was + # truncated early (common in continuous control). We cannot + # assume that the reward would be 0 after truncation, therefore + # we boostrap the return with TD(0) (similarly to what we do + # for the last truncated trajectory of the memory). + delta_t_done = jax.lax.select( + pred=is_terminal.astype(dtype=bool)[t], + on_true=r[t] - Vs[t], + on_false=r[t] + gamma * V(next_observations[t]) - Vs[t], + ) + + # Select the right δₜ + delta_t = jax.lax.select( + pred=memory.dones.astype(dtype=bool)[t], + on_true=delta_t_done, + on_false=delta_t_not_done, + ) + + # Compute the advantage estimate Âₜ + A_t = delta_t + gamma * lambda_gae * s.A[t + 1] * mask[t] + s.A = s.A.at[t].set(A_t.squeeze()) + + # Remove the trailing value we added to handle the last entry of the memory + A = s.A[:-1] + + return jax.lax.stop_gradient(jnp.vstack(A)) + + @staticmethod + def compute_reward_to_go( + memory: Memory, + gamma: float = 1.0, + V: Callable[[jtp.ArrayJax], jtp.ArrayJax] = lambda _: 0.0, + ) -> jtp.VectorJax: + assert memory.flat is True + + r = memory.rewards.squeeze() + r_to_go = jnp.zeros_like(r) + + dones = memory.dones.astype(dtype=bool).squeeze() + is_terminal = memory.infos["is_terminal"].astype(dtype=bool).squeeze() + next_observations = memory.infos["terminal_observation"].squeeze() + + with jax.experimental.loops.Scope() as s: + # Approximate the return of the state following the last one with its value. + # Note: we store next_observation in the info dict for this reason. + s.r_to_go = jnp.hstack([r_to_go, V(next_observations[-1])]) + + # Iteration same as: reversed(dones.size) + for t in s.range(dones.size - 1, -1, -1): + # Monte Carlo discounted accumulation. + # Note: considering r_to_go[-1], this is TD(0) for the last memory entry. + r_to_go_not_done = r[t] + gamma * s.r_to_go[t + 1] + + # If done, use one step of Monte Carlo if terminal, TD(0) otherwise + r_to_go_done = jax.lax.select( + pred=is_terminal[t], + on_true=r[t], + on_false=r[t] + gamma * V(next_observations[t]), + ) + + # Select the right Rₜ + r_to_go_t = jax.lax.select( + pred=dones[t], + on_true=r_to_go_done, + on_false=r_to_go_not_done, + ) + + # Store Rₜ in the buffer + s.r_to_go = s.r_to_go.at[t].set(r_to_go_t.squeeze()) + + # Remove the trailing value we added to handle the last entry of the memory + r_to_go = s.r_to_go[:-1] + + return jax.lax.stop_gradient(jnp.vstack(r_to_go)) + + @staticmethod + @functools.partial(jax.jit) + def compute_reward_to_go_jit( + train_state: PPOTrainState, memory: Memory, gamma: float + ) -> jtp.VectorJax: + # Closure that computes V(o). Used to boostrap the return when necessary. + value_of = lambda observation: train_state.apply_fn( + train_state.params, data=observation + )[1].squeeze() + + return Agent.compute_reward_to_go(memory=memory, gamma=gamma, V=value_of) + + @staticmethod + def explained_variance(y_hat: jtp.Array, y: jtp.Array) -> jtp.Array: + assert y_hat.ndim == y.ndim == 1 + + var_y = jnp.var(y) + + return jax.lax.select( + pred=(var_y == 0.0), + on_true=jnp.nan, + on_false=(1 - jnp.var(y - y_hat) / var_y), + ) + + @staticmethod + @functools.partial(jax.jit) + def train_actor_critic( + train_state_target: PPOTrainState, + train_state_behavior: PPOTrainState, + memory: Memory, + returns_target: jtp.VectorJax, + policy_gradient_loss_weight: jtp.VectorJax, + ppo_params: PPOParams = PPOParams(), + ) -> Tuple[PPOTrainState, Dict]: + # Adjust 1D arrays + returns_target = returns_target.squeeze() + policy_gradient_loss_weight = policy_gradient_loss_weight.squeeze() + + # Assume memory has vertical 1D arrays + mem_values = jnp.vstack(memory.values.squeeze()) + mem_actions = jnp.vstack(memory.actions.squeeze()) + mem_observations = jnp.vstack(memory.states.squeeze()) + mem_log_prob_actions = jnp.vstack(memory.log_prob_actions.squeeze()) + + # Loss function for both the actor and critic networks. + # Note: we do not support sharing layers, therefore weighting differently + # the two losses should not be relevant. + def loss_fn(params: flax.core.FrozenDict[str, Any]) -> Tuple[float, Dict]: + # Infer π_θₙ(⋅|sₜ) and V_ϕₙ(sₜ) with the new parameters (θₙ, ϕₙ) + new_distributions, new_values = train_state_target.apply_fn( + params, + data=mem_observations, + ) + + # ====== + # Critic + # ====== + + # The loss uses the returns as targets. Returns could be computed with + # a (possibly discounted) reward-to-go or from the estimated advantage. + + # Compute the MSE loss + new_values = new_values.squeeze() + critic_loss = jnp.linalg.norm(new_values - returns_target, ord=2) + # TODO: should this be norm squared? + + # Compute the explained variance. It should start as a very negative number + # and progressively converge towards 1.0 when the value function learned to + # approximate correctly the sampled return. + returns = policy_gradient_loss_weight.flatten() + mem_values.flatten() + explained_variance = Agent.explained_variance( + y=returns, y_hat=mem_values.flatten() + ) + + # ===== + # Actor + # ===== + + # Refer to https://arxiv.org/abs/1707.06347 for the surrogate functions + # used for both the CLIP and the KLPEN versions of PPO. + + # Infer π_θₒ(⋅|sₜ) with the old parameters θₒ. Used only for KLPEN. + old_distributions, old_values = train_state_behavior.apply_fn( + train_state_behavior.params, + data=mem_observations, + ) + + # Compute new log-likelihood of actions: log[π_θₙ(aₜ|sₜ)] + new_log_prob_actions = new_distributions.log_prob(value=mem_actions) + + # Rename the old log-likelihood of actions: log[π_θₒ(aₜ|sₜ)] + old_log_prob_actions = mem_log_prob_actions.squeeze() + + # Compute the ratio rₜ(θₙ) of the likelihoods + prob_action_ratio = jnp.exp(new_log_prob_actions - old_log_prob_actions) + + # Compute the CPI surrogate objective + L = prob_action_ratio * policy_gradient_loss_weight + + # Compute the clipped version of the ratio of the likelihoods + prob_action_ratio_clipped = jnp.clip( + a=prob_action_ratio, + a_min=(1.0 - ppo_params.epsilon_clip), + a_max=(1.0 + ppo_params.epsilon_clip), + ) + + # Compute the clip ratio + clipped_elements = jnp.where( + prob_action_ratio != prob_action_ratio_clipped, 1, 0 + ) + clip_ratio = clipped_elements.sum() / clipped_elements.size + + # Apply the CLIP surrogate objective. + # Note: 'epsilon_clip' is zero if CLIP is not enabled. + L = jax.lax.select( + pred=(ppo_params.epsilon_clip == 0), + on_true=L, + on_false=jnp.minimum( + L, prob_action_ratio_clipped * policy_gradient_loss_weight + ), + ).mean() + + # Compute the additional KLPEN surrogate objective term. + # Note: 'beta_kl' is zero if KLPEN is not enabled. + distr_kl = old_distributions.kl_divergence(other_dist=new_distributions) + ppo_klpen_term = train_state_behavior.beta_kl * distr_kl.mean() + + # Apply the KLPEN surrogate objective term + L -= ppo_klpen_term + + # Compute the loss to minimize from the surrogate objective + actor_loss = -L + + # Optional entropy reward + entropy_mean = new_distributions.entropy().mean() + entropy_reward = ppo_params.entropy_loss_weight * entropy_mean + + return ( + actor_loss - entropy_reward + 0.100 * critic_loss, + dict( + actor_loss=actor_loss, + critic_loss=critic_loss, + entropy=entropy_mean, + entropy_reward=entropy_reward, + kl=distr_kl.mean(), + beta_kl=train_state_behavior.beta_kl, + clip_ratio=clip_ratio, + explained_variance=explained_variance, + ), + ) + + # Commented-out code to check gradients wrt finite differences + # from jax.test_util import check_grads + # check_grads(loss_fn, (train_state_target.params,), order=1, eps=1e-4) + + # Compute the loss and its gradient wrt the NN parameters + (total_loss, loss_fn_data), grads = jax.value_and_grad( + fun=loss_fn, has_aux=True + )(train_state_target.params) + + # Pass the gradient to the optimizer and get a new state + new_train_state = train_state_target.apply_gradients(grads=grads) + + return new_train_state, dict(total_loss=total_loss, **loss_fn_data) + + @staticmethod + @functools.partial(jax.jit) + def adaptive_update_beta_ppo_kl_pen( + train_state_behavior: PPOTrainState, + train_state_target: PPOTrainState, + ppo_params: PPOParams, + memory: Memory, + ) -> PPOTrainState: + # Refer to https://arxiv.org/abs/1707.06347 (Sec. 4) for the update rule of β_KL. + # We use the default heuristic 1.5 and 2 parameters as reported in the paper. + + # Infer π_θₙ(⋅|sₜ) with the new parameters + new_distributions, _ = train_state_target.apply_fn( + train_state_target.params, data=memory.states + ) + + # Infer π_θₒ(⋅|sₜ) with the old parameters + old_distributions, _ = train_state_behavior.apply_fn( + train_state_behavior.params, data=memory.states + ) + + # Compute the KL divergence of the old policy from the new policy + kl = old_distributions.kl_divergence(other_dist=new_distributions).mean() + + # Get the old β_KL used as weight of PPO-KLPEN surrogate objective + beta_kl = train_state_behavior.beta_kl + + # Increase β_KL + train_state_target = jax.lax.cond( + pred=(kl > ppo_params.target_kl * 1.5), + true_fun=lambda _: train_state_target.replace(beta_kl=beta_kl * 2.0), + false_fun=lambda _: train_state_target, + operand=(), + ) + + # Decrease β_KL + train_state_target = jax.lax.cond( + pred=(kl < ppo_params.target_kl / 1.5), + true_fun=lambda _: train_state_target.replace(beta_kl=beta_kl * 0.5), + false_fun=lambda _: train_state_target, + operand=(), + ) + + return train_state_target + + def train( + self, + memory: Memory, + num_epochs: int = 1, + batch_size: int = 512, + print_report: bool = False, + ) -> "Agent": + # Truncate the last trajectory by modifying the last is_done. In this way we can + # bootstrap correctly its return considering the last sample as non-terminal. + # Also flatten the memory if it was sampled from parallel environments. + memory_flat = memory.truncate_last_trajectory().flatten() + + mean_reward = memory_flat.rewards.mean() + self.checkpoint_manager.save_best( + train_state=self.train_state, measure="reward", metric=mean_reward, keep=5 + ) + + self.checkpoint_manager.save_latest( + train_state=self.train_state, keep_every_n_steps=25 + ) + + # Update the training metadata + with self.editable(validate=True) as agent: + agent._num_iterations += 1 + agent._num_timesteps += len(memory) + + # Update the training metadata + # agent = jax_dataclasses.replace( + # self, # noqa + # _num_iterations=(self._num_iterations + 1), + # _num_timesteps=(self._num_timesteps + len(memory)), + # ) + + # Log the min/max reward ever seen + min_reward = jnp.array([agent._min_reward, memory.rewards.min()]).min() + max_reward = jnp.array([agent._max_reward, memory.rewards.max()]).max() + agent = jax_dataclasses.replace( + agent, _min_reward=min_reward, _max_reward=max_reward # noqa + ) + + # Estimate the advantages Â_π_θₒ(sₜ, aₜ) with GAE used to train the actor. + advantages = Agent.estimate_advantage_gae_jit( + train_state=agent.train_state, + memory=memory_flat, + gamma=agent.params.gamma, + lambda_gae=agent.params.lambda_gae, + ) + + # Compute the rewards-to-go R̂ₜ used to train the critic. + # Note: same of estimating Â_GAE with λ=1.0 and γ=1.0. + rewards_to_go = Agent.compute_reward_to_go_jit( + train_state=agent.train_state, + memory=memory_flat, + gamma=agent.params.gamma, + ) + + # Select the returns to use. We can use any of the following: + # - Rewards-to-go: R = R̂ₜ + # - GAE advantages plus values: R = Â_GAE + V + # returns = rewards_to_go + returns = advantages + memory_flat.values + + # Store the behavior train state (old parameters θₒ and ϕₒ). + # Note: for safety, we make sure that params do not get overridden by the + # next optimization by taking their deep copy. + train_state_behaviour = agent.train_state.replace( + params=copy.deepcopy(agent.train_state.params) + ) + + for epoch_idx in range(num_epochs): + for batch_slice in DataLoader(memory=memory_flat).batch_slices_generator( + batch_size=batch_size, + shuffle=True, + seed=epoch_idx, + allow_partial_batch=False, + ): + # Perform one step of gradient descent + train_state, extra_data = Agent.train_actor_critic( + train_state_behavior=train_state_behaviour, + train_state_target=agent.train_state, + memory=memory_flat[batch_slice], + returns_target=returns[batch_slice], + policy_gradient_loss_weight=advantages[batch_slice], + ppo_params=agent.params, + ) + + # Update the agent with the new train state + agent = jax_dataclasses.replace(agent, train_state=train_state) + + # TODO + # Create the log data of the training step + log_data = FrozenDict( + extra_data, reward_range=(agent._min_reward, agent._max_reward) + ) + + # Print to output + # TODO: logging? + print(log_data) + + # ====================================== + # Update the KL parameters for PPO-KLPEN + # ====================================== + + # @jax.jit + # def update_beta_kl(agent: Agent) -> Agent: + # + # # Update the train state with the adjusted β_KL + # train_state_kl = Agent.adaptive_update_beta_ppo_kl_pen( + # train_state_behavior=train_state_behaviour, + # train_state_target=agent.train_state, + # ppo_params=agent.params, + # memory=memory_flat, + # ) + # + # # Update the agent with the new train state + # return jax_dataclasses.replace(agent, train_state=train_state_kl) + + # agent = jax.lax.cond( + # pred=agent.params.beta_kl != 0.0, + # true_fun=update_beta_kl, + # false_fun=lambda agent: agent, + # operand=agent, + # ) + + if agent.params.beta_kl != 0.0: + # Update the train state with the adjusted β_KL + train_state_kl = Agent.adaptive_update_beta_ppo_kl_pen( + train_state_behavior=train_state_behaviour, + train_state_target=agent.train_state, + ppo_params=agent.params, + memory=memory_flat, + ) + + # Update the agent with the new train state + agent = jax_dataclasses.replace(agent, train_state=train_state_kl) + + if not print_report: + return agent + + # =================== + # Print training data + # =================== + + avg_length_of_trajectories = [] + + for trajectory in memory.trajectories(): + avg_length_of_trajectories.append(trajectory.rewards.size) + + from rich.console import Console + from rich.table import Table + + table = Table(title=f"Iteration #{agent._num_iterations}") + table.add_column("Name", justify="left", style="cyan", no_wrap=True) + table.add_column("Value", justify="right", style="cyan", no_wrap=True) + + table.add_row("Epochs", f"{num_epochs}") + table.add_row("Timesteps", f"{len(memory)}") + table.add_row("Total timesteps", f"{agent._num_timesteps}") + table.add_row("Avg reward", f"{float(memory_flat.rewards.mean()):10.4f}") + console = Console() + print() + console.print(table) + + return agent diff --git a/src/jaxsim/training/distributions.py b/src/jaxsim/training/distributions.py new file mode 100644 index 000000000..ae60d6dbe --- /dev/null +++ b/src/jaxsim/training/distributions.py @@ -0,0 +1,260 @@ +from typing import Optional, Sequence, Tuple, Union + +import distrax +import jax.numpy as jnp +import jax.scipy.special +from distrax import MultivariateNormalDiag +from distrax._src.distributions import distribution + + +class SquashedMultivariateNormalDiagBase(distribution.Distribution): + def __init__( + self, + low: distribution.Array, + high: distribution.Array, + loc: Optional[distribution.Array] = None, + scale_diag: Optional[distribution.Array] = None, + ): + self.normal = MultivariateNormalDiag(loc=loc, scale_diag=scale_diag) + self.low, _ = jnp.broadcast_arrays(low, self.normal.loc) + self.high, _ = jnp.broadcast_arrays(high, self.normal.loc) + + def event_shape(self) -> Tuple[int, ...]: + return self.normal.event_shape + + def _sample_n(self, key: jax.random.PRNGKey, n: int) -> distribution.Array: + samples_normal = self.normal._sample_n(key=key, n=n) + return self._squash(unsquashed_sample=samples_normal) + + def _squash(self, unsquashed_sample: distribution.Array) -> distribution.Array: + raise NotImplementedError + + def _unsquash(self, squashed_sample: distribution.Array) -> distribution.Array: + raise NotImplementedError + + def _log_abs_det_grad_squash( + self, unsquashed_sample: distribution.Array + ) -> distribution.Array: + raise NotImplementedError + + def entropy(self) -> distribution.Array: + raise NotImplementedError + + def mean(self) -> distribution.Array: + mean_normal = self.normal.mean() + return self._squash(unsquashed_sample=mean_normal) + + def median(self) -> distribution.Array: + median_normal = self.normal.median() + return self._squash(unsquashed_sample=median_normal) + + def variance(self) -> distribution.Array: + raise NotImplementedError + + def stddev(self) -> distribution.Array: + raise NotImplementedError + + def mode(self) -> distribution.Array: + mode_normal = self.normal.mode() + return self._squash(unsquashed_sample=mode_normal) + + def cdf(self, value: distribution.Array) -> distribution.Array: + raise NotImplementedError + + def log_cdf(self, value: distribution.Array) -> distribution.Array: + raise NotImplementedError + + def sample( + self, + *, + seed: Union[distribution.IntLike, distribution.PRNGKey], + sample_shape: Union[distribution.IntLike, Sequence[distribution.IntLike]] = (), + ) -> distribution.Array: + # Sample from the MultivariateNormal distribution + unsquashed_sample = self.normal.sample(seed=seed, sample_shape=sample_shape) + + # Squash the sample into [low, high] + squashed_sample = self._squash(unsquashed_sample=unsquashed_sample) + assert squashed_sample.shape == unsquashed_sample.shape + + return squashed_sample + + def log_prob(self, value: distribution.Array) -> distribution.Array: + # Unsquash from [low, high] to ]-∞, ∞[ + value_unsquashed = self._unsquash(squashed_sample=value) + + # Compute the log-prob of the underlying normal distribution + log_prob_norm = self.normal.log_prob(value=value_unsquashed) + + # Compute the correction term due to squashing + log_abs_det_grad_squash = self._log_abs_det_grad_squash( + unsquashed_sample=value_unsquashed + ) + + # Adjust the log-prob with the gradient of the squashing function + return log_prob_norm - log_abs_det_grad_squash + + def kl_divergence( + self, other_dist: "SquashedMultivariateNormalDiagBase", **kwargs + ) -> distribution.Array: + if not isinstance(other_dist, type(self)): + raise (TypeError(other_dist), type(self)) + + # The squashing function does not influence the KL divergence + return self.normal.kl_divergence(other_dist=other_dist.normal) + + +class GaussianSquashedMultivariateNormalDiag(SquashedMultivariateNormalDiagBase): + Epsilon: float = 1e-5 + ScaleOfSquashingGaussian = 0.5 * 1.8137 + + def __init__( + self, + low: distribution.Array, + high: distribution.Array, + loc: Optional[distribution.Array] = None, + scale_diag: Optional[distribution.Array] = None, + ): + # Clip the mean of the underlying gaussian in a valid range + loc = jnp.clip(a=loc, a_min=-4, a_max=4) + + # Initialize base class + super().__init__(low=low, high=high, loc=loc, scale_diag=scale_diag) + + # Create the gaussian from which we take the CDF as squashing function. + # Note: the default scale approximates the CDF to tanh. + self.squash_dist = distrax.MultivariateNormalDiag( + loc=jnp.array([0.0]), scale_diag=jnp.array([self.ScaleOfSquashingGaussian]) + ) + + def _squash_dist_cdfi(self, value: distribution.Array) -> distribution.Array: + def unstandardize( + dist: MultivariateNormalDiag, value_std: distribution.Array + ) -> distribution.Array: + return value_std * dist.scale_diag + dist.loc + + return unstandardize( + dist=self.squash_dist, value_std=jax.scipy.special.ndtri(p=value) + ) + + def _squash(self, unsquashed_sample: distribution.Array) -> distribution.Array: + # Squash the input sample into the [0, 1] interval + squashed_sample = jnp.vectorize(self.squash_dist.cdf)(unsquashed_sample) + + # Adjust boundaries in order to prevent getting infinite log-prob + clipped_squashed_sample = jnp.clip( + squashed_sample, a_min=self.Epsilon, a_max=(1 - self.Epsilon) + ) + + # Project the squashed sample into the output space [low, high] + return clipped_squashed_sample * (self.high - self.low) + self.low + + def _unsquash(self, squashed_sample: distribution.Array) -> distribution.Array: + import jaxsim + + # if ( + # not jaxsim.utils.tracing(squashed_sample) + # and jnp.where(squashed_sample > self.high, True, False).any() + # ): + # raise ValueError(squashed_sample, self.high) + # + # if ( + # not jaxsim.utils.tracing(squashed_sample) + # and jnp.where(squashed_sample < self.low, True, False).any() + # ): + # raise ValueError(squashed_sample, self.low) + # Project the squashed sample into the normalized space [0, 1] + normalized_squashed_example = (squashed_sample - self.low) / ( + self.high - self.low + ) + + # Unsquash the sample + return self._squash_dist_cdfi(value=normalized_squashed_example) + + def _log_abs_det_grad_squash( + self, unsquashed_sample: distribution.Array + ) -> distribution.Array: + # Compute the log-grad of the squashing function + log_grad = jnp.vectorize(self.squash_dist.log_prob)(unsquashed_sample) + log_grad += jnp.log(self.high - self.low) # TODO: add this again + + # Adjust size + log_grad = log_grad if log_grad.ndim > 1 else jnp.array([log_grad]) + + # Sum over sample dimension (i.e. return a single value for each sample) + return log_grad.sum(axis=1) + + def entropy(self) -> distribution.Array: + return ( + -self.normal.kl_divergence(other_dist=self.squash_dist) + + jnp.log(self.high - self.low).sum() + ) + + +class TanhSquashedMultivariateNormalDiag(SquashedMultivariateNormalDiagBase): + Epsilon: float = 1e-6 + + def __init__( + self, + low: distribution.Array, + high: distribution.Array, + loc: Optional[distribution.Array] = None, + scale_diag: Optional[distribution.Array] = None, + ): + # Clip the mean of the underlying gaussian in a valid range + loc = jnp.clip(a=loc, a_min=-4, a_max=4) + + # Initialize base class + super().__init__(low=low, high=high, loc=loc, scale_diag=scale_diag) + + def _squash(self, unsquashed_sample: distribution.Array) -> distribution.Array: + # Squash the input sample into the [0, 1] interval + squashed_sample = (jnp.tanh(unsquashed_sample) + 1) / 2 + + # Project the squashed sample into the output space [low, high] + return squashed_sample * (self.high - self.low) + self.low + + def _unsquash(self, squashed_sample: distribution.Array) -> distribution.Array: + import jaxsim + + # if ( + # not jaxsim.utils.tracing(squashed_sample) + # and jnp.where(squashed_sample > self.high, True, False).any() + # ): + # raise ValueError(squashed_sample, self.high) + # if ( + # not jaxsim.utils.tracing(squashed_sample) + # and jnp.where(squashed_sample < self.low, True, False).any() + # ): + # raise ValueError(squashed_sample, self.low) + # Project the squashed sample into the normalized space [0, 1] + normalized_squashed_sample = (squashed_sample - self.low) / ( + self.high - self.low + ) + + # Project into [-1, 1] + normalized_squashed_sample = normalized_squashed_sample * 2 - 1.0 + + # Clip so that tanh output is stabilized + clipped_squashed_sample = jnp.clip( + normalized_squashed_sample, -1.0 + self.Epsilon, 1.0 - self.Epsilon + ) + + # Unsquash the sample + return jnp.arctanh(clipped_squashed_sample) + + def _log_abs_det_grad_squash( + self, unsquashed_sample: distribution.Array + ) -> distribution.Array: + # Compute the log-grad of the squashing function + log_grad = jnp.log(1 - jnp.tanh(unsquashed_sample) ** 2 + self.Epsilon) + log_grad += jnp.log(self.high - self.low) - jnp.log(2) + + # Adjust size + log_grad = log_grad if log_grad.ndim > 1 else jnp.array([log_grad]) + + # Sum over sample dimension (i.e. return a single value for each sample) + return log_grad.sum(axis=1) + + def entropy(self) -> distribution.Array: + return jnp.zeros_like(self.low).sum(axis=1) diff --git a/src/jaxsim/training/memory.py b/src/jaxsim/training/memory.py new file mode 100644 index 000000000..1c6339e11 --- /dev/null +++ b/src/jaxsim/training/memory.py @@ -0,0 +1,299 @@ +import operator +from typing import Generator, List, NamedTuple, Sequence, Tuple, Union + +import jax +import jax.flatten_util +import jax.numpy as jnp +import jax_dataclasses +import numpy as np +import numpy.typing as npt +from flax.core.frozen_dict import FrozenDict + +import jaxsim.typing as jtp + + +@jax_dataclasses.pytree_dataclass +class Memory(Sequence): + states: jtp.MatrixJax + actions: jtp.MatrixJax + rewards: jtp.VectorJax + dones: jtp.VectorJax + + values: jtp.VectorJax + log_prob_actions: jtp.VectorJax + + infos: FrozenDict[str, jtp.PyTree] = jax_dataclasses.field(default_factory=dict) + + @property + def flat(self) -> bool: + # return self.dones.ndim == 2 + # return self.dones.squeeze().ndim == 1 + # return self.dones.squeeze().ndim <= 1 + return self.dones.ndim < 3 + + @property + def number_of_levels(self) -> int: + if self.flat: + return 0 + + self.check_valid() + return self.dones.shape[0] + + def truncate_last_trajectory(self) -> "Memory": + if self.flat: + dones = self.dones.at[-1].set(True) + memory = jax_dataclasses.replace(self, dones=dones) # noqa + + else: + dones = self.dones.at[:, -1].set(True) + memory = jax_dataclasses.replace(self, dones=dones) # noqa + + return memory + + def flatten(self) -> "Memory": + if self.flat: + return self + + self.check_valid() + return jax.tree_map(lambda x: jnp.vstack(x), self) + + def unflatten(self) -> "Memory": + if not self.flat: + return self + + # return jax.tree_map(lambda x: jnp.stack([jnp.vstack(x)]), self) + + self.check_valid() + + memory = jax.tree_map(lambda x: x[jnp.newaxis, :], self) + memory.check_valid() + return memory + + @staticmethod + def build( + states: jtp.MatrixJax, + actions: jtp.MatrixJax, + rewards: jtp.VectorJax, + dones: jtp.VectorJax, + values: jtp.VectorJax, + log_prob_actions: jtp.VectorJax, + infos: FrozenDict[str, jtp.ArrayJax] = FrozenDict(), + ) -> "Memory": + memory = Memory( + states=states, + actions=actions, + rewards=rewards, + dones=dones, + values=values, + log_prob_actions=log_prob_actions, + infos=infos, + ) + + # We work on (N, 1) 1D arrays + # Transform scalars to 1D arrays + memory = jax.tree_map(lambda x: x.squeeze(), memory) + # memory = jax.tree_map(lambda x: jnp.vstack(x.squeeze()), memory) + memory = jax.tree_map(lambda x: jnp.array([x]) if x.ndim == 0 else x, memory) + # memory = jax.tree_map(lambda x: jnp.vstack(x), memory) # TODO + # memory = jax.tree_map(lambda x: x[jnp.newaxis,:] if x.ndim == 1 else x, memory) + + # D, (S, D), (L, S, D) + + # If there's a single sample (S=1), add a new trivial S axis: (D,) -> (1, D) + # if memory.dones.ndim == 1: + if memory.dones.size == 1: + memory = jax.tree_map(lambda x: x[jnp.newaxis, :], memory) + + # memory = jax.tree_map(lambda x: jnp.vstack(x) if x.ndim == 1 else x, memory) + # TODO: vstack necessary?? + + memory.check_valid() + return memory + + def __len__(self) -> int: + self.check_valid() + return self.dones.flatten().size + + def __getitem__(self, item) -> "Memory": + try: + item = operator.index(item) + except TypeError: + pass + + def get_slice(s: Union[slice, np.ndarray, jnp.ndarray, Tuple]) -> Memory: + return jax.tree_map(lambda x: x[s], self) + # squeezed = jax.tree_map(lambda x: jnp.squeeze(x[s]), self) + # return jax.tree_map(lambda x: jnp.vstack(x[s]), squeezed) + + if isinstance(item, int): + if item > len(self): + raise IndexError( + f"index {item} is out of bounds for axis 0 with size {len(self)}" + ) + + return get_slice(jnp.s_[item]) + + if isinstance(item, slice): + return get_slice(item) + + if isinstance(item, (np.ndarray, jnp.ndarray)): + return get_slice(item) + + if isinstance(item, tuple): + return get_slice(item) + + raise TypeError(item) + + def flat_level(self, level: int) -> "Memory": + if self.flat: + return self + + if level > self.number_of_levels - 1: + raise ValueError(level, self.number_of_levels) + + return jax.tree_map(lambda x: x[level], self) + + def has_nans(self) -> jtp.Bool: + has_nan = jax.tree_map(lambda l: jnp.isnan(l).any(), self) + return jax.flatten_util.ravel_pytree(has_nan)[0].any() + + def check_valid(self) -> None: + # if self.has_nans(): + # raise ValueError("Found NaN values") + + # L, S, D = (None, None, None) + + if self.dones.ndim < 2 or self.dones.ndim > 3: + raise ValueError(self.dones.shape) + + # if self.dones.ndim == 2: + # L, S, D = None, self.dones.shape[0], self.dones.shape[1] + # + # if self.dones.ndim == 3: + # L, S, D = self.dones[0].shape, self.dones[1].shape, self.dones.shape[2] + # + # if D != 1: + # raise ValueError(D, self.dones.shape) + + # return True + shape_of_leaves = jax.tree_map( + lambda x: x.shape, jax.tree_util.tree_leaves(self) + ) + + # (L, S, D) + # (S, D) + + # Check (S, ⋅) TODO: same as check below with L and S? can be removed? + if self.flat and len(set([s[0:1] for s in shape_of_leaves])) != 1: + raise ValueError(shape_of_leaves) + + # Check (L, S, ⋅) TODO: same as check below with L and S? can be removed? + if not self.flat and len(set([s[0:2] for s in shape_of_leaves])) != 1: + raise ValueError(shape_of_leaves) + + # Check all leaves have same shape (L,S,⋅) + # if len(set([s[:-1] for s in shape_of_leaves])) != 1: + # raise ValueError(shape_of_leaves) + + # Get the Level and Samples dimensions + L = self.dones.shape[0] if not self.flat else None + S = self.dones.shape[1] if not self.flat else self.dones.shape[0] + + # If flat, check that all leaves have S samples + if self.flat and set(s[0] for s in shape_of_leaves) != {S}: + raise ValueError(shape_of_leaves) + + # If not flat, check that all leaves have L levels and S samples + if not self.flat and set(s[0:2] for s in shape_of_leaves) != {(L, S)}: + raise ValueError(shape_of_leaves) + + def trajectories(self) -> Generator["Memory", None, None]: + # In this method we operate on non-flat memory + memory = self if not self.flat else self.unflatten() + + def trajectory_slices_of_level( + memory: Memory, level: int = 0 + ) -> List[Tuple[npt.NDArray, npt.NDArray, slice]]: + idx_of_ones, _ = np.where(memory.dones[level] == 1) + + if idx_of_ones.size < 2: + return [] + + start = idx_of_ones[:-1] + 1 + stop = idx_of_ones[1:] + 1 + + return [ + np.s_[level, idx_start:idx_stop] + for idx_start, idx_stop in zip(start, stop) + ] + + for level in range(memory.number_of_levels): + for s in trajectory_slices_of_level(memory=memory, level=level): + yield memory[s] + + +class DataLoader: + def __init__(self, memory: Memory): + self.memory = memory if memory.flat else memory.flatten() + + def batch_slices_generator( + self, + batch_size: int, + shuffle: bool = False, + seed: int = None, + # key: jax.random.PRNGKeyArray = None, + allow_partial_batch: bool = False, + ) -> Generator[jtp.ArrayJax, None, None]: + # Create the index mask + # mask_indices = jnp.arange(0, len(self.memory), dtype=int) + mask_indices = np.arange(0, len(self.memory), dtype=int) + + seed = seed if seed is not None else 0 + # key = key if key is not None else jax.random.PRNGKey(seed=seed) + + # When this function is JIT compiled with shuffle=True, the shuffled indices + # are always the same, according to the seed + if shuffle: + rng = np.random.default_rng(seed) + mask_indices = rng.permutation(mask_indices) + + # mask_indices = ( + # mask_indices + # if shuffle is False + # else jax.random.permutation(key=key, x=mask_indices) + # ) + + def boolean_mask_generator( + a: npt.NDArray, size: int + ) -> Generator[jtp.ArrayJax, None, None]: + if a.ndim != 1: + raise ValueError(a.ndim) + + if size > a.size: + raise ValueError(size, a.size) + + idx = 0 + mask = jnp.zeros(a.shape, dtype=bool) + + while idx + size <= a.size: + batch_slice = np.s_[idx : idx + size] + # yield mask.at[batch_slice].set(True) + + mask = np.zeros(a.shape, dtype=bool) + mask[batch_slice] = True + yield mask[np.array(mask_indices)] + + # low, high = idx, idx + size + # indices = jnp.arange(start=0, stop=mask.size) + # batch_slice_low = jnp.where(indices >= low, True, False) + # batch_slice_high = jnp.where(indices < high, True, False) + # yield batch_slice_low * batch_slice_high + + idx += size + + if allow_partial_batch: + # batch_slice = np.s_[idx:] + batch_slice = jnp.s_[idx:] + yield mask.at[batch_slice].set(True) + + yield from boolean_mask_generator(a=mask_indices, size=batch_size) diff --git a/src/jaxsim/training/networks.py b/src/jaxsim/training/networks.py new file mode 100644 index 000000000..335ae89a6 --- /dev/null +++ b/src/jaxsim/training/networks.py @@ -0,0 +1,173 @@ +from typing import Any, Callable, Sequence, Tuple + +import distrax +import flax.linen as nn +import gym.spaces +import jax +import jax.numpy as jnp +import numpy as np + +import jaxsim.typing as jtp +from jaxsim import logging +from jaxsim.training import distributions + + +class CriticNetwork(nn.Module): + layer_sizes: Sequence[int] + + activation: Callable[[jtp.Array], jtp.Array] = nn.relu + kernel_init: Callable[..., Any] = jax.nn.initializers.normal() + + bias: bool = True + activate_final: bool = False + + def setup(self): + logging.debug("Configuring critic network...") + + # Automatically add a final layer so that the NN outputs a scalar value V(s) + all_layer_sizes = np.hstack([self.layer_sizes, 1]) + + self.layers = [ + nn.Dense( + features=size, + kernel_init=self.kernel_init, + use_bias=self.bias, + ) + for size in all_layer_sizes + ] + + logging.debug(f" Layers: {all_layer_sizes.tolist()}") + activated = [ + self.activate_layer(layer_idx=idx) for idx, _ in enumerate(all_layer_sizes) + ] + logging.debug(f" Activated: {activated}") + logging.debug(f" Activation function: {self.activation.__name__}") + logging.debug(f" Kernel initializer: {self.kernel_init}") + + def activate_layer(self, layer_idx: int) -> bool: + return (layer_idx != len(self.layers) - 1) or self.activate_final + + def __call__(self, data: jtp.Array) -> jtp.Array: + x = data + + for idx, layer in enumerate(self.layers): + x = layer(x) + x = self.activation(x) if self.activate_layer(layer_idx=idx) else x + + return x + + +class ActorNetwork(nn.Module): + layer_sizes: Sequence[int] + action_space: gym.spaces.Box + + activation: Callable[[jtp.Array], jtp.Array] = nn.relu + kernel_init: Callable[..., Any] = jax.nn.initializers.normal() + + bias: bool = True + activate_final: bool = False + + log_std_min: float = None + log_std_max: float = jnp.log(1.0) + log_std_init_val: float = jnp.log(0.50) + + def setup(self) -> None: + logging.debug("Configuring actor network...") + + # Automatically add a final layer so that the NN outputs an array containing + # the means μ of the action distribution + all_layer_sizes = np.hstack([self.layer_sizes, self.action_space.sample().size]) + + # Get the index of the last layer + last_layer_idx = all_layer_sizes.size - 1 + + def get_kernel_init(layer_idx: int) -> Callable[..., Any]: + # For the last layer we use a xavier_normal initializer with a variance + # 1/100 smaller than the default (which is 1.0) + last_layer_kernel_init = jax.nn.initializers.normal(stddev=1e-2 / 100) + + return ( + self.kernel_init + if layer_idx != last_layer_idx + else last_layer_kernel_init + ) + + # Fully-connected network for the mean μ + self.layers = [ + nn.Dense( + features=size, + kernel_init=get_kernel_init(layer_idx=layer_idx), + use_bias=self.bias, + ) + for layer_idx, size in enumerate(all_layer_sizes) + ] + + # State-independent bias variables for −the log of- the variance σ + self.log_std = self.param( + "log_std", + lambda rng, shape: self.log_std_init_val * jnp.ones(shape), + self.action_space.sample().size, + ) + + logging.debug(f" Layers: {all_layer_sizes.tolist()}") + activated = [ + self.activate_layer(layer_idx=idx) for idx, _ in enumerate(all_layer_sizes) + ] + logging.debug(f" Activated: {activated}") + logging.debug(f" Activation function: {self.activation.__name__}") + logging.debug(f" Kernel initializer: {self.kernel_init}") + logging.debug( + " log(σ): {:.4f} (min={}, max={})".format( + self.log_std_init_val, self.log_std_min, self.log_std_max + ) + ) + + def activate_layer(self, layer_idx: int) -> bool: + return (layer_idx != len(self.layers) - 1) or self.activate_final + + def __call__(self, data: jtp.Array) -> distrax.Distribution: + x = data + + for idx, layer in enumerate(self.layers): + x = layer(x) + x = self.activation(x) if self.activate_layer(layer_idx=idx) else x + + # The mean μ is the NN output + mu = x + + # ========================== + # Distributional exploration + # ========================== + + # Helper to clip the standard deviation + def clip_log_std(value: jtp.Vector) -> jtp.Vector: + return jnp.clip(a=value, a_min=self.log_std_min, a_max=self.log_std_max) + + # Clip log(σ) if limits are defined + log_std_limits = np.array([self.log_std_min, self.log_std_max]) + log_std = clip_log_std(self.log_std) if log_std_limits.any() else self.log_std + + # Compute σ taking the exponential + std = jnp.exp(log_std) + + # Return the actor distribution + return distributions.TanhSquashedMultivariateNormalDiag( + # return distributions.GaussianSquashedMultivariateNormalDiag( + loc=mu, + scale_diag=std, + low=self.action_space.low, + high=self.action_space.high, + ) + + +class ActorCriticNetworks(nn.Module): + actor: ActorNetwork + critic: CriticNetwork + + def __call__(self, data: jtp.Array) -> Tuple[distrax.Distribution, jtp.Array]: + observation = data + + value = self.critic(data=observation) + distribution = self.actor(data=observation) + + return distribution, value diff --git a/src/jaxsim/training/sampler.py b/src/jaxsim/training/sampler.py new file mode 100644 index 000000000..feef77d47 --- /dev/null +++ b/src/jaxsim/training/sampler.py @@ -0,0 +1,279 @@ +import dataclasses +import functools +import time +from typing import Any, Callable, Dict, Tuple, Union + +import jax +import jax.flatten_util +import jax.numpy as jnp +import jax_dataclasses +import numpy.typing as npt +from flax.core.frozen_dict import FrozenDict +from rich.console import Console + +import jaxsim.typing as jtp +from jaxsim.simulation import systems as sys +from jaxsim.training.agent import Agent +from jaxsim.training.memory import Memory + + +@jax_dataclasses.pytree_dataclass +class EnvironmentSamplerBuffer: + key: jtp.ArrayJax + state: FrozenDict + + +@dataclasses.dataclass +class EnvironmentSampler: + parallel: bool + environment: sys.EnvironmentSystem + + duration: float + horizon: jtp.VectorJax + + _buffer: EnvironmentSamplerBuffer + _sample_function: Callable = jax_dataclasses.field(default=None) + + def get_sample_function( + self, + ) -> Callable[ + [FrozenDict, jtp.ArrayJax, Dict[str, Any]], Tuple[Memory, sys.EnvironmentSystem] + ]: + """ + Build the sampling function for either single or parallel environment. + + Returns: + The JIT-compiled sampling function. + """ + + if self._sample_function is not None: + return self._sample_function + + # We split the arguments batched and non-batched. + # The non-batched arguments will be stored in a kwargs dictionary. + func = lambda x0, key, kwargs: EnvironmentSampler.step_environment_over_horizon( + x0=x0, key=key, **kwargs + ) + + if not self.parallel: + self._sample_function = func + return self._sample_function + + # Thanks to the definition of the lambda, specifying the batched axes + # is much more simple in vmap + self._sample_function = jax.jit(jax.vmap(fun=func, in_axes=(0, 0, None))) + return self._sample_function + + def _sample( + self, x0: FrozenDict, key: jtp.Array, extra_args: Dict[str, Any] = None + ) -> Tuple[Memory, sys.EnvironmentSystem]: + """ + Call the JIT-compiled sampling functions. + + The first call this method logs the time spent to JIT-compile the function. + + Args: + x0: The state of the (possibly parallelized) system. + key: The key of the (possibly parallelized) system. + extra_args: The non-batch arguments of the system + + Returns: + The sampled memory and the system with the final state. + """ + + extra_args = extra_args if extra_args is not None else dict() + + if self._sample_function is not None: + return self.get_sample_function()(x0, key, extra_args) + + console = Console() + + with console.status("[bold green]JIT compiling sampling function...") as _: + start = time.time() + out = self.get_sample_function()(x0, key, extra_args) + elapsed = time.time() - start + console.log(f"JIT compiled sampler in {elapsed:.3f} seconds.") + + return out + + @staticmethod + def build( + environment: sys.EnvironmentSystem, + t0: float = 0.0, + tf: float = 5.0, + dt: float = 0.001, + seed: int = 0, + parallel_environments: int = 1, + ) -> "EnvironmentSampler": + def build_state_and_key_parallel() -> EnvironmentSamplerBuffer: + env_list = [ + environment.reset_environment(environment.seed(seed=i + 1)) + for i in range(parallel_environments) + ] + + def tree_transpose(list_of_trees): + # return jax.tree_multimap(lambda *xs: jnp.stack(xs), *list_of_trees) + return jax.tree_map(lambda *xs: jnp.stack(xs), *list_of_trees) + + keys_list = [env.key for env in env_list] + states_list = [env.state_subsystem for env in env_list] + + key = tree_transpose(keys_list) + state = tree_transpose(states_list) + + return EnvironmentSamplerBuffer(state=state, key=key) # noqa + + def build_state_and_key_not_parallel() -> EnvironmentSamplerBuffer: + env = environment.reset_environment(environment.seed(seed=seed)) + + key = env.key + state = env.state_subsystem + + return EnvironmentSamplerBuffer(state=state, key=key) # noqa + + buffer = ( + build_state_and_key_parallel() + if parallel_environments > 1 + else build_state_and_key_not_parallel() + ) + + # Build the sampler + sampler = EnvironmentSampler( + environment=environment, + parallel=parallel_environments > 1, + horizon=jnp.arange(start=t0, stop=tf, step=dt), + duration=tf - t0, + _buffer=buffer, + ) + + return sampler + + def sample(self, agent: Agent, explore: bool = True) -> Memory: + # Update the horizon + horizon = self.horizon + self.duration + + # Build the dict of non-batched arguments + inputs_dict = dict( + system=self.environment, agent=agent, t=horizon, explore=explore + ) + + # Sample from the environment + memory, environment = self._sample( + self._buffer.state, self._buffer.key, inputs_dict + ) + + # Create the new buffer making sure that shapes didn't change + with jax_dataclasses.copy_and_mutate(self._buffer) as buffer: + buffer.key = environment.key + buffer.state = environment.state_subsystem + + # Double check that the state dict didn't change + assert jax.tree_structure(self._buffer.state) == jax.tree_structure( + buffer.state + ) + + # Update the sampler state + self.horizon = horizon + self._buffer = buffer + + # Return the sampled memory + return memory + + @staticmethod + @functools.partial(jax.jit) + def run_actor( + agent: Agent, environment: sys.EnvironmentSystem, explore: bool = True + ) -> Tuple[jtp.VectorJax, jtp.VectorJax, jtp.VectorJax, Agent]: + key, agent = agent.advance_key2() + + observation = environment.get_observation(system=environment) + + action, log_prob_action, value = agent.choose_action( + observation=observation, + explore=jnp.array(explore).any(), + key=key, + ) + + return action, log_prob_action, value, agent + + @staticmethod + @functools.partial(jax.jit) + def step_environment_over_horizon( + system: sys.EnvironmentSystem, + agent: Agent, + t: npt.NDArray, + x0: FrozenDict = None, + key: jax.random.PRNGKey = None, + explore: Union[bool, jtp.VectorJax] = jnp.array(True), + ) -> Tuple[Memory, sys.EnvironmentSystem]: + # Handle state and key + system = system if x0 is None else system.update_subsystem_state(new_state=x0) + system = system if key is None else system.update_key(key=key) + + # Compute a dummy output of the system, used to initialize the buffer. + # We flatten the output (dict) in order to allocate a dense jax array + # used inside a JIT-compiled for loop. + out, system = system(t0=t[0], tf=t[0], u0=None) + out_flattened, restore_output_fn = jax.flatten_util.ravel_pytree(out) + + # Create the buffers storing environment data + values = jnp.zeros(shape=(t.size, 1)) + log_prob_actions = jnp.zeros(shape=(t.size, 1)) + system_output = jnp.zeros(shape=(t.size, out_flattened.size)) + + # Generate a new key from the environment used for sampling actions + agent_key, system = system.generate_subkey() + agent = jax_dataclasses.replace(agent, key=agent_key) # noqa + + # Initialize the loop carry + carry_init = (system_output, log_prob_actions, values, system, agent) + + def body_fun(i: int, carry: Tuple) -> Tuple: + # Unpack the loop carry + system_output, log_prob_actions, values, system, agent = carry + + # Execute the actor + (action, log_prob_action, value, agent) = EnvironmentSampler.run_actor( + agent=agent, environment=system, explore=explore + ) + + # Update values and log_prob + values = values.at[i].set(value) + log_prob_actions = log_prob_actions.at[i].set(log_prob_action) + + # Advance the environment and get its output + out, system = system(t0=t[i], tf=t[i + 1], u0=FrozenDict(action=action)) + + # Store the environment output in the buffer + out_flattened, _ = jax.flatten_util.ravel_pytree(out) + system_output = system_output.at[i, :].set(out_flattened) + + # Return the loop carry + return system_output, log_prob_actions, values, system, agent + + # Execute the rollout. The environment automatically resets when it's done. + system_output, log_prob_actions, values, system, agent = jax.lax.fori_loop( + lower=0, + upper=t.size, + body_fun=body_fun, + init_val=carry_init, + ) + + # Unflatten the output + output_horizon: FrozenDict = jax.vmap(lambda b: restore_output_fn(b))( + system_output + ) + + # Create the memory object + memory = Memory.build( + states=output_horizon["observation"], + actions=output_horizon["action"], + rewards=output_horizon["reward"], + dones=output_horizon["done"], + values=values, + log_prob_actions=log_prob_actions, + infos=output_horizon["info"], + ) + + # Return objects after stepping over the horizon + return memory, system diff --git a/src/jaxsim/training/trajectory_sampler.py b/src/jaxsim/training/trajectory_sampler.py new file mode 100644 index 000000000..7ecb6d1a8 --- /dev/null +++ b/src/jaxsim/training/trajectory_sampler.py @@ -0,0 +1,242 @@ +from typing import Callable, Tuple + +import jax +import jax.numpy as jnp +import jax_dataclasses + +from jaxsim.gym import Env, EnvironmentState +from jaxsim.gym.typing import * +from jaxsim.training.agent import Agent +from jaxsim.training.memory import Memory +from jaxsim.utils import JaxsimDataclass + + +@jax_dataclasses.pytree_dataclass +class TrajectorySampler(JaxsimDataclass): + env: Env + agent: Agent + state: EnvironmentState + + _sampling_fn: Callable = jax_dataclasses.static_field(default=None) + + @staticmethod + def build( + env: Env, agent: Agent, state: EnvironmentState, number_of_environments: int = 1 + ) -> "TrajectorySampler": + if number_of_environments > 1: + env, state = TrajectorySampler.make_parallel_environment( + number=number_of_environments, env=env, state=state + ) + + with TrajectorySampler(env=env, agent=agent, state=state).editable( + validate=False + ) as sampler: + # Handle parallel environments + sampling_fn = ( + TrajectorySampler.Sample_trajectory + if sampler.number_of_environments() == 1 + else jax.vmap( + fun=TrajectorySampler.Sample_trajectory, + in_axes=(0, 0, None, None, None), + ) + ) + + # JIT compile sampling function + sampler._sampling_fn = jax.jit( + sampling_fn, static_argnames=["horizon_length"] + ) + + return sampler + + def seed(self, seed: int = 0) -> None: + if self.number_of_environments() == 1: + state = self.env.seed(state=self.state, seed=seed) + + else: + with self.state.editable(validate=True) as state: + state.key = jax.random.split( + key=jax.random.PRNGKey(seed=seed), num=self.number_of_environments() + ) + + self.set_mutability(mutable=True, validate=False) + self.state = state + self.set_mutability(mutable=False) + + _ = self.reset() + + @staticmethod + def reset_fn( + env: Env, state: EnvironmentState + ) -> Tuple[EnvironmentState, Tuple[Observation, Reward, IsDone, Info]]: + return env.reset(state=state) + + def reset(self) -> Observation: + reset_fn_jit = ( + jax.jit(jax.vmap(self.reset_fn)) + if self.number_of_environments() > 1 + else jax.jit(self.reset_fn) + ) + + state, observation = reset_fn_jit(env=self.env, state=self.state) + + self.set_mutability(mutable=True, validate=False) + self.state = state + self.set_mutability(mutable=False) + + return observation + + def number_of_environments(self) -> int: + shape_of_key = self.state.key.shape + return 1 if len(shape_of_key) == 1 else shape_of_key[0] + + def sample_trajectory( + self, horizon_length: int = 1, explore: bool = True + ) -> Memory: + # We cannot use named arguments: https://github.com/google/jax/issues/7465 + memory, state = self._sampling_fn( + self.env, self.state, self.agent, int(horizon_length), bool(explore) + ) + + self.set_mutability(mutable=True, validate=True) + self.state = state + self.set_mutability(mutable=False) + + return memory + + # =============== + # Private methods + # =============== + + @staticmethod + def make_parallel_environment( + number: int, env: Env, state: EnvironmentState + ) -> Tuple[Env, EnvironmentState]: + envs = jax.tree_map(lambda *l: jnp.stack(l), *[env] * number) + states = jax.tree_map(lambda *l: jnp.stack(l), *[state] * number) + + return envs, states + + @staticmethod + def Sample_trajectory( + env: Env, + state: EnvironmentState, + agent: Agent, + horizon_length: int = 1, + explore: bool = True, + ) -> Tuple[Memory, EnvironmentState]: + # Create the memory object of the entire batch + memory = jax.tree_map( + lambda *x0: jnp.stack(x0), + *[TrajectorySampler.Zero_memory_sample(env=env, state=state, agent=agent)] + * horizon_length, + ) + + carry_init = (state, memory) + + def body_fun(idx: int, carry: Tuple) -> Tuple: + # Unpack the carry + state, memory = carry + + # Get the current observation + observation = env.get_observation(state=state).flatten() + + # Sample a new action + subkey, state = state.generate_key() + + action, log_prob_action, value = agent.choose_action( + observation=observation, explore=explore, key=subkey + ) + # distribution, value = agent.train_state.apply_fn( + # agent.train_state.params, data=observation + # ) + # action = jax.lax.select( + # pred=explore, + # on_true=distribution.sample(seed=subkey), + # on_false=distribution.mode(), + # ) + # log_prob_action = distribution.log_prob(value=action) + + # Step the environment with automatic reset + state, (_, reward, is_done, info) = TrajectorySampler.Step_environment( + env=env, state=state, action=action + ) + + # Build a single-entry memory object with the sample + sample = Memory.build( + states=observation, + actions=action, + rewards=reward, + dones=is_done, + values=value, + log_prob_actions=log_prob_action, + infos=info, + ) + + # Store the new sample + memory = jax.tree_map( + lambda stacked, leaf: stacked.at[idx].set(leaf), + memory, + TrajectorySampler.Memory_to_memory_1D(sample=sample), + ) + + return state, memory + + state, memory = jax.lax.fori_loop( + lower=0, + upper=horizon_length, + body_fun=body_fun, + init_val=carry_init, + ) + + return memory, state + + @staticmethod + def Step_environment( + env: Env, state: EnvironmentState, action: Action + ) -> Tuple[EnvironmentState, Tuple[Observation, Reward, IsDone, Info]]: + # Step the environment + state, (observation, reward, is_done, info) = env.step( + action=action, state=state + ) + + # Automatically reset the environment if done + state = jax.lax.cond( + pred=is_done, + true_fun=lambda: env.reset(state=state)[0], + false_fun=lambda: state, + ) + + return state, (observation, reward, is_done, info) + + @staticmethod + def Memory_to_memory_1D(sample: Memory) -> Memory: + # Fix Memory to handle just one sample (S=1) + def memory_to_sample(leaf): + l = leaf.squeeze() + return jnp.array([l]) if l.ndim == 0 else l + + return jax.tree_map(memory_to_sample, sample) + + @staticmethod + def Zero_memory_sample(env: Env, state: EnvironmentState, agent: Agent) -> Memory: + _, obs = env.reset(state=state) + + action, log_prob_action, value = agent.choose_action( + observation=obs.flatten(), explore=False, key=state.key + ) + + _, (obs, reward, is_done, info) = env.step(action=action, state=state) + + sample = Memory.build( + states=obs.flatten(), + actions=action, + rewards=reward, + dones=is_done, + values=value, + log_prob_actions=log_prob_action, + infos=info, + ) + + zero_sample = jax.tree_map(lambda l: jnp.zeros_like(l), sample) + + return TrajectorySampler.Memory_to_memory_1D(sample=zero_sample) diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py new file mode 100644 index 000000000..b79fd990f --- /dev/null +++ b/src/jaxsim/utils/__init__.py @@ -0,0 +1,8 @@ +from jax_dataclasses._copy_and_mutate import _Mutability as Mutability + +from .jaxsim_dataclass import JaxsimDataclass +from .tracing import not_tracing, tracing +from .vmappable import Vmappable + +# Leave this below the others to prevent circular imports +from .oop import jax_tf # isort: skip diff --git a/src/jaxsim/utils.py b/src/jaxsim/utils/jaxsim_dataclass.py similarity index 56% rename from src/jaxsim/utils.py rename to src/jaxsim/utils/jaxsim_dataclass.py index 07a6bef5e..c995b4ff6 100644 --- a/src/jaxsim/utils.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -1,44 +1,32 @@ import abc import contextlib import copy -from typing import Any, ContextManager, TypeVar +import dataclasses +from typing import Generator import jax.abstract_arrays import jax.flatten_util import jax.interpreters.partial_eval import jax_dataclasses -from jax_dataclasses._copy_and_mutate import _Mutability as Mutability import jaxsim.typing as jtp -T = TypeVar("T") +from . import Mutability - -def tracing(var: Any) -> bool: - """Returns True if the variable is being traced by JAX, False otherwise.""" - - return jax.numpy.array( - [ - isinstance(var, t) - for t in ( - jax.abstract_arrays.ShapedArray, - jax.interpreters.partial_eval.DynamicJaxprTracer, - ) - ] - ).any() - - -def not_tracing(var: Any) -> bool: - """Returns True if the variable is not being traced by JAX, False otherwise.""" - - return True if tracing(var) is False else False +try: + from typing import Self +except ImportError: + from typing_extensions import Self class JaxsimDataclass(abc.ABC): """""" + # This attribute is set by jax_dataclasses + __mutability__ = None + @contextlib.contextmanager - def editable(self: T, validate: bool = True) -> ContextManager[T]: + def editable(self: Self, validate: bool = True) -> Generator[Self, None, None]: """""" mutability = ( @@ -48,23 +36,34 @@ def editable(self: T, validate: bool = True) -> ContextManager[T]: with JaxsimDataclass.mutable_context(self.copy(), mutability=mutability) as obj: yield obj - # with jax_dataclasses.copy_and_mutate(self, validate=validate) as self_rw: - # yield self_rw - # - # self_rw._set_mutability(self._mutability()) - @contextlib.contextmanager - def mutable_context(self: T, mutability: Mutability) -> ContextManager[T]: + def mutable_context( + self: Self, mutability: Mutability, restore_after_exception: bool = True + ) -> Generator[Self, None, None]: """""" original_mutability = self._mutability() - self._set_mutability(mutability) - yield self - - self._set_mutability(original_mutability) - - def is_mutable(self: T, validate: bool = False) -> bool: + if restore_after_exception: + self_copy = copy.copy(self) + + def restore_self(): + self._set_mutability(mutability=Mutability.MUTABLE) + for f in dataclasses.fields(self_copy): + setattr(self, f.name, getattr(self_copy, f.name)) + + try: + self._set_mutability(mutability) + yield self + except Exception as e: + if restore_after_exception: + restore_self() + self._set_mutability(original_mutability) + raise e + finally: + self._set_mutability(original_mutability) + + def is_mutable(self, validate: bool = False) -> bool: """""" return ( @@ -91,21 +90,21 @@ def _set_mutability(self, mutability: Mutability) -> None: self, mutable=mutability, visited=set() ) - def mutable(self: T, mutable: bool = True, validate: bool = False) -> T: + def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self: self.set_mutability(mutable=mutable, validate=validate) return self - def copy(self: T) -> T: - obj = copy.deepcopy(self) + def copy(self: Self) -> Self: + obj = jax.tree_util.tree_map(lambda leaf: leaf, self) obj._set_mutability(mutability=self._mutability()) return obj - def replace(self: T, validate: bool = True, **kwargs) -> T: + def replace(self: Self, validate: bool = True, **kwargs) -> Self: with self.editable(validate=validate) as obj: _ = [obj.__setattr__(k, copy.copy(v)) for k, v in kwargs.items()] obj._set_mutability(mutability=self._mutability()) return obj - def flatten(self: T) -> jtp.VectorJax: + def flatten(self) -> jtp.VectorJax: return jax.flatten_util.ravel_pytree(self)[0] diff --git a/src/jaxsim/utils/oop.py b/src/jaxsim/utils/oop.py new file mode 100644 index 000000000..cc4099446 --- /dev/null +++ b/src/jaxsim/utils/oop.py @@ -0,0 +1,497 @@ +import contextlib +import dataclasses +import functools +import inspect +import os +from typing import Any, Callable, Generator + +import jax +import jax.flatten_util + +from jaxsim import logging +from jaxsim.utils import tracing + +from . import Mutability, Vmappable + + +class jax_tf: + """ + Class containing decorators applicable to methods of Vmappable objects. + """ + + # Environment variables that can be used to disable the transformations + EnvVarOOP: str = "JAXSIM_OOP_DECORATORS" + EnvVarJitOOP: str = "JAXSIM_OOP_DECORATORS_JIT" + EnvVarVmapOOP: str = "JAXSIM_OOP_DECORATORS_VMAP" + + @staticmethod + def method_ro( + fn: Callable, + jit: bool = True, + static_argnames: tuple[str, ...] | list[str] = (), + vmap: bool | None = None, + vmap_in_axes: tuple[int, ...] | int | None = None, + vmap_out_axes: tuple[int, ...] | int | None = None, + ): + """ + Decorator for r/o methods of classes inheriting from Vmappable. + """ + + return jax_tf.method( + fn=fn, + read_only=True, + validate=True, + jit_enabled=jit, + static_argnames=static_argnames, + vmap_enabled=vmap, + vmap_in_axes=vmap_in_axes, + vmap_out_axes=vmap_out_axes, + ) + + @staticmethod + def method_rw( + fn: Callable, + validate: bool = True, + jit: bool = True, + static_argnames: tuple[str, ...] | list[str] = (), + vmap: bool | None = None, + vmap_in_axes: tuple[int, ...] | int | None = None, + vmap_out_axes: tuple[int, ...] | int | None = None, + ): + """ + Decorator for r/w methods of classes inheriting from Vmappable. + """ + + return jax_tf.method( + fn=fn, + read_only=False, + validate=validate, + jit_enabled=jit, + static_argnames=static_argnames, + vmap_enabled=vmap, + vmap_in_axes=vmap_in_axes, + vmap_out_axes=vmap_out_axes, + ) + + @staticmethod + def method( + fn: Callable, + read_only: bool = True, + validate: bool = True, + jit_enabled: bool = True, + static_argnames: tuple[str, ...] | list[str] = (), + vmap_enabled: bool | None = None, + vmap_in_axes: tuple[int, ...] | int | None = None, + vmap_out_axes: tuple[int, ...] | int | None = None, + ): + """ + Decorator for methods of classes inheriting from Vmappable. + + This decorator enables executing the methods on an object characterized by a + desired mutability, that is selected considering the r/o and validation flags. + It also allows to transform the method with the jit/vmap transformations. + If the Vmappable object is vectorized, the method is automatically vmapped, and + the in_axes are properly post-processed to simplify the combination with jit. + + Args: + fn: The method to decorate. + read_only: Whether the method operates on a read-only object. + validate: Whether r/w methods should preserve the pytree structure. + jit_enabled: Whether to apply the jit transformation. + static_argnames: The names of the arguments that should be static. + vmap_enabled: Whether to apply the vmap transformation. + vmap_in_axes: The in_axes to use for the vmap transformation. + vmap_out_axes: The out_axes to use for the vmap transformation. + + Returns: + The decorated method. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + """The wrapper function that is returned by this decorator.""" + + # Methods of classes inheriting from Vmappable decorated by this wrapper + # automatically support jit/vmap/mutability features when called standalone. + # However, when objects are arguments of plain functions transformed with + # jit/vmap, and decorated methods are called inside those functions, we need + # to disable this decorator to avoid double wrapping and execution errors. + # We do so by iterating over the arguments, and checking whether they are + # being traced by JAX. + for argument in list(args) + list(kwargs.values()): + try: + argument_flat, _ = jax.flatten_util.ravel_pytree(argument) + + if tracing(argument_flat): + return fn(*args, **kwargs) + except: + continue + + # =============================================================== + # Wrap fn so that jit/vmap/mutability transformations are applied + # =============================================================== + + # Initialize the mutability of the instance over which the method is running. + # * In r/o methods, this approach prevents any type of mutation. + # * In r/w methods, this approach allows to catch early JIT recompilations + # caused by unwanted changes in the pytree structure. + if read_only: + mutability = Mutability.FROZEN + else: + mutability = ( + Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION + ) + + # Extract the class instance over which fn is called + instance: Vmappable = args[0] + assert isinstance(instance, Vmappable) + + # Inspect the environment to detect whether to enforce disabling jit/vmap + deco_on = jax_tf.env_var_on(jax_tf.EnvVarOOP) + jit_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarJitOOP) and deco_on + vmap_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarVmapOOP) and deco_on + + # Get the transformed function (possibly cached by functools.cache). + # Note that all the arguments of the following methods, when hashed, should + # uniquely identify the returned function so that a new function is built + # when arguments change and either jit or vmap have to be called again. + fn_db = jax_tf.wrap_fn( + fn=fn, # noqa + mutability=mutability, + jit=jit_enabled_env and jit_enabled, + static_argnames=tuple(static_argnames), + vmap=vmap_enabled_env + and ( + vmap_enabled is True + or (vmap_enabled is None and instance.vectorized) + ), + in_axes=vmap_in_axes, + out_axes=vmap_out_axes, + ) + + # Call the transformed (mutable/jit/vmap) method + out, obj = fn_db(*args, **kwargs) + + if read_only: + return out + + # ================================================================= + # From here we assume that the wrapper is operating on a r/w method + # ================================================================= + + from jax_dataclasses._dataclasses import JDC_STATIC_MARKER + + # Select the right runtime mutability. The only difference here is when a r/w + # method is called on a frozen object. In this case, we enable updating the + # pytree data and preserve its structure only if validation is enabled. + mutability_dict = { + Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION, + Mutability.MUTABLE: Mutability.MUTABLE, + Mutability.FROZEN: Mutability.MUTABLE + if validate + else Mutability.MUTABLE_NO_VALIDATION, + } + + # We need to replace all the dynamic leafs of the original instance with those + # computed by the functional transformation. + # We do so by iterating over the fields of the jax_dataclasses and ignoring + # all the fields that are marked as static. + with instance.mutable_context( + mutability=mutability_dict[instance._mutability()] + ): + for f in dataclasses.fields(instance): # noqa + if ( + hasattr(f, "type") + and hasattr(f.type, "__metadata__") + and JDC_STATIC_MARKER in f.type.__metadata__ + ): + continue + + try: + setattr(instance, f.name, getattr(obj, f.name)) + except AssertionError: + raise RuntimeError( + "Failed to update field '{}' (old={}|new={})".format( + f.name, getattr(instance, f.name), getattr(obj, f.name) + ) + ) + + return out + + return wrapper + + @staticmethod + @functools.cache + def wrap_fn( + fn: Callable, + mutability: Mutability, + jit: bool, + static_argnames: tuple[str, ...] | list[str], + vmap: bool, + in_axes: tuple[int, ...] | int | None, + out_axes: tuple[int, ...] | int | None, + ) -> Callable: + """ + Transform a method with jit/vmap and execute it on an object characterized + by the desired mutability. + + Note: + The method should take the object (self) as first argument. + + Note: + This returned transformed method is cached by considering the hash of all + the arguments. It will re-apply jit/vmap transformations only if needed. + + Args: + fn: The method to consider. + mutability: The mutability of the object on which the method is called. + jit: Whether to apply jit transformations. + static_argnames: The names of the arguments that should be considered static. + vmap: Whether to apply vmap transformations. + in_axes: The axes along which to vmap input arguments. + out_axes: The axes along which to vmap output arguments. + + Note: + In order to simplify the application of vmap, we close the method arguments + over all the non-mapped input arguments. Furthermore, for improving the + compatibility with jit, we also close the vmap application over the static + arguments. + + Returns: + The transformed method operating on an object with the desired mutability. + We maintain the same signature of the original method. + """ + + # Extract the signature of the function + sig = inspect.signature(fn) + + # All static arguments must be actual arguments of fn + for name in static_argnames: + if name not in sig.parameters: + raise ValueError(f"Static argument '{name}' not found in {fn}") + + # If in_axes is a tuple, its dimension should match the number of arguments + if isinstance(in_axes, tuple) and len(in_axes) != len(sig.parameters): + msg = "The length of 'in_axes' must match the number of arguments ({})" + raise ValueError(msg.format(len(sig.parameters))) + + # Check that static arguments are not mapped with vmap. + # This case would not work since static arguments are not traces and vmap need + # to trace arguments in order to map them. + if isinstance(in_axes, tuple): + for mapped_axis, arg_name in zip(in_axes, sig.parameters.keys()): + if mapped_axis is not None and arg_name in static_argnames: + raise ValueError( + f"Static argument '{arg_name}' cannot be mapped with vmap" + ) + + def fn_tf_vmap(function_to_vmap: Callable, *args, **kwargs): + """Wrapper applying the vmap transformation""" + + # Canonicalize the arguments so that all of them are kwargs + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + # Build a dictionary mapping all arguments to a mapped axis, even when + # the None is passed (defaults to in_axes=0) or and int is passed (defaults + # to in_axes=). + match in_axes: + case None: + argname_to_mapped_axis = {name: 0 for name in bound.arguments} + case tuple(): + argname_to_mapped_axis = { + name: in_axes[i] for i, name in enumerate(bound.arguments) + } + case int(): + argname_to_mapped_axis = {name: in_axes for name in bound.arguments} + case _: + raise ValueError(in_axes) + + # Build a dictionary (argument_name -> argument) for all mapped arguments. + # Note that a mapped argument is an argument whose axis is not None and + # is not a static jit argument. + vmap_mapped_args = { + arg: value + for arg, value in bound.arguments.items() + if argname_to_mapped_axis[arg] is not None + and arg not in static_argnames + } + + # Build a dictionary (argument_name -> argument) for all unmapped arguments + vmap_unmapped_args = { + arg: value + for arg, value in bound.arguments.items() + if arg not in vmap_mapped_args + } + + # Close the function over the unmapped arguments of vmap + fn_closed = functools.partial(function_to_vmap, **vmap_unmapped_args) + + # Create the in_axes tuple of only the mapped arguments + in_axes_mapped = tuple( + argname_to_mapped_axis[name] for name in vmap_mapped_args + ) + + # If all in_axes are the same, simplify in_axes tuple to be just an integer + if len(set(in_axes_mapped)) == 1: + in_axes_mapped = list(set(in_axes_mapped))[0] + + # If, instead, in_axes has different elements, we need to replace the mapped + # axis of "self" with a pytree having as leafs the mapped axis. + # This is because the vmap in_axes specification must be a tree prefix of + # the corresponding value. + if isinstance(in_axes_mapped, tuple) and "self" in vmap_mapped_args: + argname_to_mapped_axis["self"] = jax.tree_util.tree_map( + lambda _: argname_to_mapped_axis["self"], vmap_mapped_args["self"] + ) + in_axes_mapped = tuple( + argname_to_mapped_axis[name] for name in vmap_mapped_args + ) + + # Apply the vmap transformation and call the function passing only the + # mapped arguments. The unmapped arguments have been closed over. + # Note: that we altered the "in_axes" tuple so that it does not have any + # None elements. + # Note: if in_axes_mapped is a tuple, the following fails if we pass kwargs, + # we need to pass the unpacked args tuple instead. + return jax.vmap( + fn_closed, + in_axes=in_axes_mapped, + **dict(out_axes=out_axes) if out_axes is not None else {}, + )(*list(vmap_mapped_args.values())) + + def fn_tf_jit(function_to_jit: Callable, *args, **kwargs): + """Wrapper applying the jit transformation""" + + # Canonicalize the arguments so that all of them are kwargs + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + # Apply the jit transformation and call the function passing all arguments + # as keyword arguments + return jax.jit(function_to_jit, static_argnames=static_argnames)( + **bound.arguments + ) + + # First applied wrapper that executes fn in a mutable context + fn_mutable = functools.partial( + jax_tf.call_class_method_in_mutable_context, fn, jit, mutability + ) + + # Second applied wrapper that transforms fn with vmap + fn_vmap = fn_mutable if not vmap else functools.partial(fn_tf_vmap, fn_mutable) + + # Third applied wrapper that transforms fn with jit + fn_jit_vmap = fn_vmap if not jit else functools.partial(fn_tf_jit, fn_vmap) + + return fn_jit_vmap + + @staticmethod + def call_class_method_in_mutable_context( + fn: Callable, jit: bool, mutability: Mutability, *args, **kwargs + ) -> tuple[Any, Vmappable]: + """ + Wrapper to call a method on an object with the desired mutable context. + + Args: + fn: The method to call. + jit: Whether the method is being jit compiled or not. + mutability: The desired mutability context. + *args: The positional arguments to pass to the method (including self). + **kwargs: The keyword arguments to pass to the method. + + Returns: + A tuple containing the return value of the method and the object + possibly updated by the method if it is in read-write. + + Note: + This approach enables to jit-compile methods of a stateful object without + leaking traces, therefore obtaining a jax-compatible OOP pattern. + """ + + # Log here whether the method is being jit compiled or not. + # This log message does not get printed from compiled code, so here is the + # most appropriate place to be sure that we log it correctly. + if jit: + logging.debug(msg=f"JIT compiling {fn}") + + # Canonicalize the arguments so that all of them are kwargs + sig = inspect.signature(fn) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + # Extract the class instance over which fn is called + instance: Vmappable = bound.arguments["self"] + + # Select the right mutability. If the instance is mutable with validation + # disabled, we override the input mutability so that we do not fail in case + # of mismatched tree structure. + mut = ( + Mutability.MUTABLE_NO_VALIDATION + if instance._mutability() is Mutability.MUTABLE_NO_VALIDATION + else mutability + ) + + # Call fn in a mutable context + with instance.mutable_context(mutability=mut): + # Methods could call other decorated methods. When it happens, the decorator + # of the called method is invoked, that applies jit and vmap transformations. + # This is not desired as it calls vmap inside an already vmapped method. + # We work around this occurrence by disabling the jit/vmap decorators of all + # methods called inside fn through a context manager. + # Note that we already work around this in the beginning of the wrapper + # function by detecting traced arguments, but the decorator works also + # when jit=False and vmap=False, therefore only enforcing the mutability. + with jax_tf.disabled_oop_decorators(): + out = fn(**bound.arguments) + + return out, instance + + @staticmethod + def env_var_on(var_name: str, default_value: str = "1") -> bool: + """ + Check whether an environment variable is set to a value that is considered on. + + Args: + var_name: The name of the environment variable. + default_value: The default variable value to consider if the variable has not + been exported. + + Returns: + True if the environment variable contains an on value, False otherwise. + """ + + on_values = {"1", "true", "on", "yes"} + return os.environ.get(var_name, default_value).lower() in on_values + + @staticmethod + @contextlib.contextmanager + def disabled_oop_decorators() -> Generator[None, None, None]: + """ + Context manager to disable the application of jax transformations performed by + the decorators of this class. + + Note: when the transformations are disabled, the only logic still applied is + the selection of the object mutability over which the method is running. + """ + + # Check whether the environment variable is part of the environment and + # save its value. We restore the original value before exiting the context. + env_cache = ( + None if jax_tf.EnvVarOOP not in os.environ else os.environ[jax_tf.EnvVarOOP] + ) + + # Disable both jit and vmap transformations + os.environ[jax_tf.EnvVarOOP] = "0" + + try: + # Execute the code in the context with disabled transformations + yield + + finally: + # Restore the original value of the environment variable or remove it if + # it was not present before entering the context + if env_cache is not None: + os.environ[jax_tf.EnvVarOOP] = env_cache + else: + _ = os.environ.pop(jax_tf.EnvVarOOP) diff --git a/src/jaxsim/utils/tracing.py b/src/jaxsim/utils/tracing.py new file mode 100644 index 000000000..e85deeb4e --- /dev/null +++ b/src/jaxsim/utils/tracing.py @@ -0,0 +1,26 @@ +from typing import Any + +import jax._src.core +import jax.abstract_arrays +import jax.flatten_util +import jax.interpreters.partial_eval + + +def tracing(var: Any) -> bool | jax.Array: + """Returns True if the variable is being traced by JAX, False otherwise.""" + + return jax.numpy.array( + [ + isinstance(var, t) + for t in ( + jax._src.core.Tracer, + jax.interpreters.partial_eval.DynamicJaxprTracer, + ) + ] + ).any() + + +def not_tracing(var: Any) -> bool: + """Returns True if the variable is not being traced by JAX, False otherwise.""" + + return True if tracing(var) is False else False diff --git a/src/jaxsim/utils/vmappable.py b/src/jaxsim/utils/vmappable.py new file mode 100644 index 000000000..0e449f4b8 --- /dev/null +++ b/src/jaxsim/utils/vmappable.py @@ -0,0 +1,117 @@ +import dataclasses +from typing import Type + +import jax +import jax.numpy as jnp +import jax_dataclasses + +from . import JaxsimDataclass, Mutability + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class Vmappable(JaxsimDataclass): + """Abstract class with utilities for vmappable pytrees.""" + + batch_size: jax_dataclasses.Static[int] = dataclasses.field( + default=int(0), repr=False, compare=False, hash=False, kw_only=True + ) + + @property + def vectorized(self) -> bool: + """Marks this pytree as vectorized.""" + + return self.batch_size > 0 + + @classmethod + def build_from_list(cls: Type[Self], list_of_obj: list[Self]) -> Self: + """ + Build a vectorized pytree from a list of pytree of the same type. + + Args: + list_of_obj: The list of pytrees to vectorize. + + Returns: + The vectorized pytree having as leaves the stacked leaves of the input list. + """ + + if set(type(el) for el in list_of_obj) != {cls}: + msg = "The input list must contain only objects of type '{}'" + raise ValueError(msg.format(cls.__name__)) + + # Create a pytree by stacking all the leafs of the input list + data_vec: Vmappable = jax.tree_map( + lambda *leafs: jnp.array(leafs), *list_of_obj + ) + + # Store the batch dimension + with data_vec.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + data_vec.batch_size = len(list_of_obj) + + # Detect the most common mutability in the input list + mutabilities = [e._mutability() for e in list_of_obj] + mutability = max(set(mutabilities), key=mutabilities.count) + + # Update the mutability of the vectorized pytree + data_vec._set_mutability(mutability) + + return data_vec + + def vectorize(self: Self, batch_size: int) -> Self: + """ + Return a vectorized version of this pytree. + + Args: + batch_size: The batch size. + + Returns: + A vectorized version of this pytree obtained by stacking the leaves of the + original pytree along a new batch dimension (the first one). + """ + + if self.vectorized: + raise RuntimeError("Cannot vectorize an already vectorized object") + + if batch_size == 0: + return self.copy() + + # TODO validate if mutability is maintained + + return self.__class__.build_from_list(list_of_obj=[self] * batch_size) + + def extract_element(self: Self, index: int) -> Self: + """ + Extract the i-th element from a vectorized pytree. + + Args: + index: The index of the element to extract. + + Returns: + A non vectorized pytree obtained by extracting the i-th element from the + vectorized pytree. + """ + + if index < 0: + raise ValueError("The index of the desired element cannot be negative") + + if index == 0 and self.batch_size == 0: + return self.copy() + + if not self.vectorized: + raise RuntimeError("Cannot extract elements from a non-vectorized object") + + if index >= self.batch_size: + raise ValueError("The index must be smaller than the batch size") + + # Get the i-th pytree by extracting the i-th element from the vectorized pytree + data = jax.tree_map(lambda leaf: leaf[index], self) + + # Update the batch size of the extracted scalar pytree + with data.mutable_context(mutability=Mutability.MUTABLE): + data.batch_size = 0 + + return data diff --git a/tests/test_jax_oop.py b/tests/test_jax_oop.py new file mode 100644 index 000000000..f887ecfae --- /dev/null +++ b/tests/test_jax_oop.py @@ -0,0 +1,422 @@ +import dataclasses +import io +from contextlib import redirect_stdout +from typing import Any, Type + +import jax +import jax.numpy as jnp +import jax_dataclasses +import numpy as np +import pytest + +from jaxsim.utils import Mutability, Vmappable, oop + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class AlgoData(Vmappable): + """Class storing vmappable data of a given algorithm.""" + + counter: jax.Array = dataclasses.field( + default_factory=lambda: jnp.array(0, dtype=jnp.uint64) + ) + + @classmethod + def build(cls: Type[Self], counter: jax.typing.ArrayLike) -> Self: + """Builder method. Helpful for enforcing type and shape of fields.""" + + # Counter can be int / scalar numpy array / scalar jax array / etc. + if jnp.array(counter).squeeze().size != 1: + raise ValueError("The counter must be a scalar") + + # Create the object enforcing `counter` to be a scalar jax array + data = AlgoData( + counter=jnp.array(counter, dtype=jnp.uint64).squeeze(), + ) + + return data + + +def test_data(): + """Test AlgoData class.""" + + data1 = AlgoData.build(counter=0) + data2 = AlgoData.build(counter=np.array(10)) + data3 = AlgoData.build(counter=jnp.array(50)) + + assert isinstance(data1.counter, jax.Array) and data1.counter.dtype == jnp.uint64 + assert isinstance(data2.counter, jax.Array) and data2.counter.dtype == jnp.uint64 + assert isinstance(data3.counter, jax.Array) and data3.counter.dtype == jnp.uint64 + + assert data1.batch_size == 0 + assert data2.batch_size == 0 + assert data3.batch_size == 0 + + # ================== + # Vectorizing PyTree + # ================== + + for batch_size in (0, 10, 100): + data_vec = data1.vectorize(batch_size=batch_size) + + assert data_vec.batch_size == batch_size + + if batch_size > 0: + assert data_vec.counter.shape[0] == batch_size + + # ========================================= + # Extracting element from vectorized PyTree + # ========================================= + + data_vec = AlgoData.build_from_list(list_of_obj=[data1, data2, data3]) + assert data_vec.batch_size == 3 + assert data_vec.extract_element(index=0) == data1 + assert data_vec.extract_element(index=1) == data2 + assert data_vec.extract_element(index=2) == data3 + + with pytest.raises(ValueError): + _ = data_vec.extract_element(index=3) + + out = data1.extract_element(index=0) + assert out == data1 + assert id(out) != id(data1) + + with pytest.raises(RuntimeError): + _ = data1.extract_element(index=1) + + with pytest.raises(ValueError): + _ = AlgoData.build_from_list(list_of_obj=[data1, data2, data3, 42]) + + +@jax_dataclasses.pytree_dataclass +class MyClassWithAlgorithms(Vmappable): + """ + Class to demonstrate how to use `Vmappable`. + """ + + # Dynamic data of the algorithm + data: AlgoData = dataclasses.field(default=None) + + # Static attribute of the pytree (triggers recompilation if changed) + double_input: jax_dataclasses.Static[bool] = dataclasses.field(default=None) + + # Non-static attribute of the pytree that is not transparently vmap-able. + done: jax.typing.ArrayLike = dataclasses.field( + default_factory=lambda: jnp.array(False, dtype=bool) + ) + + # Additional leaves to test the behaviour of mutable and immutable python objects + my_tuple: tuple[int] = dataclasses.field(default=tuple(jnp.array([1, 2, 3]))) + my_list: list[int] = dataclasses.field( + default_factory=lambda: [4, 5, 6], init=False + ) + my_array: jax.Array = dataclasses.field( + default_factory=lambda: jnp.array([10, 20, 30]) + ) + + @classmethod + def build(cls: Type[Self], double_input: bool = False) -> Self: + """""" + + obj = MyClassWithAlgorithms() + + with obj.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + obj.data = AlgoData.build(counter=0) + obj.double_input = jnp.array(double_input) + + return obj + + @oop.jax_tf.method_ro + def algo_ro(self, advance: int | jax.typing.ArrayLike) -> Any: + """This is a read-only algorithm. It does not alter any pytree leaf.""" + + # This should be printed only the first execution since it is disabled + # in the execution of the JIT-compiled function. + print("__algo_ro__") + + # Use the dynamic condition that doubles the input value + mul = jax.lax.select(self.double_input, 2, 1) + + # Increase the counter + counter_old = jnp.atleast_1d(self.data.counter)[0] + counter_new = counter_old + mul * advance + + # Return the updated counter + return counter_new + + @oop.jax_tf.method_rw + def algo_rw(self, advance: int | jax.typing.ArrayLike) -> Any: + """ + This is a read-write algorithm. It may alter pytree leaves either belonging + to the vmappable data or generic non-static dataclass attributes. + """ + + print(self) + + # This should be printed only the first execution since it is disabled + # in the execution of the JIT-compiled function. + print("__algo_rw__") + + # Use the dynamic condition that doubles the input value + mul = jax.lax.select(self.double_input, 2, 1) + + # Increase the internal counter + counter_old = jnp.atleast_1d(self.data.counter)[0] + self.data.counter = jnp.array(counter_old + mul * advance, dtype=jnp.uint64) + + # Update the non-static and non-vmap-able attribute + self.done = jax.lax.cond( + pred=self.data.counter > 100, + true_fun=lambda _: jnp.array(True), + false_fun=lambda _: jnp.array(False), + operand=None, + ) + + print(self) + + # Return the updated counter + return self.data.counter + + +def test_mutability(): + """Test MyClassWithAlgorithms class.""" + + # Build the object + obj_ro = MyClassWithAlgorithms.build(double_input=True) + + # By default, pytrees built with jax_dataclasses are frozen (read-only) + assert obj_ro._mutability() == Mutability.FROZEN + with pytest.raises(dataclasses.FrozenInstanceError): + obj_ro.data.counter = 42 + + # Data can be changed through a context manager, in this case operating on a copy... + with obj_ro.editable(validate=True) as obj_ro_copy: + obj_ro_copy.data.counter = jnp.array(42, dtype=obj_ro.data.counter.dtype) + assert obj_ro_copy.data.counter == pytest.approx(42) + assert obj_ro.data.counter != pytest.approx(42) + + # ... or a context manager that does not copy the pytree... + with obj_ro.mutable_context(mutability=Mutability.MUTABLE): + obj_ro.data.counter = jnp.array(42, dtype=obj_ro.data.counter.dtype) + assert obj_ro.data.counter == pytest.approx(42) + + # ... that raises if the leafs change type + with pytest.raises(AssertionError): + with obj_ro.mutable_context(mutability=Mutability.MUTABLE): + obj_ro.data.counter = 42 + + # Pytrees can be copied... + obj_ro_copy = obj_ro.copy() + assert id(obj_ro) != id(obj_ro_copy) + # ... operation that does not copy the leaves + # TODO describe + assert id(obj_ro.done) == id(obj_ro_copy.done) + assert id(obj_ro.data.counter) == id(obj_ro_copy.data.counter) + assert id(obj_ro.my_array) == id(obj_ro_copy.my_array) + assert id(obj_ro.my_tuple) != id(obj_ro_copy.my_tuple) + assert id(obj_ro.my_list) != id(obj_ro_copy.my_list) + + # They can be converted as mutable pytrees to update their values without + # using context managers (maybe useful for debugging or quick prototyping) + obj_rw = obj_ro.copy().mutable(validate=True) + assert obj_rw._mutability() == Mutability.MUTABLE + obj_rw.data.counter = jnp.array(42, dtype=obj_rw.data.counter.dtype) + + # However, with validation enabled, this works only if the leaf does not + # change its type (shape, dtype, weakness, ...) + with pytest.raises(AssertionError): + obj_rw.data.counter = 100 + with pytest.raises(AssertionError): + obj_rw.data.counter = jnp.array(100, dtype=float) + with pytest.raises(AssertionError): + obj_rw.data.counter = jnp.array([100, 200], dtype=obj_rw.data.counter.dtype) + + # Instead, with validation disabled, the pytree structure can be altered + # (and this might cause JIT recompilations, so use it at your own risk) + obj_rw_noval = obj_ro.copy().mutable(validate=False) + assert obj_rw_noval._mutability() == Mutability.MUTABLE_NO_VALIDATION + obj_rw_noval.data.counter = jnp.array(42, dtype=obj_rw.data.counter.dtype) + + # Now this should work without exceptions + obj_rw_noval.data.counter = 100 + obj_rw_noval.data.counter = jnp.array(100, dtype=float) + obj_rw_noval.data.counter = jnp.array([100, 200], dtype=obj_rw.data.counter.dtype) + + # Build another object and check mutability changes + obj_ro = MyClassWithAlgorithms.build(double_input=True) + assert obj_ro.is_mutable(validate=True) is False + assert obj_ro.is_mutable(validate=False) is False + + obj_rw_val = obj_ro.mutable(validate=True) + assert id(obj_ro) == id(obj_rw_val) + assert obj_rw_val.is_mutable(validate=True) is True + assert obj_rw_val.is_mutable(validate=False) is False + + obj_rw_noval = obj_rw_val.mutable(validate=False) + assert id(obj_rw_noval) == id(obj_rw_val) + assert obj_rw_noval.is_mutable(validate=True) is False + assert obj_rw_noval.is_mutable(validate=False) is True + + # Checking mutable leaves behavior + obj_rw = MyClassWithAlgorithms.build(double_input=True).mutable(validate=True) + obj_rw_copy = obj_rw.copy() + + # Memory of JAX arrays cannot be altered in place so this is safe + obj_rw.my_array = obj_rw.my_array.at[1].set(-20) + assert obj_rw_copy.my_array[1] != -20 + + # Tuples are immutable so this should be safe too + obj_rw.my_tuple = tuple(jnp.array([1, -2, 3])) + assert obj_rw_copy.my_array[1] != -2 + + # Lists are treated as tuples (they are not leaves) but since they are mutable, + # their id changes + obj_rw.my_list[1] = -5 + assert obj_rw_copy.my_list[1] != -5 + + # Check that exceptions in mutable context do not alter the object + obj_ro = MyClassWithAlgorithms.build(double_input=True) + assert obj_ro.data.counter == 0 + assert obj_ro.double_input == jnp.array(True) + + with pytest.raises(RuntimeError): + with obj_ro.mutable_context(mutability=Mutability.MUTABLE): + obj_ro.double_input = jnp.array(False, dtype=obj_ro.double_input.dtype) + obj_ro.data.counter = jnp.array(33, dtype=obj_ro.data.counter.dtype) + raise RuntimeError + assert obj_ro.data.counter == 0 + assert obj_ro.double_input == jnp.array(True) + + +def test_decorators_jit_compilation(): + """Test JIT features of MyClassWithAlgorithms class.""" + + obj = MyClassWithAlgorithms.build(double_input=False) + assert obj.data.counter == 0 + assert obj.is_mutable(validate=True) is False + assert obj.is_mutable(validate=False) is False + + # JIT compilation should happen only the first function call. + # We test this by checking that the first execution prints some output. + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_ro(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" in printed + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_ro(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" not in printed + + # JIT compilation should happen only the first function call. + # We test this by checking that the first execution prints some output. + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_rw__" in printed + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_rw__" not in printed + + # Create a new object + obj = MyClassWithAlgorithms.build(double_input=False) + + # New objects should be able to re-use the JIT-compiled functions from other objects + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_ro(advance=1) + _ = obj.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" not in printed + assert "__algo_rw__" not in printed + + # Create a new object + obj = MyClassWithAlgorithms.build(double_input=False) + + # Read-only methods can be called on r/o objects + out = obj.algo_ro(advance=1) + assert out == obj.data.counter + 1 + out = obj.algo_ro(advance=1) + assert out == obj.data.counter + 1 + + # Read-write methods can be called too on r/o objects since they are marked as r/w + out = obj.algo_rw(advance=1) + assert out == 1 + out = obj.algo_rw(advance=1) + assert out == 2 + out = obj.algo_rw(advance=2) + assert out == 4 + + # Create a new object with a different dynamic attribute + obj_dyn = MyClassWithAlgorithms.build(double_input=False).mutable(validate=True) + obj_dyn.done = jnp.array(not obj_dyn.done, dtype=bool) + + # New objects with different dynamic attributes should be able to re-use the + # JIT-compiled functions from other objects + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj.algo_ro(advance=1) + _ = obj.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" not in printed + assert "__algo_rw__" not in printed + + # Create a new object with a different static attribute + obj_stat = MyClassWithAlgorithms.build(double_input=True) + + # New objects with different static attributes trigger the recompilation of the + # JIT-compiled functions... + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj_stat.algo_ro(advance=1) + _ = obj_stat.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" in printed + assert "__algo_rw__" in printed + + # ... that are cached as well by jax + with io.StringIO() as buf, redirect_stdout(buf): + _ = obj_stat.algo_ro(advance=1) + _ = obj_stat.algo_rw(advance=1) + printed = buf.getvalue() + assert "__algo_ro__" not in printed + assert "__algo_rw__" not in printed + + +def test_decorators_vmap(): + """Test automatic vectorization features of MyClassWithAlgorithms class.""" + + # Create a new object with scalar data + obj = MyClassWithAlgorithms.build(double_input=False) + + # Vectorize the entire object + obj_vec = obj.vectorize(batch_size=10) + assert obj_vec.vectorized is True + assert obj_vec.batch_size == 10 + assert id(obj_vec) != id(obj) + + # Calling methods of vectorized objects with scalar arguments should raise an error + with pytest.raises(ValueError): + _ = obj_vec.algo_ro(advance=1) + with pytest.raises(ValueError): + _ = obj_vec.algo_rw(advance=1) + + # Check that the r/o method provides automatically vectorized output and accepts + # vectorized input + out_vec = obj_vec.algo_ro(advance=jnp.array([1] * obj_vec.batch_size)) + assert out_vec.shape[0] == 10 + assert set(out_vec.tolist()) == {1} + + # Check that the r/w method provides automatically vectorized output and accepts + # vectorized input + out_vec = obj_vec.algo_rw(advance=jnp.array([1] * obj_vec.batch_size)) + assert out_vec.shape[0] == 10 + assert set(out_vec.tolist()) == {1} + out_vec = obj_vec.algo_rw(advance=jnp.array([1] * obj_vec.batch_size)) + assert set(out_vec.tolist()) == {2} + + # Extract a single object from the vectorized object + obj = obj_vec.extract_element(index=5) + assert obj.vectorized is False + assert obj.data.counter == obj_vec.data.counter[5] diff --git a/train.py b/train.py new file mode 100644 index 000000000..f5d8b8700 --- /dev/null +++ b/train.py @@ -0,0 +1,562 @@ +import dataclasses +import multiprocessing +import pathlib +from typing import Any, ClassVar, Optional + +import jax.numpy as jnp +import jax.random +import jax_dataclasses +import numpy as np +import numpy.typing as npt +import rod +from meshcat_viz import MeshcatWorld +from resolve_robotics_uri_py import resolve_robotics_uri + +import jaxgym.jax.pytree_space as spaces +import jaxsim.typing as jtp +from jaxgym.jax import JaxDataclassEnv, JaxEnv +from jaxgym.vector.jax import JaxVectorEnv +from jaxgym.wrappers.jax import ( + FlattenSpacesWrapper, + JaxTransformWrapper, + TimeLimit, + ToNumPyWrapper, +) +from jaxsim import JaxSim +from jaxsim.physics.algos.soft_contacts import SoftContactsParams +from jaxsim.simulation import simulator_callbacks +from jaxsim.simulation.ode_integration import IntegratorType +from jaxsim.simulation.simulator import SimulatorData, VelRepr +from jaxsim.utils import JaxsimDataclass, Mutability + + +@jax_dataclasses.pytree_dataclass +class ErgoCubObservation(JaxsimDataclass): + """Observation of the ErgoCub environment.""" + + base_height: jtp.Float + gravity_projection: jtp.Array + + joint_positions: jtp.Array + joint_velocities: jtp.Array + + base_linear_velocity: jtp.Array + base_angular_velocity: jtp.Array + + contact_state: jtp.Array + + @staticmethod + def build( + base_height: jtp.Float, + gravity_projection: jtp.Array, + joint_positions: jtp.Array, + joint_velocities: jtp.Array, + base_linear_velocity: jtp.Array, + base_angular_velocity: jtp.Array, + contact_state: jtp.Array, + ) -> "ErgoCubObservation": + """Build an ErgoCubObservation object.""" + + return ErgoCubObservation( + base_height=jnp.array(base_height, dtype=float), + gravity_projection=jnp.array(gravity_projection, dtype=float), + joint_positions=jnp.array(joint_positions, dtype=float), + joint_velocities=jnp.array(joint_velocities, dtype=float), + base_linear_velocity=jnp.array(base_linear_velocity, dtype=float), + base_angular_velocity=jnp.array(base_angular_velocity, dtype=float), + contact_state=jnp.array(contact_state, dtype=bool), + ) + + +@dataclasses.dataclass +class MeshcatVizRenderState: + """Render state of a meshcat-viz visualizer.""" + + world: MeshcatWorld = dataclasses.dataclass(init=False) + + _gui_process: Optional[multiprocessing.Process] = dataclasses.field( + default=None, init=False, repr=False, hash=False, compare=False + ) + + _jaxsim_to_meshcat_viz_name: dict[str, str] = dataclasses.field( + default_factory=dict, init=False, repr=False, hash=False, compare=False + ) + + def __post_init__(self) -> None: + """""" + + self.world = MeshcatWorld() + self.world.open() + + def close(self) -> None: + """""" + + if self.world is not None: + self.world.close() + + if self._gui_process is not None: + self._gui_process.terminate() + self._gui_process.close() + + @staticmethod + def open_window(web_url: str) -> None: + """Open a new window with the given web url.""" + + import webview + + print(web_url) + webview.create_window("meshcat", web_url) + webview.start(gui="qt") + + def open_window_in_process(self) -> None: + """""" + + if self._gui_process is not None: + self._gui_process.terminate() + self._gui_process.close() + + self._gui_process = multiprocessing.Process( + target=MeshcatVizRenderState.open_window, args=(self.world.web_url,) + ) + self._gui_process.start() + + +StateType = dict[str, SimulatorData | jtp.Array] +ActType = jnp.ndarray +ObsType = ErgoCubObservation +RewardType = float | jnp.ndarray +TerminalType = bool | jnp.ndarray +RenderStateType = MeshcatVizRenderState + + +@jax_dataclasses.pytree_dataclass +class ErgoCubWalkFuncEnvV0( + JaxDataclassEnv[ + StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + ] +): + """ErgoCub environment implementing a target reaching task.""" + + name: ClassVar = jax_dataclasses.static_field(default="ErgoCubWalkFuncEnvV0") + + # Store an instance of the JaxSim simulator. + # It gets initialized with SimulatorData with a functional approach. + _simulator: JaxSim = jax_dataclasses.field(default=None) + + def __post_init__(self) -> None: + """Environment initialization.""" + + # Dummy initialization (not needed here) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + _ = self.jaxsim + model = self.jaxsim.get_model(model_name="ErgoCub") + + # Create the action space (static attribute) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + high = jnp.array([25.0] * model.dofs(), dtype=float) + self._action_space = spaces.PyTree(low=-high, high=high) + + # Get joint limits + s_min, s_max = model.joint_limits() + s_range = s_max - s_min + + low = ErgoCubObservation.build( + base_height=0.25, + gravity_projection=-jnp.ones(3), + joint_positions=s_min, + joint_velocities=-50.0 * jnp.ones_like(s_min), + base_linear_velocity=-5.0 * jnp.ones(3), + base_angular_velocity=-10.0 * jnp.ones(3), + contact_state=jnp.array([False] * 4), + ) + + high = ErgoCubObservation.build( + base_height=1.0, + gravity_projection=jnp.ones(3), + joint_positions=s_max, + joint_velocities=50.0 * jnp.ones_like(s_max), + base_linear_velocity=5.0 * jnp.ones(3), + base_angular_velocity=10.0 * jnp.ones(3), + contact_state=jnp.array([True] * 4), + ) + + # Create the observation space (static attribute) + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + self._observation_space = spaces.PyTree(low=low, high=high) + + @property + def jaxsim(self) -> JaxSim: + """""" + + if self._simulator is not None: + return self._simulator + + # Create the jaxsim simulator. + simulator = JaxSim.build( + # Note: any change of either 'step_size' or 'steps_per_run' requires + # updating the number of integration steps in the 'transition' method. + step_size=0.000_250, + steps_per_run=1, + velocity_representation=VelRepr.Body, + integrator_type=IntegratorType.EulerSemiImplicit, + simulator_data=SimulatorData( + gravity=jnp.array([0, 0, -10.0]), + contact_parameters=SoftContactsParams.build(K=10_000, D=20), + ), + ).mutable(mutable=True, validate=False) + + # Get the SDF path + model_sdf_path = resolve_robotics_uri( + "package://ergoCub/robots/ergoCubGazeboV1_minContacts/model.urdf" + ) + + # Insert the model + _ = simulator.insert_model_from_description( + model_description=model_sdf_path, model_name="ErgoCub" + ) + + simulator.data.models = { + model_name: jax.tree_util.tree_map(lambda leaf: jnp.array(leaf), model_data) + for model_name, model_data in simulator.data.models.items() + } + + with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + self._simulator = simulator.mutable(mutable=True, validate=True) + + return self._simulator + + def initial(self, rng: Any = None) -> StateType: + """""" + + # Split the key + subkey1, subkey2 = jax.random.split(rng, num=2) + + # Sample an initial observation + initial_observation: ErgoCubObservation = ( + self.observation_space.sample_with_key(key=subkey1) + ) + + # Sample a goal position + goal_xy_position = jax.random.uniform( + key=subkey2, minval=-5.0, maxval=5.0, shape=(2,) + ) + + with self.jaxsim.editable(validate=False) as simulator: + # Reset the simulator and get the model + simulator.reset(remove_models=False) + model = simulator.get_model(model_name="ErgoCub") + + # Reset the joint positions + model.reset_joint_positions( + positions=initial_observation.joint_positions, + joint_names=model.joint_names(), + ) + + # Reset the base position + model.reset_base_position(position=jnp.array([0, 0, 0.5])) + + # Reset the base velocity + model.reset_base_velocity( + base_velocity=jnp.hstack( + [ + 0.1 * initial_observation.base_linear_velocity, + 0.1 * initial_observation.base_angular_velocity, + ] + ) + ) + + # Return the simulation state + return dict( + simulator_data=simulator.data, + goal=jnp.array(goal_xy_position, dtype=float), + ) + + def transition( + self, state: StateType, action: ActType, rng: Any = None + ) -> StateType: + """""" + + # Get the JaxSim simulator + simulator = self.jaxsim + + # Initialize the simulator with the environment state (containing SimulatorData) + with simulator.editable(validate=True) as simulator: + simulator.data = state["simulator_data"] + + @jax_dataclasses.pytree_dataclass + class SetTorquesOverHorizon(simulator_callbacks.PreStepCallback): + def pre_step(self, sim: JaxSim) -> JaxSim: + """""" + + model = sim.get_model(model_name="ErgoCub") + model.zero_input() + model.set_joint_generalized_force_targets( + forces=jnp.atleast_1d(action), joint_names=model.joint_names() + ) + + return sim + + number_of_integration_steps = 40 # 0.010 # TODO 20 for having 0.010 + + # Stepping logic + with simulator.editable(validate=True) as simulator: + simulator, _ = simulator.step_over_horizon( + horizon_steps=number_of_integration_steps, + clear_inputs=False, + callback_handler=SetTorquesOverHorizon(), + ) + + # Return the new environment state (updated SimulatorData) + return state | dict(simulator_data=simulator.data) + + def observation(self, state: StateType) -> ObsType: + """""" + + # Initialize the simulator with the environment state (containing SimulatorData) + # and get the simulated model + with self.jaxsim.editable(validate=True) as simulator: + simulator.data = state["simulator_data"] + model = simulator.get_model("ErgoCub") + + # Compute the normalized gravity projection in the body frame + W_R_B = model.base_orientation(dcm=True) + # W_gravity = state.simulator.gravity() + W_gravity = self.jaxsim.gravity() + B_gravity = W_R_B.T @ (W_gravity / jnp.linalg.norm(W_gravity)) + + W_p_B = model.base_position() + W_p_goal = jnp.hstack([state["goal"].squeeze(), 0]) + + # Compute the distance between the base and the goal in the body frame + B_p_distance = W_R_B.T @ (W_p_goal - W_p_B) + + # Build the observation from the state + return ErgoCubObservation.build( + base_height=model.base_position()[2], + gravity_projection=B_gravity, + joint_positions=model.joint_positions(), + joint_velocities=model.joint_velocities(), + base_linear_velocity=model.base_velocity()[0:3], + base_angular_velocity=model.base_velocity()[3:6], + contact_state=model.in_contact( + link_names=[ + name for name in model.link_names() if name.endswith("_ankle") + ] + ), + ) + + def reward( + self, state: StateType, action: ActType, next_state: StateType + ) -> RewardType: + """""" + + with self.jaxsim.editable(validate=True) as simulator_next: + simulator_next.data = next_state["simulator_data"] + model_next = simulator_next.get_model("ErgoCub") + + terminal = self.terminal(state=state) + obs_in_space = jax.lax.select( + pred=self.observation_space.contains(x=self.observation(state=state)), + on_true=1.0, + on_false=0.0, + ) + + # Position of the base + W_p_B = model_next.base_position() + W_p_xy_goal = state["goal"] + + reward = 0.0 + reward += 1.0 * (1.0 - jnp.array(terminal, dtype=float)) # alive + reward += 5.0 * obs_in_space # + # reward += 100.0 * v_WB[0] # forward velocity + reward -= jnp.linalg.norm(W_p_B[0:2] - W_p_xy_goal) # distance from goal + reward += 1.0 * model_next.in_contact( + link_names=[ + name + for name in model_next.link_names() + if name.startswith("leg_") and name.endswith("_lower") + ] + ).any().astype(float) + reward -= 0.1 * jnp.linalg.norm(action) / action.size # control cost + + return reward + + def terminal(self, state: StateType) -> TerminalType: + # Get the current observation + observation = self.observation(state=state) + + base_too_high = ( + observation.base_height >= self.observation_space.high.base_height + ) + return base_too_high + + # ========= + # Rendering + # ========= + + def render_image( + self, state: StateType, render_state: RenderStateType + ) -> tuple[RenderStateType, npt.NDArray]: + """Show the state.""" + + model_name = "ErgoCub" + + # Initialize the simulator with the environment state (containing SimulatorData) + # and get the simulated model + with self.jaxsim.editable(validate=False) as simulator: + simulator.data = state["simulator_data"] + model = simulator.get_model(model_name=model_name) + + # Insert the model lazily in the visualizer if it is not already there + if model_name not in render_state.world._meshcat_models.keys(): + from rod.urdf.exporter import UrdfExporter + + urdf_string = UrdfExporter.sdf_to_urdf_string( + sdf=rod.Sdf( + version="1.7", + model=model.physics_model.description.extra_info["sdf_model"], + ), + pretty=True, + gazebo_preserve_fixed_joints=False, + ) + + meshcat_viz_name = render_state.world.insert_model( + model_description=urdf_string, is_urdf=True, model_name=None + ) + + render_state._jaxsim_to_meshcat_viz_name[model_name] = meshcat_viz_name + + # Check that the model is in the visualizer + if ( + not render_state._jaxsim_to_meshcat_viz_name[model_name] + in render_state.world._meshcat_models.keys() + ): + raise ValueError(f"The '{model_name}' model is not in the meshcat world") + + # Update the model in the visualizer + render_state.world.update_model( + model_name=render_state._jaxsim_to_meshcat_viz_name[model_name], + joint_names=model.joint_names(), + joint_positions=model.joint_positions(), + base_position=model.base_position(), + base_quaternion=model.base_orientation(dcm=False), + ) + + return render_state, np.empty(0) + + def render_init(self, open_gui: bool = False, **kwargs) -> RenderStateType: + """Initialize the render state.""" + + # Initialize the render state + meshcat_viz_state = MeshcatVizRenderState() + + if open_gui: + meshcat_viz_state.open_window_in_process() + + return meshcat_viz_state + + def render_close(self, render_state: RenderStateType) -> None: + """Close the render state.""" + + render_state.close() + + +class ErgoCubWalkEnvV0(JaxEnv): + """""" + + def __init__(self, render_mode: str | None = None, **kwargs: Any) -> None: + """""" + + from jaxgym.wrappers.jax import ( + ClipActionWrapper, + FlattenSpacesWrapper, + JaxTransformWrapper, + TimeLimit, + ) + + func_env = ErgoCubWalkFuncEnvV0() + + func_env_wrapped = func_env + func_env_wrapped = TimeLimit( + env=func_env_wrapped, max_episode_steps=5_000 + ) # TODO + func_env_wrapped = ClipActionWrapper(env=func_env_wrapped) + func_env_wrapped = FlattenSpacesWrapper(env=func_env_wrapped) + func_env_wrapped = JaxTransformWrapper(env=func_env_wrapped, function=jax.jit) + + super().__init__( + func_env=func_env_wrapped, + metadata=self.metadata, + render_mode=render_mode, + ) + + +class ErgoCubWalkVectorEnvV0(JaxVectorEnv): + """""" + + metadata = dict() + + def __init__( + self, + # func_env: JaxDataclassEnv[ + # StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType + # ], + num_envs: int, + render_mode: str | None = None, + # max_episode_steps: int = 5_000, + jit_compile: bool = True, + **kwargs, + ) -> None: + """""" + + print("+++", kwargs) + + env = ErgoCubWalkFuncEnvV0() + + # Vectorize the environment. + # Note: it automatically wraps the environment in a TimeLimit wrapper. + super().__init__( + func_env=env, + num_envs=num_envs, + metadata=self.metadata, + render_mode=render_mode, + max_episode_steps=5_000, # TODO + jit_compile=jit_compile, + ) + + # from jaxgym.vector.jax import FlattenSpacesVecWrapper + # + # vec_env_wrapped = FlattenSpacesVecWrapper(env=vec_env) + + +if __name__ == "__main__": + + def make_jax_env( + max_episode_steps: Optional[int] = 500, jit: bool = True + ) -> JaxEnv: + """""" + + if max_episode_steps in {None, 0}: + env = ErgoCubWalkFuncEnvV0() + else: + env = TimeLimit( + env=ErgoCubWalkFuncEnvV0(), max_episode_steps=max_episode_steps + ) + + return JaxEnv( + func_env=ToNumPyWrapper( + env=FlattenSpacesWrapper(env=env) + if not jit + else JaxTransformWrapper( + function=jax.jit, + env=FlattenSpacesWrapper(env=env), + ), + ), + render_mode="meshcat_viz", + ) + + env = make_jax_env(max_episode_steps=5, jit=True) + + obs, state_info = env.reset(seed=0) + _ = env.render() + print(obs)