Skip to content

Commit

Permalink
per_epoch_stats reported from all ranks
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghh04 committed Feb 21, 2025
1 parent d695c60 commit b535005
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions dlio_benchmark/checkpointing/pytorch_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dlio_benchmark.common.constants import MODULE_CHECKPOINT
from dlio_benchmark.common.enumerations import CheckpointLocationType
from dlio_benchmark.utils.utility import DLIOMPI
import logging

def get_torch_datatype(datatype):
if datatype == "fp32":
Expand Down
6 changes: 3 additions & 3 deletions dlio_benchmark/utils/statscounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,10 @@ def finalize(self):
def save_data(self):
# Dump statistic counters to files for postprocessing
# Overall stats
with open(os.path.join(self.output_folder, f'{self.my_rank}_per_epoch_stats.json'), 'w') as outfile:
json.dump(self.per_epoch_stats, outfile, indent=4)
outfile.flush()
if self.my_rank == 0:
with open(os.path.join(self.output_folder, 'per_epoch_stats.json'), 'w') as outfile:
json.dump(self.per_epoch_stats, outfile, indent=4)
outfile.flush()
with open(os.path.join(self.output_folder, 'summary.json'), 'w') as outfile:
json.dump(self.summary, outfile, indent=4)
self.output['hostname'] = socket.gethostname()
Expand Down

0 comments on commit b535005

Please sign in to comment.