Skip to content

Commit

Permalink
feat(comm/attn_offload.py): support selective ckpt and cpu offload (#383
Browse files Browse the repository at this point in the history
)
  • Loading branch information
huangting4201 authored Dec 31, 2024
1 parent 141e9eb commit e3f5001
Show file tree
Hide file tree
Showing 8 changed files with 549 additions and 11 deletions.
3 changes: 3 additions & 0 deletions internlm/core/parallel/comm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .attn_offload import get_offload_manager, initialize_offload_manager

__all__ = ["initialize_offload_manager", "get_offload_manager"]
127 changes: 127 additions & 0 deletions internlm/core/parallel/comm/attn_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch

from internlm.utils.common import get_current_device

global_attn_offload = None


class AttnOffloadManager:
"""
A manager for attention output CPU offloading and GPU prefetch loading.
"""

def __init__(self, enable_cpu_offload: bool = False) -> None:
# cpu offload overlapping
self.cpu_offload = enable_cpu_offload
# layer id mapping to flash attn output
self.fa_output_mapping = {}
self.fa_stream = torch.cuda.Stream()
self.d2h_final_event = torch.cuda.Event()
self.h2d_final_event = torch.cuda.Event()
# prepare for tensor buffer
self.tensor_id_to_tensor_bufs = {}

def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id):
"""Get tensor buffer for offloaded tensor."""
layer_id = layer_id % 2
if layer_id not in self.tensor_id_to_tensor_bufs:
self.tensor_id_to_tensor_bufs[layer_id] = {}

if tensor_id not in self.tensor_id_to_tensor_bufs[layer_id]:
allocate_new_buf = True
else:
tensor_buf = self.tensor_id_to_tensor_bufs[layer_id][tensor_id]
allocate_new_buf = tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype

if allocate_new_buf:
# supposed to only execute once
buffer = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
device=tensor.device,
)

self.tensor_id_to_tensor_bufs[layer_id][tensor_id] = buffer

return self.tensor_id_to_tensor_bufs[layer_id][tensor_id]

def insert_fa_output_with_layer(self, layer_idx, output):
assert layer_idx not in self.fa_output_mapping
if self.cpu_offload is False:
self.fa_output_mapping[layer_idx] = output
return

tensors = []
for tensor_id, tensor in enumerate(output):
if tensor is None:
tensors.append(None)
continue
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, layer_idx, tensor_id)
tensor_buf.copy_(tensor)
tensors.append(tensor_buf)
self.fa_output_mapping[layer_idx] = tensors

def get_fa_output_with_layer(self, layer_idx):
assert layer_idx in self.fa_output_mapping
return self.fa_output_mapping.pop(layer_idx)

def offload_fa_output_with_layer(self, layer_idx):
assert layer_idx in self.fa_output_mapping

self.fa_stream.wait_stream(torch.cuda.current_stream())
self.fa_stream.wait_event(self.d2h_final_event)

with torch.cuda.stream(self.fa_stream):
_gpu_tensors = self.fa_output_mapping.pop(layer_idx)
_cpu_tensors = []
for _tensor in _gpu_tensors:
if _tensor is None:
_cpu_tensors.append(_tensor)
continue

_cpu_backup = torch.empty(
_tensor.size(),
dtype=_tensor.dtype,
layout=_tensor.layout,
device="cpu",
pin_memory=True,
)
_cpu_backup.copy_(_tensor, non_blocking=True)
_cpu_tensors.append(_cpu_backup)

# _cpu_tensors.append(_tensor.to("cpu", non_blocking=False))

self.fa_output_mapping[layer_idx] = _cpu_tensors

self.fa_stream.record_event(self.d2h_final_event)

def preload_fa_output_with_layer(self, layer_idx):
assert layer_idx in self.fa_output_mapping

self.fa_stream.wait_stream(torch.cuda.current_stream())
self.fa_stream.wait_event(self.h2d_final_event)

# Important: get device before with stream, in stream get device is error
_device = get_current_device()
with torch.cuda.stream(self.fa_stream):
_cpu_tensors = self.fa_output_mapping.pop(layer_idx)
self.fa_output_mapping[layer_idx] = [
_tensor.to(device=_device, non_blocking=True) if _tensor is not None else _tensor
for _tensor in _cpu_tensors
]

self.fa_stream.record_event(self.h2d_final_event)


def initialize_offload_manager(enable_cpu_offload: bool = False):
global global_attn_offload
if global_attn_offload is None:
global_attn_offload = AttnOffloadManager(enable_cpu_offload)

return global_attn_offload


def get_offload_manager():
assert global_attn_offload is not None
return global_attn_offload
34 changes: 34 additions & 0 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
params_dispatch_with_condition,
)

from .attn_offload import get_offload_manager


# not really useful, only for code hint.
class WPCommunicator(ABC):
Expand Down Expand Up @@ -306,6 +308,7 @@ def __init__(
overlap: bool = False,
process_group: dist.ProcessGroup = None,
is_moe: bool = False,
selective_ckpt_offload: bool = False,
) -> None:
self.process_group = process_group
self.overlap = overlap
Expand All @@ -316,6 +319,14 @@ def __init__(
self._forward_prefetch_prerequisites = []
self._forward_overlap_per = self._get_forward_overlap_granularity()
self._launch_before_module = self._get_launch_before_module()
# As an optimization, do not release weight after forward for the last
# transformer block since wp would prefetch it immediately
self.layers_wp_not_release = [] # [gpc.config.isp_num_layers - 1]
self.layers_fa_not_release = [
gpc.config.isp_num_layers - 1,
int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) - 1,
]
self.sc_offload = selective_ckpt_offload

# real overlap state for each chunk.
self._overlap_states: Dict[int, ISPOverlapState] = {}
Expand Down Expand Up @@ -411,6 +422,7 @@ def is_allgather_launch_module(name, module):
self._overlap_states[cid].index_to_isp_modules[idx].append(child)

setattr(child, "isp_name", name)
setattr(child, "isp_layer_idx", idx)

full_name = f"{cid}.{idx}.{name}"
setattr(
Expand Down Expand Up @@ -506,6 +518,25 @@ def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args)
if block_index + 1 < self._num_blocks:
self._all_gather_block_weight(block_index + 1)

# register offload and prefetch hook for selective ckpt with wo linear
if self.sc_offload is True:
# move current layer's attn output from GPU to CPU asynchronizely
if (
self.is_forward is True
and gpc.config.selective_checkpoint
and block_index not in self.layers_fa_not_release
and block_index < self._ckpt_block_num
):
get_offload_manager().offload_fa_output_with_layer(layer_idx=block_index)

# load previous layer's attn output from CPU to GPU asynchronizely
if (
self.is_forward is False
and gpc.config.selective_checkpoint
and (0 <= (block_index - 1) < self._ckpt_block_num)
):
get_offload_manager().preload_fa_output_with_layer(layer_idx=block_index - 1)

def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
if module not in self._weight_global_handle:
self._all_gather_module_weight(module)
Expand Down Expand Up @@ -539,6 +570,9 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis
self._all_gather_module_weight(next_module)

def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
if int(module.isp_layer_idx) in self.layers_wp_not_release:
# print(f"the layer {module.isp_layer_idx} after forward not clear weight")
return
if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False):
self._clear_handle(module)
self._clear_weight(module)
Expand Down
4 changes: 4 additions & 0 deletions internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from internlm.checkpoint.checkpoint_manager import CheckpointManager
from internlm.core.context import global_context as gpc
from internlm.core.context.process_group_initializer import ParallelMode
from internlm.core.parallel.comm import initialize_offload_manager
from internlm.core.trainer import Trainer
from internlm.data.streaming.utils import streaming_simple_resume
from internlm.data.train_state import get_train_state
Expand Down Expand Up @@ -118,6 +119,9 @@ def __init__(
# initialize isp communicator
isp_communicator = initialize_parallel_communicator(model)

# initialize cpu offload manager for selective checkpoint
initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False))

# initialize train state
train_state = get_train_state(train_dl)

Expand Down
37 changes: 30 additions & 7 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,22 @@ def get_default_parser():
def args_sanity_check():
assert gpc.config is not None, "config is not load!"

gpc.is_forward = True

if "JOB_NAME" not in gpc.config:
gpc.config._add_item("JOB_NAME", "AnonymousJob")

# the default model type is INTERNLM
if "model_type" not in gpc.config:
gpc.config._add_item("model_type", ModelType.INTERNLM.name)

if gpc.config.model_type == "InternLM3_M":
# TODO: need check for isp overlap
num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers
else:
num_layers = gpc.config.model.num_layers
gpc.config.isp_num_layers = num_layers

if "use_apex_adam" not in gpc.config:
gpc.config._add_item("use_apex_adam", False)

Expand Down Expand Up @@ -388,17 +397,18 @@ def args_sanity_check():
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name)
if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name
assert (
gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0
), "VOCAB_SIZE must be integer multiple of tensor parallel size"
if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name:
assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp"
assert (
torch.__version__ >= "2.1.0"
), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}"
assert (
gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0
), "VOCAB_SIZE must be integer multiple of wp size"

assert (
gpc.config.model.vocab_size % gpc.config.parallel.weight.size == 0
), "model.vocab_size must be integer multiple of weight parallel size"
assert (
gpc.config.model.vocab_size % gpc.config.parallel.tensor.size == 0
), "model.vocab_size must be integer multiple of tensor parallel size"

assert gpc.config.parallel["tensor"].get("mode", None) in [
TensorParallelMode.mtp.name,
Expand Down Expand Up @@ -524,7 +534,20 @@ def args_sanity_check():
gpc.config.loss._add_item("moe_loss_coeff", 1.0)

if "selective_checkpoint" not in gpc.config:
gpc.config._add_item("selective_checkpoint", False)
gpc.config.selective_checkpoint = False
if "selective_checkpoint_offload" not in gpc.config:
gpc.config.selective_checkpoint_offload = False
if gpc.config.selective_checkpoint is True:
assert (
gpc.config.parallel["tensor"]["mode"] == "isp"
), "When using selective_checkpoint, tensor parallel mode must be isp"
if gpc.config.selective_checkpoint_offload is True:
assert (
gpc.config.selective_checkpoint is True
), "When using selective_checkpoint_offload, selective_checkpoint must be True"
assert (
gpc.config.parallel.weight.launch_allgather_before == "wo"
), "When using selective_checkpoint_offload, wp launch allgather communication should be set before 'wo' module"

# moe not support overlap and zero1.5 for now
if gpc.config.model.get("num_experts", 1) > 1:
Expand Down
Loading

0 comments on commit e3f5001

Please sign in to comment.