From 4a01fc4c4c7456602a7dc16fd0dba5fac57a385f Mon Sep 17 00:00:00 2001 From: Wesley Maa Date: Fri, 6 Dec 2024 16:38:02 +0800 Subject: [PATCH] fix default standing --- sim/model_export.py | 3 +++ sim/sim2sim.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sim/model_export.py b/sim/model_export.py index 09c190a..e1f8db3 100644 --- a/sim/model_export.py +++ b/sim/model_export.py @@ -242,6 +242,8 @@ def get_actor_policy(model_path: str, cfg: ActorCfg) -> Tuple[nn.Module, dict, T num_actions = a_model.num_actions num_observations = a_model.num_observations + default_standing = list(a_model.robot.default_standing().values()) + return ( a_model, { @@ -250,6 +252,7 @@ def get_actor_policy(model_path: str, cfg: ActorCfg) -> Tuple[nn.Module, dict, T "robot_damping": robot_damping, "num_actions": num_actions, "num_observations": num_observations, + "default_standing": default_standing, }, input_tensors, ) diff --git a/sim/sim2sim.py b/sim/sim2sim.py index b512952..fcef1d7 100755 --- a/sim/sim2sim.py +++ b/sim/sim2sim.py @@ -319,7 +319,7 @@ def run_mujoco( policy_cfg = ActorCfg(embodiment=args.embodiment) if args.embodiment == "gpr": - policy_cfg.cycle_time = 0.25 + policy_cfg.cycle_time = 0.5 cfg = Sim2simCfg( sim_duration=10.0, dt=0.001, @@ -337,7 +337,7 @@ def run_mujoco( cycle_time=policy_cfg.cycle_time, ) - if args.load_model.endswith(".kinfer"): + if args.load_model.endswith(".kinfer") or args.load_model.endswith(".onnx"): policy = ONNXModel(args.load_model) else: actor_model, sim2sim_info, input_tensors = get_actor_policy(args.load_model, policy_cfg)