Skip to content

Commit

Permalink
[checkpointio] fix hybrid plugin model save (#6106)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 authored Oct 31, 2024
1 parent 89a9a60 commit c2e8f61
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 38 deletions.
9 changes: 4 additions & 5 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import get_current_device
from colossalai.utils import get_current_device, get_non_persistent_buffers_set

from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
Expand Down Expand Up @@ -105,8 +105,9 @@ def _model_sharder(
yield block, block_size

# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set:
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
Expand Down Expand Up @@ -352,9 +353,7 @@ def _load(name: str):
_load(name)

# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
non_persistent_buffers = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
Expand Down
2 changes: 2 additions & 0 deletions colossalai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ensure_path_exists,
free_storage,
get_current_device,
get_non_persistent_buffers_set,
is_ddp_ignored,
set_seed,
)
Expand All @@ -25,4 +26,5 @@
"set_seed",
"get_current_device",
"is_ddp_ignored",
"get_non_persistent_buffers_set",
]
34 changes: 33 additions & 1 deletion colossalai/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import random
from contextlib import contextmanager
from pathlib import Path
from typing import Callable
from typing import Callable, Optional, Set

import numpy as np
import torch
import torch.nn as nn

from colossalai.accelerator import get_accelerator

Expand Down Expand Up @@ -76,3 +77,34 @@ def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def get_non_persistent_buffers_set(
module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""

if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
)
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
child_non_persistent_set = get_non_persistent_buffers_set(
sub_module, memo, submodule_prefix, remove_duplicate
)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
34 changes: 2 additions & 32 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
to_unpadded_tensor,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
from colossalai.utils import _cast_float, free_storage, get_non_persistent_buffers_set, is_ddp_ignored

from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
Expand Down Expand Up @@ -187,7 +187,7 @@ def __init__(
pin_memory=pin_memory,
)
super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._non_persistent_buffers_set = get_non_persistent_buffers_set(module)
self._cast_buffers()

# register grad hook
Expand Down Expand Up @@ -257,36 +257,6 @@ def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
for p in params_to_ignore:
p._ddp_to_ignore = True

def _get_non_persistent_buffers_set(
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""

if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
)
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
child_non_persistent_set = self._get_non_persistent_buffers_set(
sub_module, memo, submodule_prefix, remove_duplicate
)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set

def _post_forward(self):
"""This function is only triggered for inference."""
access_list = list(self.chunk_manager.accessed_chunks)
Expand Down

0 comments on commit c2e8f61

Please sign in to comment.