Skip to content

Commit

Permalink
fixed issue with datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghh04 committed Feb 6, 2025
1 parent e6a4697 commit a4372cd
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion dlio_benchmark/checkpointing/base_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_datatype_size(datatype):
elif datatype == "fp64":
return 8
else:
raise Exception("Unsupported datatype")
raise Exception("Unsupported datatype {datatype}")

class BaseCheckpointing(ABC):

Expand Down
6 changes: 4 additions & 2 deletions dlio_benchmark/checkpointing/pytorch_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def get_torch_datatype(datatype):
return torch.float16
if datatype == "f64":
return torch.float64
if datatype == "i8":
if datatype == "int8":
return torch.int8
if datatype == "uint8":
return torch.uint8
if datatype == "bf16": # bfloat16
return torch.bfloat16

Expand All @@ -54,7 +56,7 @@ def __init__(self):
super().__init__("pt")

@dlp.log
def get_tensor(self, size, datatype="i8"):
def get_tensor(self, size, datatype="int8"):
return torch.ones(size, dtype=get_torch_datatype(datatype))

@dlp.log
Expand Down
6 changes: 4 additions & 2 deletions dlio_benchmark/checkpointing/tf_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def get_tf_datatype(datatype):
return tf.dtypes.float16
if datatype == "f64":
return tf.dtypes.float64
if datatype == "i8":
if datatype == "int8":
return tf.dtypes.int8
if datatype == "uint8":
return tf.dtypes.uint8
if datatype == "bf16": # bfloat16
return tf.dtypes.bfloat16

Expand All @@ -54,7 +56,7 @@ def __init__(self):
super().__init__("pb")

@dlp.log
def get_tensor(self, size, datatype="i8"):
def get_tensor(self, size, datatype="int8"):
return tf.random.uniform((size), maxval=100, dtype=get_tf_datatype(datatype))

@dlp.log
Expand Down
4 changes: 2 additions & 2 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class ConfigArguments:
epochs_between_evals: int = 1
checkpoint_type: CheckpointLocationType = CheckpointLocationType.RANK_ZERO
checkpoint_mechanism: CheckpointMechanismType = CheckpointMechanismType.NONE
checkpoint_model_datatype: str = "f16"
checkpoint_optimizer_datatype: str = "f32"
checkpoint_model_datatype: str = "fp16"
checkpoint_optimizer_datatype: str = "fp32"
model_size: int = 10240
vocab_size: int = 32000
hidden_size: int = 2048
Expand Down

0 comments on commit a4372cd

Please sign in to comment.