Skip to content

Commit

Permalink
fixed datatype issue
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghh04 committed Feb 7, 2025
1 parent 843cabb commit 9690df2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
4 changes: 2 additions & 2 deletions dlio_benchmark/checkpointing/base_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def __init__(self, ext):
logging.info(f"{utcnow()} Layer size: {ss} GB")
logging.info(f"{utcnow()} Optimizer state size: {opt} GB")
logging.info(f"{utcnow()} Total checkpoint size: {self.checkpoint_size} GB")

@abstractmethod
def get_tensor(self, size):
def get_tensor(self, size, datatype="int8"):
return []

@abstractmethod
Expand Down
15 changes: 9 additions & 6 deletions dlio_benchmark/checkpointing/pytorch_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,21 @@
from dlio_benchmark.utils.utility import DLIOMPI

def get_torch_datatype(datatype):
if datatype == "f32":
if datatype == "fp32":
return torch.float32
if datatype == "f16":
elif datatype == "fp16":
return torch.float16
if datatype == "f64":
elif datatype == "fp64":
return torch.float64
if datatype == "int8":
elif datatype == "int8":
return torch.int8
if datatype == "uint8":
elif datatype == "uint8":
return torch.uint8
if datatype == "bf16": # bfloat16
elif datatype == "bf16": # bfloat16
return torch.bfloat16
else:
raise Exception(f"Invalid datatype {datatype}")


dlp = Profile(MODULE_CHECKPOINT)

Expand Down

0 comments on commit 9690df2

Please sign in to comment.