Skip to content

Commit

Permalink
Handle inplace operation in forward when offloading
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Nov 27, 2024
1 parent abf4a40 commit 21c0cb1
Show file tree
Hide file tree
Showing 3 changed files with 447 additions and 2 deletions.
203 changes: 201 additions & 2 deletions src/brevitas_examples/common/accelerate_utils/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# SPDX-License-Identifier: BSD-3-Clause

import logging
from typing import Dict, Mapping, Optional, Union
from typing import Dict, List, Mapping, Optional, Union

from accelerate import dispatch_model
from accelerate import infer_auto_device_map
from accelerate.hooks import add_hook_to_module
from accelerate.hooks import AlignDevicesHook
from accelerate.hooks import ModelHook
from accelerate.hooks import remove_hook_from_module
from accelerate.hooks import SequentialHook
from accelerate.utils import check_tied_parameters_in_config
from accelerate.utils import compute_module_sizes
from accelerate.utils import find_tied_parameters
Expand All @@ -18,6 +20,7 @@
from accelerate.utils.modeling import named_module_tensors
from psutil import virtual_memory
import torch
from torch import nn

import brevitas.config as config
from brevitas.graph.utils import get_module
Expand Down Expand Up @@ -382,10 +385,205 @@ def calc_cpu_device_map(absolute_mem_margin: float = 2.0 * 1e9,
return cpu_device_map


class UpdateStateDictHook(ModelHook):
"""
`ModelHook` that ensures that in-place operations during the model forward pass update the values
in the weights_maps, thus ensuring that future calls to offload_model result in the updated model
being retrieved.
Args:
offload (`bool`, *optional*, defaults to `False`):
Whether or not the weights should be offloaded after the forward pass.
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
"""

def __init__(
self,
execution_device: Optional[Union[int, str, torch.device]] = None,
offload: bool = False,
weights_map: Optional[Mapping] = None,
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
self.execution_device = execution_device
self.offload = offload
self.weights_map = weights_map

# The hook pre_forward/post_forward need to have knowledge of this dictionary, as updating the values in the state
# dict should remove the old values that might have been cached in each device.
self.tied_params_map = tied_params_map

def __repr__(self):
return (f"UpdateStateDictHook(offload={self.offload})")

def post_forward(self, module, output):
if self.offload:
prefix = self.weights_map.prefix
for key in module.state_dict().keys():
value = recurse_getattr(module, key)
# It might happen that we call an quantization's inner modules, and this cause some parameters to be
# already on meta device. This is not a problem for their value but we need to check here
curr_device = value.device
if str(curr_device) != "meta":
# Check if there is an old value that needs to be replaced
self.weights_map.dataset.state_dict[prefix + key].copy_(value.detach().cpu())

return output


# TODO: Remove depending on whether to go with the first option. Still additional logic needs to be incorporated
# to handle tied parameters.
class UpdateStateDictLegacyHook(ModelHook):
"""
`ModelHook` that ensures that in-place operations during the model forward pass update the values
in the weights_maps, thus ensuring that future calls to offload_model result in the updated model
being retrieved.
Args:
offload (`bool`, *optional*, defaults to `False`):
Whether or not the weights should be offloaded after the forward pass.
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
"""

def __init__(
self,
align_device_hook: AlignDevicesHook,
execution_device: Optional[Union[int, str, torch.device]] = None,
offload: bool = False,
weights_map: Optional[Mapping] = None,
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
self.execution_device = execution_device
self.offload = offload
self.weights_map = weights_map

self.align_device_hook = align_device_hook

# The hook pre_forward/post_forward need to have knowledge of this dictionary, as updating the values in the state
# dict should remove the old values that might have been cached in each device.
self.tied_params_map = tied_params_map

def __repr__(self):
return (f"UpdateStateDictHook(offload={self.offload})")

def post_forward(self, module, output):
if self.offload:
prefix = self.weights_map.prefix
for key in module.state_dict().keys():
value = recurse_getattr(module, key)
# It might happen that we call an quantization's inner modules, and this cause some parameters to be
# already on meta device. This is not a problem for their value but we need to check here
curr_device = value.device
if str(curr_device) != "meta":
# Update tied_pointers_to_remove
update_tied_pointers_to_remove = False
# Check if there is an old value that needs to be replaced
if prefix + key in self.weights_map.dataset.state_dict:
old_value = self.weights_map.dataset.state_dict[prefix + key]
if (old_value is not None and self.tied_params_map is not None and
old_value.data_ptr() in self.tied_params_map):
# Remove from tied_params_map if present there
del self.tied_params_map[old_value.data_ptr()]

if (old_value is not None and
self.align_device_hook.tied_pointers_to_remove is not None and
(old_value.data_ptr(), self.execution_device)
in self.align_device_hook.tied_pointers_to_remove):
self.align_device_hook.tied_pointers_to_remove.remove(
(old_value.data_ptr(), self.execution_device))
# Ensure that the appropiate value is added
update_tied_pointers_to_remove = True
# Move to CPU before storing it in the weights_map
detached_value = value.detach().cpu()
# Reassign in tied_params_map to make sure that the tensor is re-used
self.tied_params_map[detached_value.data_ptr()] = {}
self.tied_params_map[detached_value.data_ptr()][self.execution_device] = value

if update_tied_pointers_to_remove:
self.align_device_hook.tied_pointers_to_remove.add(
(detached_value.data_ptr(), self.execution_device))

# Reassign the tensor in the state_dict
self.weights_map.dataset.state_dict[prefix + key] = detached_value

return output


def add_hook_to_module_with_pre_append(
module: nn.Module, hook: ModelHook, append: bool = False, pre_append: bool = False):
"""
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
<Tip warning={true}>
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
</Tip>
Args:
module (`torch.nn.Module`):
The module to attach a hook to.
hook (`ModelHook`):
The hook to attach.
append (`bool`, *optional*, defaults to `False`):
Whether the hook should be chained after an existing one (if module already contains a hook) or not.
pre_append (`bool`, *optional*, defaults to `False`):
Whether the hook should be chained before an existing one (if module already contains a hook) or not.
Returns:
`torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
be discarded).
"""

if (append or pre_append) and (getattr(module, "_hf_hook", None) is not None):
old_hook = module._hf_hook
remove_hook_from_module(module)
if append and not pre_append:
hook = SequentialHook(old_hook, hook)
if not append and pre_append:
hook = SequentialHook(hook, old_hook)
else:
raise ValueError(
"Setting both append and pre_append to True is not allowed when adding a hook.")

# Append is set to False as the appropiate SequentialHook is already attached
return add_hook_to_module(module, hook, append=False)


def attach_update_state_dict_hook_on_modules(module: nn.Module) -> None:
if hasattr(module, "_hf_hook"):
hf_hooks = module._hf_hook
align_device_hook = None
if isinstance(hf_hooks, SequentialHook):
for hook in hf_hooks.hooks:
if isinstance(hook, AlignDevicesHook):
align_device_hook = hook
break
elif isinstance(hf_hooks, AlignDevicesHook):
align_device_hook = hf_hooks
# If the align devices hook is present, include the update state dict hook
if align_device_hook is not None:
hook = UpdateStateDictHook(
execution_device=align_device_hook.execution_device,
offload=align_device_hook.offload,
weights_map=align_device_hook.weights_map,
tied_params_map=align_device_hook.tied_params_map,
)
# Add hook so post-forward gets run first
add_hook_to_module_with_pre_append(module, hook, pre_append=True)

for child in module.children():
attach_update_state_dict_hook_on_modules(child)


def offload_model(
model: torch.nn.Module,
gpu_device_map: Optional[Dict[int, float]] = None,
cpu_device_map: Optional[Dict[str, float]] = None,
preload_module_classes: Optional[List[str]] = None,
) -> torch.nn.Module:
"""
Wraps accelerate's infer_auto_device_map and dispatch_model.
Expand All @@ -408,7 +606,8 @@ def offload_model(
device_map = infer_auto_device_map(
model, memory_map, no_split_module_classes=model._no_split_modules)

model = dispatch_model(model, device_map)
model = dispatch_model(
model=model, device_map=device_map, preload_module_classes=preload_module_classes)

# Fixes an asymetric behavior in Accelerate where hooks are not attached at all when a single device is used.
# TODO: Fix directly in accelerate.
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.utils import get_module
from brevitas_examples.common.accelerate_utils.accelerate import \
attach_update_state_dict_hook_on_modules
from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
from brevitas_examples.common.generative.quantize import generate_quant_maps
Expand Down Expand Up @@ -363,10 +365,14 @@ def main(args):
model = add_zero_bias_to_linear(model)

model = offload_model(model)
attach_update_state_dict_hook_on_modules(model)

with torch.no_grad():
model(**calibration_loader[0])

remove_hooks(model)
model = offload_model(model)

if args.act_calibration:
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
Expand Down
Loading

0 comments on commit 21c0cb1

Please sign in to comment.