diff --git a/dlio_benchmark/checkpointing/pytorch_checkpointing.py b/dlio_benchmark/checkpointing/pytorch_checkpointing.py index bed9c7b7..c704fe5c 100644 --- a/dlio_benchmark/checkpointing/pytorch_checkpointing.py +++ b/dlio_benchmark/checkpointing/pytorch_checkpointing.py @@ -23,6 +23,7 @@ from dlio_benchmark.common.constants import MODULE_CHECKPOINT from dlio_benchmark.common.enumerations import CheckpointLocationType from dlio_benchmark.utils.utility import DLIOMPI +import logging def get_torch_datatype(datatype): if datatype == "fp32": diff --git a/dlio_benchmark/utils/statscounter.py b/dlio_benchmark/utils/statscounter.py index d6c30288..47e1b618 100644 --- a/dlio_benchmark/utils/statscounter.py +++ b/dlio_benchmark/utils/statscounter.py @@ -421,10 +421,10 @@ def finalize(self): def save_data(self): # Dump statistic counters to files for postprocessing # Overall stats + with open(os.path.join(self.output_folder, f'{self.my_rank}_per_epoch_stats.json'), 'w') as outfile: + json.dump(self.per_epoch_stats, outfile, indent=4) + outfile.flush() if self.my_rank == 0: - with open(os.path.join(self.output_folder, 'per_epoch_stats.json'), 'w') as outfile: - json.dump(self.per_epoch_stats, outfile, indent=4) - outfile.flush() with open(os.path.join(self.output_folder, 'summary.json'), 'w') as outfile: json.dump(self.summary, outfile, indent=4) self.output['hostname'] = socket.gethostname()