Skip to content

Commit

Permalink
update export
Browse files Browse the repository at this point in the history
  • Loading branch information
budzianowski committed Dec 10, 2024
1 parent 6141d6a commit 77a95aa
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 673 deletions.
Binary file modified examples/gpr_walking.kinfer
Binary file not shown.
2 changes: 1 addition & 1 deletion sim/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from sim.envs.humanoids.dora_env import DoraFreeEnv
from sim.envs.humanoids.g1_config import G1Cfg, G1CfgPPO
from sim.envs.humanoids.g1_env import G1FreeEnv
from sim.envs.humanoids.gpr_env import GprFreeEnv
from sim.envs.humanoids.gpr_config import GprCfg, GprCfgPPO, GprStandingCfg
from sim.envs.humanoids.gpr_env import GprFreeEnv
from sim.envs.humanoids.h1_config import H1Cfg, H1CfgPPO
from sim.envs.humanoids.h1_env import H1FreeEnv
from sim.envs.humanoids.xbot_config import XBotCfg, XBotCfgPPO
Expand Down
29 changes: 14 additions & 15 deletions sim/model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@

@dataclass
class ActorCfg:
embodiment: str = "gpr"
cycle_time: float = 0.4 # Cycle time for sinusoidal command input
action_scale: float = 0.25 # Scale for actions
lin_vel_scale: float = 2.0 # Scale for linear velocity
ang_vel_scale: float = 1.0 # Scale for angular velocity
quat_scale: float = 1.0 # Scale for quaternion
dof_pos_scale: float = 1.0 # Scale for joint positions
dof_vel_scale: float = 0.05 # Scale for joint velocities
frame_stack: int = 15 # Number of frames to stack for the policy input
clip_observations: float = 18.0 # Clip observations to this value
clip_actions: float = 18.0 # Clip actions to this value
embodiment: str
cycle_time: float # Cycle time for sinusoidal command input
action_scale: float # Scale for actions
lin_vel_scale: float # Scale for linear velocity
ang_vel_scale: float # Scale for angular velocity
quat_scale: float # Scale for quaternion
dof_pos_scale: float # Scale for joint positions
dof_vel_scale: float # Scale for joint velocities
frame_stack: int # Number of frames to stack for the policy input
clip_observations: float # Clip observations to this value
clip_actions: float # Clip actions to this value
sim_dt: float # Simulation time step
sim_decimation: int # Simulation decimation
tau_factor: float # Torque limit factor


class ActorCritic(nn.Module):
Expand Down Expand Up @@ -341,7 +344,3 @@ def convert_model_to_onnx(model_path: str, cfg: ActorCfg, save_path: Optional[st
buffer2.seek(0)

return ort.InferenceSession(buffer2.read())


if __name__ == "__main__":
convert_model_to_onnx("model_3000.pt", ActorCfg(), "policy.onnx")
35 changes: 20 additions & 15 deletions sim/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,11 @@
from kinfer.export.pytorch import export_to_onnx
from tqdm import tqdm

# Local imports third
from sim.env import run_dir # noqa: E402
from sim.envs import task_registry # noqa: E402
from sim.model_export import ( # noqa: E402
ActorCfg,
convert_model_to_onnx,
get_actor_policy,
)

# Local imports third
from sim.model_export import ActorCfg, get_actor_policy
from sim.utils.helpers import get_args # noqa: E402
from sim.utils.logger import Logger # noqa: E402

Expand Down Expand Up @@ -91,15 +88,22 @@ def play(args: argparse.Namespace) -> None:
if args.export_onnx:
path = ppo_runner.load_path
embodiment = ppo_runner.cfg["experiment_name"].lower()
policy_cfg = ActorCfg(embodiment=embodiment)

if embodiment == "gpr":
policy_cfg.cycle_time = 0.4
elif embodiment == "zeroth":
policy_cfg.cycle_time = 0.2
else:
print(f"Specific policy cfg for {embodiment} not implemented")

policy_cfg = ActorCfg(
embodiment=embodiment,
cycle_time=env_cfg.rewards.cycle_time,
sim_dt=env_cfg.sim.dt,
sim_decimation=env_cfg.control.decimation,
tau_factor=env_cfg.safety.torque_limit,
action_scale=env_cfg.control.action_scale,
lin_vel_scale=env_cfg.normalization.obs_scales.lin_vel,
ang_vel_scale=env_cfg.normalization.obs_scales.ang_vel,
quat_scale=env_cfg.normalization.obs_scales.quat,
dof_pos_scale=env_cfg.normalization.obs_scales.dof_pos,
dof_vel_scale=env_cfg.normalization.obs_scales.dof_vel,
frame_stack=env_cfg.env.frame_stack,
clip_observations=env_cfg.normalization.clip_observations,
clip_actions=env_cfg.normalization.clip_actions,
)
actor_model, sim2sim_info, input_tensors = get_actor_policy(path, policy_cfg)

# Merge policy_cfg and sim2sim_info into a single config object
Expand All @@ -117,6 +121,7 @@ def play(args: argparse.Namespace) -> None:
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if args.log_h5:
from sim.h5_logger import HDF5Logger

# Create directory for HDF5 files
h5_dir = run_dir() / "h5_out" / args.task / now
h5_dir.mkdir(parents=True, exist_ok=True)
Expand Down
231 changes: 0 additions & 231 deletions sim/play_old.py

This file was deleted.

2 changes: 1 addition & 1 deletion sim/resources/gpr/robot_fixed.xml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,6 @@
</sensor>

<keyframe>
<key name="default" qpos="0 0 1.051 1. 0.0 0.0 0.0 -0.23 0.0 0.0 -0.441 -0.195 0.23 0.0 0.0 0.441 0.195"/>
<key name="default" qpos="0 0 1.05 1. 0.0 0.0 0.0 -0.23 0.0 0.0 -0.441 -0.195 0.23 0.0 0.0 0.441 0.195"/>
</keyframe>
</mujoco>
Loading

0 comments on commit 77a95aa

Please sign in to comment.