From 833f03a73515f564f667267d831775d10798eb4c Mon Sep 17 00:00:00 2001 From: alik-git Date: Wed, 20 Nov 2024 22:26:37 +0000 Subject: [PATCH] save current actions to h5 --- sim/h5_logger.py | 6 ++++-- sim/sim2sim.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sim/h5_logger.py b/sim/h5_logger.py index e455ae49..81f123bb 100644 --- a/sim/h5_logger.py +++ b/sim/h5_logger.py @@ -33,7 +33,7 @@ def _create_h5_file(self): h5_file = h5py.File(h5_file_path, "w") # Create datasets for logging actions and observations - dset_actions = h5_file.create_dataset("prev_actions", (self.max_timesteps, self.num_actions), dtype=np.float32) + dset_prev_actions = h5_file.create_dataset("prev_actions", (self.max_timesteps, self.num_actions), dtype=np.float32) dset_2D_command = h5_file.create_dataset("observations/2D_command", (self.max_timesteps, 2), dtype=np.float32) dset_3D_command = h5_file.create_dataset("observations/3D_command", (self.max_timesteps, 3), dtype=np.float32) dset_q = h5_file.create_dataset("observations/q", (self.max_timesteps, self.num_actions), dtype=np.float32) @@ -42,10 +42,12 @@ def _create_h5_file(self): dset_euler = h5_file.create_dataset("observations/euler", (self.max_timesteps, 3), dtype=np.float32) dset_t = h5_file.create_dataset("observations/t", (self.max_timesteps, 1), dtype=np.float32) dset_buffer = h5_file.create_dataset("observations/buffer", (self.max_timesteps, self.num_observations), dtype=np.float32) + dset_curr_actions = h5_file.create_dataset("curr_actions", (self.max_timesteps, self.num_actions), dtype=np.float32) # Map datasets for easy access h5_dict = { - "prev_actions": dset_actions, + "prev_actions": dset_prev_actions, + "curr_actions": dset_curr_actions, "2D_command": dset_2D_command, "3D_command": dset_3D_command, "joint_pos": dset_q, diff --git a/sim/sim2sim.py b/sim/sim2sim.py index 9799964f..3474fbfa 100755 --- a/sim/sim2sim.py +++ b/sim/sim2sim.py @@ -252,6 +252,7 @@ def run_mujoco( "joint_pos": cur_pos_obs.astype(np.float32), "joint_vel": cur_vel_obs.astype(np.float32), "prev_actions": actions.astype(np.float32), + "curr_actions": target_q.astype(np.float32), "ang_vel": omega.astype(np.float32), "euler_rotation": eu_ang.astype(np.float32), "buffer": hist_obs.astype(np.float32)