Skip to content

Commit

Permalink
Support for custom checkpointing. (#137)
Browse files Browse the repository at this point in the history
* support for custom checkpointing.

* add docs for custom checkpointing.

* add finalization code.
  • Loading branch information
hariharan-devarajan authored Jan 30, 2024
1 parent 937e6c2 commit 95fdd59
Show file tree
Hide file tree
Showing 19 changed files with 504 additions and 129 deletions.
Empty file.
90 changes: 90 additions & 0 deletions dlio_benchmark/checkpointing/base_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from abc import ABC, abstractmethod

from dlio_benchmark.common.enumerations import CheckpointLocationType
from dlio_benchmark.storage.storage_factory import StorageFactory
from dlio_benchmark.utils.config import ConfigArguments
from dlio_benchmark.utils.utility import DLIOMPI


class BaseCheckpointing(ABC):

def __init__(self, ext):
self.ext = ext
self.args = ConfigArguments.get_instance()
checkpoint_storage = StorageFactory().get_storage(self.args.storage_type, self.args.checkpoint_folder,
self.args.framework)
checkpoint_storage.create_namespace(exist_ok=True)
rank_to_checkpoint = self.args.my_rank
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
rank_to_checkpoint = 0
if rank_to_checkpoint == self.args.my_rank:
self.model_state = None
if self.args.model_size > 0:
self.model_state = {"a": self.get_tensor(self.args.model_size)}
self.optimization_state = None
if len(self.args.optimization_groups) > 0:
self.optimization_state = dict()
tensor_array_size = 0
for index, state in enumerate(self.args.optimization_groups):
if state > 0:
self.optimization_state[str(index)] = {'a': self.get_tensor(state),
'b': self.get_tensor(state)}
tensor_array_size += state
self.optimization_state["combined"] = self.get_tensor(tensor_array_size)
self.layer_state = None
if len(self.args.layer_parameters) > 0:
self.layer_state = dict()
for index, state in enumerate(self.args.layer_parameters):
if state > 0:
self.layer_state[str(index)] = self.get_tensor(state)

@abstractmethod
def get_tensor(self, size):
return []

@abstractmethod
def save_state(self, suffix, state):
pass

def get_name(self, suffix):
return os.path.join(self.args.checkpoint_folder, f"{suffix}.{self.ext}")

@abstractmethod
def checkpoint(self, epoch, step_number):
rank_to_checkpoint = DLIOMPI.get_instance().rank()
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
rank_to_checkpoint = 0
if rank_to_checkpoint == DLIOMPI.get_instance().rank():
my_rank = DLIOMPI.get_instance().rank()
if self.model_state:
self.save_state(suffix=f"model-{epoch}-{step_number}-{my_rank}", state=self.model_state)
if self.optimization_state:
self.save_state(suffix=f"optimizer-{epoch}-{step_number}-{my_rank}", state=self.optimization_state)
if rank_to_checkpoint % self.args.pipeline_parallelism == 0:
if self.layer_state and self.args.num_layers > 0:
total_layers = self.args.num_layers
if self.args.tensor_parallelism > 1:
total_layers = total_layers + self.args.tensor_parallelism
for layer in range(total_layers):
self.save_state(suffix=f"layer-{layer}-{epoch}-{step_number}-{my_rank}", state=self.layer_state)

@abstractmethod
def finalize(self):
pass
43 changes: 43 additions & 0 deletions dlio_benchmark/checkpointing/checkpointing_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging

from dlio_benchmark.common.enumerations import CheckpointMechanismType
from dlio_benchmark.common.error_code import ErrorCodes
from dlio_benchmark.utils.config import ConfigArguments
from dlio_benchmark.utils.utility import utcnow


class CheckpointingFactory(object):
def __init__(self):
pass

@staticmethod
def get_mechanism(checkpoint_mechanism_type):
_args = ConfigArguments.get_instance()
if _args.checkpoint_mechanism_class is not None:
logging.info(f"{utcnow()} Running DLIO with custom checkpointing mechanism "
f"class {_args.checkpoint_mechanism_class.__name__}")
return _args.checkpoint_mechanism_class.get_instance()
elif checkpoint_mechanism_type == CheckpointMechanismType.TF_SAVE:
from dlio_benchmark.checkpointing.tf_checkpointing import TFCheckpointing
return TFCheckpointing.get_instance()
elif checkpoint_mechanism_type == CheckpointMechanismType.PT_SAVE:
from dlio_benchmark.checkpointing.pytorch_checkpointing import PyTorchCheckpointing
return PyTorchCheckpointing.get_instance()
else:
raise Exception(str(ErrorCodes.EC1005))
61 changes: 61 additions & 0 deletions dlio_benchmark/checkpointing/pytorch_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import torch

from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing
from dlio_profiler.logger import fn_interceptor as Profile

from dlio_benchmark.common.constants import MODULE_CHECKPOINT
from dlio_benchmark.common.enumerations import CheckpointLocationType
from dlio_benchmark.utils.utility import DLIOMPI

dlp = Profile(MODULE_CHECKPOINT)


class PyTorchCheckpointing(BaseCheckpointing):
__instance = None

@staticmethod
def get_instance():
""" Static access method. """
if PyTorchCheckpointing.__instance is None:
PyTorchCheckpointing.__instance = PyTorchCheckpointing()
return PyTorchCheckpointing.__instance

@dlp.log_init
def __init__(self):
super().__init__("pt")

@dlp.log
def get_tensor(self, size):
return torch.randint(high=1, size=(size,), dtype=torch.int8)

@dlp.log
def save_state(self, suffix, state):
name = self.get_name(suffix)
with open(name, "wb") as f:
torch.save(state, f)

@dlp.log
def checkpoint(self, epoch, step_number):
super().checkpoint(epoch, step_number)

@dlp.log
def finalize(self):
super().finalize()

61 changes: 61 additions & 0 deletions dlio_benchmark/checkpointing/tf_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os

from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing
from dlio_profiler.logger import fn_interceptor as Profile
import tensorflow as tf

from dlio_benchmark.common.constants import MODULE_CHECKPOINT
from dlio_benchmark.common.enumerations import CheckpointLocationType
from dlio_benchmark.utils.utility import DLIOMPI

dlp = Profile(MODULE_CHECKPOINT)


class TFCheckpointing(BaseCheckpointing):
__instance = None

@staticmethod
def get_instance():
""" Static access method. """
if TFCheckpointing.__instance is None:
TFCheckpointing.__instance = TFCheckpointing()
return TFCheckpointing.__instance

@dlp.log_init
def __init__(self):
super().__init__("pb")

@dlp.log
def get_tensor(self, size):
return tf.random.uniform((int(size / 4),), maxval=100, dtype=tf.dtypes.int32)

@dlp.log
def save_state(self, suffix, state):
name = self.get_name(suffix)
checkpoint = tf.train.Checkpoint()
checkpoint.mapped = state
checkpoint.save(name)

@dlp.log
def checkpoint(self, epoch, step_number):
super().checkpoint(epoch, step_number)

@dlp.log
def finalize(self):
super().finalize()
1 change: 1 addition & 0 deletions dlio_benchmark/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'''
MODULE_DATA_LOADER = "data_loader"
MODULE_AI_FRAMEWORK = "ai_framework"
MODULE_CHECKPOINT = "checkpoint"
MODULE_DATA_READER = "reader"
MODULE_DATA_GENERATOR = "generator"
MODULE_STORAGE = "storage"
Expand Down
15 changes: 14 additions & 1 deletion dlio_benchmark/common/enumerations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,22 @@

from enum import Enum


class CheckpointMechanismType(Enum):
"""
Different Checkpoint mechanisms.
"""
NONE = 'none'
CUSTOM = 'custom'
TF_SAVE = 'tf_save'
PT_SAVE = 'pt_save'

def __str__(self):
return self.value

class CheckpointLocationType(Enum):
"""
Different types of underlying storage
Different types of Checkpointing Locations
"""
RANK_ZERO = 'rank_zero'
ALL_RANKS = 'all_ranks'
Expand Down
3 changes: 2 additions & 1 deletion dlio_benchmark/common/error_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ class ErrorCodes:
EC1001 = {1001, "ERROR: Incorrect Format Type"}
EC1002 = {1002, "ERROR: Invalid Parameter Combination"}
EC1003 = {1003, "ERROR: Invalid Data Loader"}
EC1004 = {1004, "ERROR: Not supported"}
EC1004 = {1004, "ERROR: Not supported"}
EC1005 = {1005, "ERROR: Invalid Checkpointing Mechanism"}
7 changes: 0 additions & 7 deletions dlio_benchmark/framework/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class Framework(ABC):
def __init__(self):
self.args = ConfigArguments.get_instance()
self.output_folder = self.args.output_folder
self.checkpoint_folder = self.args.checkpoint_folder


@abstractmethod
Expand All @@ -53,9 +52,6 @@ def init_loader(self, format_type, epoch, data_loader=None):
self.reader_valid = DataLoaderFactory.get_loader(data_loader, format_type,
dataset_type=DatasetType.VALID, epoch=epoch)
self.storage = StorageFactory().get_storage(self.args.storage_type, self.args.storage_root, self.args.framework)
checkpoint_storage = StorageFactory().get_storage(self.args.storage_type, self.checkpoint_folder,
self.args.framework)
checkpoint_storage.create_namespace(exist_ok=True)

@abstractmethod
def get_type(self):
Expand All @@ -73,9 +69,6 @@ def stop_framework_profiler(self):
def trace_object(self, string, step, r):
pass

def checkpoint(self, epoch, step_number):
pass

def model(epoch, epoch_number, step, computation_time):
sleep(computation_time)

Expand Down
Loading

0 comments on commit 95fdd59

Please sign in to comment.