Skip to content

Commit

Permalink
save current actions to h5
Browse files Browse the repository at this point in the history
  • Loading branch information
alik-git committed Nov 20, 2024
1 parent cae7a62 commit 833f03a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sim/h5_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _create_h5_file(self):
h5_file = h5py.File(h5_file_path, "w")

# Create datasets for logging actions and observations
dset_actions = h5_file.create_dataset("prev_actions", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_prev_actions = h5_file.create_dataset("prev_actions", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_2D_command = h5_file.create_dataset("observations/2D_command", (self.max_timesteps, 2), dtype=np.float32)
dset_3D_command = h5_file.create_dataset("observations/3D_command", (self.max_timesteps, 3), dtype=np.float32)
dset_q = h5_file.create_dataset("observations/q", (self.max_timesteps, self.num_actions), dtype=np.float32)
Expand All @@ -42,10 +42,12 @@ def _create_h5_file(self):
dset_euler = h5_file.create_dataset("observations/euler", (self.max_timesteps, 3), dtype=np.float32)
dset_t = h5_file.create_dataset("observations/t", (self.max_timesteps, 1), dtype=np.float32)
dset_buffer = h5_file.create_dataset("observations/buffer", (self.max_timesteps, self.num_observations), dtype=np.float32)
dset_curr_actions = h5_file.create_dataset("curr_actions", (self.max_timesteps, self.num_actions), dtype=np.float32)

# Map datasets for easy access
h5_dict = {
"prev_actions": dset_actions,
"prev_actions": dset_prev_actions,
"curr_actions": dset_curr_actions,
"2D_command": dset_2D_command,
"3D_command": dset_3D_command,
"joint_pos": dset_q,
Expand Down
1 change: 1 addition & 0 deletions sim/sim2sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def run_mujoco(
"joint_pos": cur_pos_obs.astype(np.float32),
"joint_vel": cur_vel_obs.astype(np.float32),
"prev_actions": actions.astype(np.float32),
"curr_actions": target_q.astype(np.float32),
"ang_vel": omega.astype(np.float32),
"euler_rotation": eu_ang.astype(np.float32),
"buffer": hist_obs.astype(np.float32)
Expand Down

0 comments on commit 833f03a

Please sign in to comment.