From a4372cd932b16a257897819a9ca2bdc208cc5ac3 Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Thu, 6 Feb 2025 12:58:16 -0600 Subject: [PATCH] fixed issue with datatype --- dlio_benchmark/checkpointing/base_checkpointing.py | 2 +- dlio_benchmark/checkpointing/pytorch_checkpointing.py | 6 ++++-- dlio_benchmark/checkpointing/tf_checkpointing.py | 6 ++++-- dlio_benchmark/utils/config.py | 4 ++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dlio_benchmark/checkpointing/base_checkpointing.py b/dlio_benchmark/checkpointing/base_checkpointing.py index 6ce5e45e..d7a6b6e8 100644 --- a/dlio_benchmark/checkpointing/base_checkpointing.py +++ b/dlio_benchmark/checkpointing/base_checkpointing.py @@ -34,7 +34,7 @@ def get_datatype_size(datatype): elif datatype == "fp64": return 8 else: - raise Exception("Unsupported datatype") + raise Exception("Unsupported datatype {datatype}") class BaseCheckpointing(ABC): diff --git a/dlio_benchmark/checkpointing/pytorch_checkpointing.py b/dlio_benchmark/checkpointing/pytorch_checkpointing.py index 8018acb8..e9bd2174 100644 --- a/dlio_benchmark/checkpointing/pytorch_checkpointing.py +++ b/dlio_benchmark/checkpointing/pytorch_checkpointing.py @@ -31,8 +31,10 @@ def get_torch_datatype(datatype): return torch.float16 if datatype == "f64": return torch.float64 - if datatype == "i8": + if datatype == "int8": return torch.int8 + if datatype == "uint8": + return torch.uint8 if datatype == "bf16": # bfloat16 return torch.bfloat16 @@ -54,7 +56,7 @@ def __init__(self): super().__init__("pt") @dlp.log - def get_tensor(self, size, datatype="i8"): + def get_tensor(self, size, datatype="int8"): return torch.ones(size, dtype=get_torch_datatype(datatype)) @dlp.log diff --git a/dlio_benchmark/checkpointing/tf_checkpointing.py b/dlio_benchmark/checkpointing/tf_checkpointing.py index 6e0cb140..df9af7bc 100644 --- a/dlio_benchmark/checkpointing/tf_checkpointing.py +++ b/dlio_benchmark/checkpointing/tf_checkpointing.py @@ -31,8 +31,10 @@ def get_tf_datatype(datatype): return tf.dtypes.float16 if datatype == "f64": return tf.dtypes.float64 - if datatype == "i8": + if datatype == "int8": return tf.dtypes.int8 + if datatype == "uint8": + return tf.dtypes.uint8 if datatype == "bf16": # bfloat16 return tf.dtypes.bfloat16 @@ -54,7 +56,7 @@ def __init__(self): super().__init__("pb") @dlp.log - def get_tensor(self, size, datatype="i8"): + def get_tensor(self, size, datatype="int8"): return tf.random.uniform((size), maxval=100, dtype=get_tf_datatype(datatype)) @dlp.log diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 0f37e6b0..a0d137af 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -101,8 +101,8 @@ class ConfigArguments: epochs_between_evals: int = 1 checkpoint_type: CheckpointLocationType = CheckpointLocationType.RANK_ZERO checkpoint_mechanism: CheckpointMechanismType = CheckpointMechanismType.NONE - checkpoint_model_datatype: str = "f16" - checkpoint_optimizer_datatype: str = "f32" + checkpoint_model_datatype: str = "fp16" + checkpoint_optimizer_datatype: str = "fp32" model_size: int = 10240 vocab_size: int = 32000 hidden_size: int = 2048