From fcebac40bc40f479b1cf31d7ff43e360e60dff0b Mon Sep 17 00:00:00 2001 From: Jash Shah Date: Sat, 14 Dec 2024 19:12:31 -0800 Subject: [PATCH] improved logging - data shape val, metadata storage, error handling --- sim/h5_logger.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/sim/h5_logger.py b/sim/h5_logger.py index 8e96a554..6e881b6b 100644 --- a/sim/h5_logger.py +++ b/sim/h5_logger.py @@ -71,6 +71,16 @@ def _create_h5_file(self) -> Tuple[h5py.File, Dict[str, h5py.Dataset]]: "t": dset_t, "buffer": dset_buffer, } + + metadata = { + "data_name": self.data_name, + "num_actions": self.num_actions, + "num_observations": self.num_observations, + "max_timesteps": self.max_timesteps, + "creation_time": timestamp, + } + h5_file.attrs['metadata'] = metadata + return h5_file, h5_dict def log_data(self, data: Dict[str, np.ndarray]) -> None: @@ -80,6 +90,9 @@ def log_data(self, data: Dict[str, np.ndarray]) -> None: for key, dataset in self.h5_dict.items(): if key in data: + if data[key].shape != dataset.shape[1:]: + print(f"Warning: Data shape mismatch for {key}. Expected {dataset.shape[1:]}, got {data[key].shape}.") + continue dataset[self.current_timestep] = data[key] self.current_timestep += 1 @@ -98,25 +111,26 @@ def close(self) -> None: self.h5_file.close() @staticmethod - def visualize_h5(h5_file_path: str) -> None: + def visualize_h5(h5_file_path: str, variable: str = None) -> None: """Visualizes the data from an HDF5 file by plotting each variable one by one. Args: h5_file_path (str): Path to the HDF5 file. + variable (str, optional): Specific variable to visualize. If None, all variables are plotted. """ try: - # Open the HDF5 file with h5py.File(h5_file_path, "r") as h5_file: - # Extract all datasets - for key in h5_file.keys(): - group = h5_file[key] - if isinstance(group, h5py.Group): - for subkey in group.keys(): - dataset = group[subkey][:] - HDF5Logger._plot_dataset(f"{key}/{subkey}", dataset) - else: - dataset = group[:] - HDF5Logger._plot_dataset(key, dataset) + keys = [variable] if variable else h5_file.keys() + for key in keys: + if key in h5_file: + group = h5_file[key] + if isinstance(group, h5py.Group): + for subkey in group.keys(): + dataset = group[subkey][:] + HDF5Logger._plot_dataset(f"{key}/{subkey}", dataset) + else: + dataset = group[:] + HDF5Logger._plot_dataset(key, dataset) except Exception as e: print(f"Failed to visualize HDF5 file: {e}")