Skip to content

Commit

Permalink
Disk cache support in CPU-only mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Gnome Ann committed Jun 20, 2022
1 parent af07d7a commit 90fd8b1
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 12 deletions.
25 changes: 15 additions & 10 deletions aiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,11 @@ def device_config(config):
import breakmodel
n_layers = utils.num_layers(config)
if(args.breakmodel_gpulayers is not None or (utils.HAS_ACCELERATE and args.breakmodel_disklayers is not None)):
if(args.breakmodel_gpulayers is None):
args.breakmodel_gpulayers = ",".join(["0"] * torch.cuda.device_count())
try:
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
if(not args.breakmodel_gpulayers):
breakmodel.gpu_blocks = []
else:
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
assert len(breakmodel.gpu_blocks) <= torch.cuda.device_count()
s = n_layers
for i in range(len(breakmodel.gpu_blocks)):
Expand Down Expand Up @@ -622,19 +623,16 @@ def device_config(config):
def move_model_to_devices(model):
global generator

if(not vars.breakmodel):
if(not utils.HAS_ACCELERATE and not vars.breakmodel):
if(vars.usegpu):
model = model.half().to(vars.gpu_device)
else:
model = model.to('cpu').float()
generator = model.generate
return

model.half()
gc.collect()

if(utils.HAS_ACCELERATE):
import accelerate
import breakmodel
disk_blocks = breakmodel.disk_blocks
gpu_blocks = breakmodel.gpu_blocks
ram_blocks = len(vars.layers_module_names) - sum(gpu_blocks)
Expand All @@ -646,11 +644,14 @@ def move_model_to_devices(model):
device_map[name] = device
for name in utils.get_missing_module_names(model, list(device_map.keys())):
device_map[name] = breakmodel.primary_device
accelerate.dispatch_model(model, device_map, main_device=breakmodel.primary_device, offload_buffers=True, offload_dir="accelerate-disk-cache")
breakmodel.dispatch_model_ex(model, device_map, main_device=breakmodel.primary_device, offload_buffers=True, offload_dir="accelerate-disk-cache")
gc.collect()
generator = model.generate
return

model.half()
gc.collect()

if(hasattr(model, "transformer")):
model.transformer.wte.to(breakmodel.primary_device)
model.transformer.ln_f.to(breakmodel.primary_device)
Expand Down Expand Up @@ -1874,7 +1875,7 @@ def maybe_use_float16(always_use=False):

# If we're using torch_lazy_loader, we need to get breakmodel config
# early so that it knows where to load the individual model tensors
if(vars.lazy_load and vars.hascuda and vars.breakmodel):
if(utils.HAS_ACCELERATE or vars.lazy_load and vars.hascuda and vars.breakmodel):
device_config(model_config)

# Download model from Huggingface if it does not exist, otherwise load locally
Expand Down Expand Up @@ -2003,6 +2004,10 @@ def new_rebuild_tensor(storage: Union[torch_lazy_loader.LazyTensor, torch.Storag
if(not vars.lazy_load):
device_config(model.config)
move_model_to_devices(model)
elif(utils.HAS_ACCELERATE):
move_model_to_devices(model)
vars.modeldim = get_hidden_size_from_model(model)
generator = model.generate
else:
model = model.to('cpu').float()
vars.modeldim = get_hidden_size_from_model(model)
Expand Down
98 changes: 96 additions & 2 deletions breakmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
The ORIGINAL version of the patch is released under the Apache License 2.0
Copyright 2021 arrmansa
Copyright 2021 finetuneanon
Copyright 2018 The Hugging Face team
Copyright 2018, 2022 The Hugging Face team
Apache License
Expand Down Expand Up @@ -216,11 +216,13 @@
import torch.cuda.comm
import copy
import gc
import os
import sys
import itertools
import bisect
import random
from typing import Optional
import utils
from typing import Dict, List, Optional, Union

from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions

Expand All @@ -234,6 +236,98 @@
primary_device = 0 if torch.cuda.device_count() > 0 else "cpu"


if utils.HAS_ACCELERATE:
from accelerate.hooks import attach_align_device_hook_on_blocks
from accelerate.utils import OffloadedWeightsLoader, check_device_map, extract_submodules_state_dict, offload_state_dict
from accelerate import dispatch_model

def dispatch_model_ex(
model: nn.Module,
device_map: Dict[str, Union[str, int, torch.device]],
main_device: Optional[torch.device] = None,
state_dict: Optional[Dict[str, torch.Tensor]] = None,
offload_dir: Union[str, os.PathLike] = None,
offload_buffers: bool = False,
**kwargs,
):
"""
This is a modified version of
https://github.com/huggingface/accelerate/blob/eeaba598f455fbd2c48661d7e816d3ff25ab050b/src/accelerate/big_modeling.py#L130
that still works when the main device is the CPU.
Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on
the CPU or even the disk.
Args:
model (`torch.nn.Module`):
The model to dispatch.
device_map (`Dict[str, Union[str, int, torch.device]]`):
A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that
`"disk"` is accepted even if it's not a proper value for `torch.device`.
main_device (`str`, `int` or `torch.device`, *optional*):
The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or
`"disk"`.
state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the part of the model that will be kept on CPU.
offload_dir (`str` or `os.PathLike`):
The folder in which to offload the model weights (or where the model weights are already offloaded).
offload_buffers (`bool`, *optional*, defaults to `False`):
Whether or not to offload the buffers with the model parameters.
preload_module_classes (`List[str]`, *optional*):
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
"""
if main_device != "cpu":
return dispatch_model(model, device_map, main_device, state_dict, offload_dir=offload_dir, offload_buffers=offload_buffers, **kwargs)

# Error early if the device map is incomplete.
check_device_map(model, device_map)

offload_devices = ["cpu", "disk"] if main_device != "cpu" else ["disk"]

if main_device is None:
main_device = [d for d in device_map.values() if d not in offload_devices][0]

cpu_modules = [name for name, device in device_map.items() if device == "cpu"] if main_device != "cpu" else []
if state_dict is None and len(cpu_modules) > 0:
state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)

disk_modules = [name for name, device in device_map.items() if device == "disk"]
if offload_dir is None and len(disk_modules) > 0:
raise ValueError(
"We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
f"need to be offloaded: {', '.join(disk_modules)}."
)
if len(disk_modules) > 0 and (
not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json"))
):
disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)
offload_state_dict(offload_dir, disk_state_dict)

execution_device = {
name: main_device if device in offload_devices else device for name, device in device_map.items()
}
offload = {name: device in offload_devices for name, device in device_map.items()}
save_folder = offload_dir if len(disk_modules) > 0 else None
if state_dict is not None or save_folder is not None:
weights_map = OffloadedWeightsLoader(state_dict=state_dict, save_folder=save_folder)
else:
weights_map = None

attach_align_device_hook_on_blocks(
model,
execution_device=execution_device,
offload=offload,
offload_buffers=offload_buffers,
weights_map=weights_map,
**kwargs,
)
model.hf_device_map = device_map
return model


# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expand Down

0 comments on commit 90fd8b1

Please sign in to comment.