diff --git a/sim/model_export.py b/sim/model_export.py index 9f893791..51a812c5 100644 --- a/sim/model_export.py +++ b/sim/model_export.py @@ -205,6 +205,52 @@ def forward( return actions_scaled, actions, x +def get_actor_policy(model_path: str, cfg: ActorCfg) -> Tuple[nn.Module, dict, Tuple[Tensor, ...]]: + all_weights = torch.load(model_path, map_location="cpu", weights_only=True) + weights = all_weights["model_state_dict"] + num_actor_obs = weights["actor.0.weight"].shape[1] + num_critic_obs = weights["critic.0.weight"].shape[1] + num_actions = weights["std"].shape[0] + actor_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"actor\.\d+\.weight", k)] + critic_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"critic\.\d+\.weight", k)] + actor_hidden_dims = actor_hidden_dims[:-1] + critic_hidden_dims = critic_hidden_dims[:-1] + + ac_model = ActorCritic(num_actor_obs, num_critic_obs, num_actions, actor_hidden_dims, critic_hidden_dims) + ac_model.load_state_dict(weights) + + a_model = Actor(ac_model.actor, cfg) + + # Gets the model input tensors. + x_vel = torch.randn(1) + y_vel = torch.randn(1) + rot = torch.randn(1) + t = torch.randn(1) + dof_pos = torch.randn(a_model.num_actions) + dof_vel = torch.randn(a_model.num_actions) + prev_actions = torch.randn(a_model.num_actions) + imu_ang_vel = torch.randn(3) + imu_euler_xyz = torch.randn(3) + buffer = a_model.get_init_buffer() + input_tensors = (x_vel, y_vel, rot, t, dof_pos, dof_vel, prev_actions, imu_ang_vel, imu_euler_xyz, buffer) + + jit_model = torch.jit.script(a_model) + + # Add sim2sim metadata + robot_effort = list(a_model.robot.effort().values()) + robot_stiffness = list(a_model.robot.stiffness().values()) + robot_damping = list(a_model.robot.damping().values()) + num_actions = a_model.num_actions + num_observations = a_model.num_observations + + return a_model, { + "robot_effort": robot_effort, + "robot_stiffness": robot_stiffness, + "robot_damping": robot_damping, + "num_actions": num_actions, + "num_observations": num_observations, + }, input_tensors + def convert_model_to_onnx(model_path: str, cfg: ActorCfg, save_path: Optional[str] = None) -> ort.InferenceSession: """Converts a PyTorch model to a ONNX format. diff --git a/sim/play.py b/sim/play.py index 3e596a66..b5f135fc 100755 --- a/sim/play.py +++ b/sim/play.py @@ -22,9 +22,10 @@ from sim.env import run_dir # noqa: E402 from sim.envs import task_registry # noqa: E402 -from sim.model_export import ActorCfg, convert_model_to_onnx # noqa: E402 +from sim.model_export import ActorCfg, get_actor_policy # noqa: E402 from sim.utils.helpers import get_args # noqa: E402 from sim.utils.logger import Logger # noqa: E402 +from kinfer.export.pytorch import export_to_onnx import torch # isort: skip @@ -81,8 +82,19 @@ def play(args: argparse.Namespace) -> None: # export policy as a onnx module (used to run it on web) if args.export_onnx: path = ppo_runner.alg.actor_critic - convert_model_to_onnx(path, ActorCfg(), save_path="policy.onnx") - print("Exported policy as onnx to: ", path) + policy_cfg = ActorCfg() + actor_model, sim2sim_info, input_tensors = get_actor_policy(path, policy_cfg) + + # Merge policy_cfg and sim2sim_info into a single config object + export_config = {**vars(policy_cfg), **sim2sim_info} + + policy = export_to_onnx( + actor_model, + input_tensors=input_tensors, + config=export_config, + save_path="kinfer_policy.onnx" + ) + print("Exported policy as kinfer-compatible onnx to: ", path) # Prepare for logging env_logger = Logger(env.dt) diff --git a/sim/requirements.txt b/sim/requirements.txt index a131c312..53fb29da 100755 --- a/sim/requirements.txt +++ b/sim/requirements.txt @@ -11,3 +11,5 @@ wandb tensorboard==2.14.0 onnxscript # onnxruntime + +kinfer diff --git a/sim/sim2sim.py b/sim/sim2sim.py index 292462b5..59955505 100755 --- a/sim/sim2sim.py +++ b/sim/sim2sim.py @@ -17,10 +17,13 @@ import onnxruntime as ort import pygame from scipy.spatial.transform import Rotation as R +import torch from tqdm import tqdm from sim.h5_logger import HDF5Logger -from sim.model_export import ActorCfg, convert_model_to_onnx +from sim.model_export import ActorCfg, get_actor_policy +from kinfer.export.pytorch import export_to_onnx +from kinfer.inference.python import ONNXModel @dataclass @@ -238,7 +241,11 @@ def run_mujoco( input_data["buffer.1"] = hist_obs.astype(np.float32) - positions, curr_actions, hist_obs = policy.run(None, input_data) + policy_output = policy(input_data) + positions = policy_output["actions_scaled"] + curr_actions = policy_output["actions"] + hist_obs = policy_output["x.3"] + target_q = positions if log_h5: @@ -290,32 +297,6 @@ def run_mujoco( if log_h5: logger.close() - -def parse_modelmeta( - modelmeta: List[Tuple[str, str]], - verbose: bool = False, -) -> Dict[str, Union[float, List[float], str]]: - parsed_meta: Dict[str, Union[float, List[float], str]] = {} - for key, value in modelmeta: - if value.startswith("[") and value.endswith("]"): - parsed_meta[key] = list(map(float, value.strip("[]").split(","))) - else: - try: - parsed_meta[key] = float(value) - try: - if int(value) == parsed_meta[key]: - parsed_meta[key] = int(value) - except ValueError: - pass - except ValueError: - print(f"Failed to convert {value} to float") - parsed_meta[key] = value - if verbose: - for key, value in parsed_meta.items(): - print(f"{key}: {value}") - return parsed_meta - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Deployment script.") parser.add_argument("--embodiment", type=str, required=True, help="Embodiment name.") @@ -355,16 +336,31 @@ def parse_modelmeta( ) if args.load_model.endswith(".onnx"): - policy = ort.InferenceSession(args.load_model) + policy = ONNXModel(args.load_model) else: - policy = convert_model_to_onnx( - args.load_model, policy_cfg, save_path="policy.onnx" + actor_model, sim2sim_info, input_tensors = get_actor_policy(args.load_model, policy_cfg) + + # Merge policy_cfg and sim2sim_info into a single config object + export_config = {**vars(policy_cfg), **sim2sim_info} + print(export_config) + export_to_onnx( + actor_model, + input_tensors=input_tensors, + config=export_config, + save_path="kinfer_test.onnx" ) + policy = ONNXModel("kinfer_test.onnx") + + metadata = policy.get_metadata() + + model_info = { + "num_actions": metadata["num_actions"], + "num_observations": metadata["num_observations"], + "robot_effort": metadata["robot_effort"], + "robot_stiffness": metadata["robot_stiffness"], + "robot_damping": metadata["robot_damping"], + } - model_info = parse_modelmeta( - policy.get_modelmeta().custom_metadata_map.items(), - verbose=True, - ) run_mujoco( args.embodiment,