Skip to content

Commit

Permalink
Adding kinfer (#118)
Browse files Browse the repository at this point in the history
* hacky

* export with params

* add export script

* use kinfer for inference

* delete

* clean
  • Loading branch information
WT-MM authored Dec 3, 2024
1 parent e39f93a commit df886c1
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 38 deletions.
46 changes: 46 additions & 0 deletions sim/model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,52 @@ def forward(

return actions_scaled, actions, x

def get_actor_policy(model_path: str, cfg: ActorCfg) -> Tuple[nn.Module, dict, Tuple[Tensor, ...]]:
all_weights = torch.load(model_path, map_location="cpu", weights_only=True)
weights = all_weights["model_state_dict"]
num_actor_obs = weights["actor.0.weight"].shape[1]
num_critic_obs = weights["critic.0.weight"].shape[1]
num_actions = weights["std"].shape[0]
actor_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"actor\.\d+\.weight", k)]
critic_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"critic\.\d+\.weight", k)]
actor_hidden_dims = actor_hidden_dims[:-1]
critic_hidden_dims = critic_hidden_dims[:-1]

ac_model = ActorCritic(num_actor_obs, num_critic_obs, num_actions, actor_hidden_dims, critic_hidden_dims)
ac_model.load_state_dict(weights)

a_model = Actor(ac_model.actor, cfg)

# Gets the model input tensors.
x_vel = torch.randn(1)
y_vel = torch.randn(1)
rot = torch.randn(1)
t = torch.randn(1)
dof_pos = torch.randn(a_model.num_actions)
dof_vel = torch.randn(a_model.num_actions)
prev_actions = torch.randn(a_model.num_actions)
imu_ang_vel = torch.randn(3)
imu_euler_xyz = torch.randn(3)
buffer = a_model.get_init_buffer()
input_tensors = (x_vel, y_vel, rot, t, dof_pos, dof_vel, prev_actions, imu_ang_vel, imu_euler_xyz, buffer)

jit_model = torch.jit.script(a_model)

# Add sim2sim metadata
robot_effort = list(a_model.robot.effort().values())
robot_stiffness = list(a_model.robot.stiffness().values())
robot_damping = list(a_model.robot.damping().values())
num_actions = a_model.num_actions
num_observations = a_model.num_observations

return a_model, {
"robot_effort": robot_effort,
"robot_stiffness": robot_stiffness,
"robot_damping": robot_damping,
"num_actions": num_actions,
"num_observations": num_observations,
}, input_tensors


def convert_model_to_onnx(model_path: str, cfg: ActorCfg, save_path: Optional[str] = None) -> ort.InferenceSession:
"""Converts a PyTorch model to a ONNX format.
Expand Down
18 changes: 15 additions & 3 deletions sim/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

from sim.env import run_dir # noqa: E402
from sim.envs import task_registry # noqa: E402
from sim.model_export import ActorCfg, convert_model_to_onnx # noqa: E402
from sim.model_export import ActorCfg, get_actor_policy # noqa: E402
from sim.utils.helpers import get_args # noqa: E402
from sim.utils.logger import Logger # noqa: E402
from kinfer.export.pytorch import export_to_onnx

import torch # isort: skip

Expand Down Expand Up @@ -81,8 +82,19 @@ def play(args: argparse.Namespace) -> None:
# export policy as a onnx module (used to run it on web)
if args.export_onnx:
path = ppo_runner.alg.actor_critic
convert_model_to_onnx(path, ActorCfg(), save_path="policy.onnx")
print("Exported policy as onnx to: ", path)
policy_cfg = ActorCfg()
actor_model, sim2sim_info, input_tensors = get_actor_policy(path, policy_cfg)

# Merge policy_cfg and sim2sim_info into a single config object
export_config = {**vars(policy_cfg), **sim2sim_info}

policy = export_to_onnx(
actor_model,
input_tensors=input_tensors,
config=export_config,
save_path="kinfer_policy.onnx"
)
print("Exported policy as kinfer-compatible onnx to: ", path)

# Prepare for logging
env_logger = Logger(env.dt)
Expand Down
2 changes: 2 additions & 0 deletions sim/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ wandb
tensorboard==2.14.0
onnxscript
# onnxruntime

kinfer
66 changes: 31 additions & 35 deletions sim/sim2sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import onnxruntime as ort
import pygame
from scipy.spatial.transform import Rotation as R
import torch
from tqdm import tqdm

from sim.h5_logger import HDF5Logger
from sim.model_export import ActorCfg, convert_model_to_onnx
from sim.model_export import ActorCfg, get_actor_policy
from kinfer.export.pytorch import export_to_onnx
from kinfer.inference.python import ONNXModel


@dataclass
Expand Down Expand Up @@ -238,7 +241,11 @@ def run_mujoco(

input_data["buffer.1"] = hist_obs.astype(np.float32)

positions, curr_actions, hist_obs = policy.run(None, input_data)
policy_output = policy(input_data)
positions = policy_output["actions_scaled"]
curr_actions = policy_output["actions"]
hist_obs = policy_output["x.3"]

target_q = positions

if log_h5:
Expand Down Expand Up @@ -290,32 +297,6 @@ def run_mujoco(
if log_h5:
logger.close()


def parse_modelmeta(
modelmeta: List[Tuple[str, str]],
verbose: bool = False,
) -> Dict[str, Union[float, List[float], str]]:
parsed_meta: Dict[str, Union[float, List[float], str]] = {}
for key, value in modelmeta:
if value.startswith("[") and value.endswith("]"):
parsed_meta[key] = list(map(float, value.strip("[]").split(",")))
else:
try:
parsed_meta[key] = float(value)
try:
if int(value) == parsed_meta[key]:
parsed_meta[key] = int(value)
except ValueError:
pass
except ValueError:
print(f"Failed to convert {value} to float")
parsed_meta[key] = value
if verbose:
for key, value in parsed_meta.items():
print(f"{key}: {value}")
return parsed_meta


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Deployment script.")
parser.add_argument("--embodiment", type=str, required=True, help="Embodiment name.")
Expand Down Expand Up @@ -355,16 +336,31 @@ def parse_modelmeta(
)

if args.load_model.endswith(".onnx"):
policy = ort.InferenceSession(args.load_model)
policy = ONNXModel(args.load_model)
else:
policy = convert_model_to_onnx(
args.load_model, policy_cfg, save_path="policy.onnx"
actor_model, sim2sim_info, input_tensors = get_actor_policy(args.load_model, policy_cfg)

# Merge policy_cfg and sim2sim_info into a single config object
export_config = {**vars(policy_cfg), **sim2sim_info}
print(export_config)
export_to_onnx(
actor_model,
input_tensors=input_tensors,
config=export_config,
save_path="kinfer_test.onnx"
)
policy = ONNXModel("kinfer_test.onnx")

metadata = policy.get_metadata()

model_info = {
"num_actions": metadata["num_actions"],
"num_observations": metadata["num_observations"],
"robot_effort": metadata["robot_effort"],
"robot_stiffness": metadata["robot_stiffness"],
"robot_damping": metadata["robot_damping"],
}

model_info = parse_modelmeta(
policy.get_modelmeta().custom_metadata_map.items(),
verbose=True,
)

run_mujoco(
args.embodiment,
Expand Down

0 comments on commit df886c1

Please sign in to comment.