Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] optimize infer_auto_device_map for multi-GPU allocation #3321

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 60 additions & 21 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,7 @@ def infer_auto_device_map(
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
reserve_max_layer: bool = True,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
Expand Down Expand Up @@ -1339,6 +1340,9 @@ def infer_auto_device_map(
no_split_module_classes (`List[str]`, *optional*):
A list of layer class names that should never be split across device (for instance any layer that has a
residual connection).
reserve_max_layer (`bool`, *optional*, defaults to `True`):
Whether to reserve the maximum layer size for the main devices. This allows more efficient memory allocation
when multiple GPUs are present and no offloading to CPU or disk is needed.
dtype (`str` or `torch.dtype`, *optional*):
If provided, the weights will be converted to that type when loaded.
special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):
Expand Down Expand Up @@ -1374,21 +1378,24 @@ def infer_auto_device_map(
device_minimum_assignment_memory = {}

# Initialize maximum largest layer, to know which space to keep in memory
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)

if reserve_max_layer:
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
else:
max_layer_size, max_layer_names = 0, []
# Ready ? This is going to be a bit messy.
while len(modules_to_treat) > 0:
name, module = modules_to_treat.pop(0)
if verbose:
print(f"\nTreating module {name}.")
# Max size in the remaining layers may have changed since we took one, so we maybe update it.
max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")]
if len(max_layer_names) == 0:
max_layer_size, max_layer_names = get_max_layer_size(
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
module_sizes,
no_split_module_classes,
)
if reserve_max_layer:
max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")]
if len(max_layer_names) == 0:
max_layer_size, max_layer_names = get_max_layer_size(
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
module_sizes,
no_split_module_classes,
)
# Assess size needed
module_size = module_sizes[name]

Expand Down Expand Up @@ -1417,8 +1424,8 @@ def infer_auto_device_map(
device = devices[current_device]
current_max_size = max_memory[device] if device != "disk" else None
current_memory_reserved = 0
# Reduce max size available by the largest layer.
if devices[current_device] in main_devices:

if devices[current_device] in main_devices and reserve_max_layer:
current_max_size = current_max_size - max_layer_size
current_memory_reserved = max_layer_size

Expand Down Expand Up @@ -1497,11 +1504,12 @@ def infer_auto_device_map(
+ modules_to_treat[tied_module_index + 1 :]
)
# Update the max layer size.
max_layer_size, max_layer_names = get_max_layer_size(
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
module_sizes,
no_split_module_classes,
)
if reserve_max_layer:
max_layer_size, max_layer_names = get_max_layer_size(
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
module_sizes,
no_split_module_classes,
)
split_happened = True
break

Expand Down Expand Up @@ -1537,11 +1545,12 @@ def infer_auto_device_map(
modules_children = list(module.named_parameters(recurse=False)) + modules_children
modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat
# Update the max layer size.
max_layer_size, max_layer_names = get_max_layer_size(
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
module_sizes,
no_split_module_classes,
)
if reserve_max_layer:
max_layer_size, max_layer_names = get_max_layer_size(
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
module_sizes,
no_split_module_classes,
)
continue

# If no module is assigned to the current device, we attempt to allocate a fallback module
Expand Down Expand Up @@ -1573,6 +1582,36 @@ def infer_auto_device_map(

device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0}

# before we return, we check if the device map has offloaded layers that aren't accounted for as memory used
# if so, we call infer_auto_device_map again with the reserve_max_layer set to True
if not reserve_max_layer:
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]

execution_device = {
name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items()
}

execution_device[""] = main_device

offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"]
offload = {name: device in offloaded_devices for name, device in device_map.items()}

if any(offload.values()):
return infer_auto_device_map(
model,
max_memory,
no_split_module_classes,
reserve_max_layer=True,
dtype=dtype,
special_dtypes=special_dtypes,
verbose=verbose,
clean_result=clean_result,
offload_buffers=offload_buffers,
fallback_allocation=fallback_allocation,
)
if clean_result:
device_map = clean_device_map(device_map)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,28 @@ def test_infer_auto_device_map(self):
)
assert device_map == {"0": 0, "1": 1, "2": 1}

# Setting reserve_max_layer to False prevents unnecessary offloading
model = nn.Sequential(nn.Linear(10,5), nn.Linear(5,5), nn.Linear(5,15))
gpu_0_mem = 145 * 4
gpu_1_mem = 100 * 4
cpu_mem = 700 * 4 # large enough mem
device_map = infer_auto_device_map(model, max_memory={0: gpu_0_mem, 1: gpu_1_mem, 'cpu':cpu_mem}, reserve_max_layer=True)
assert device_map == {'0': 0, '1': 1, '2': 'cpu'}

device_map = infer_auto_device_map(model, max_memory={0: gpu_0_mem, 1: gpu_1_mem, 'cpu':cpu_mem}, reserve_max_layer=False)
assert device_map == {'0': 0, '1': 0, '2': 1}

# Setting reserve_max_layer to False doesn't prevent necessary offloading
model = nn.Sequential(nn.Linear(10,5), nn.Linear(5,5), nn.Linear(5,15), nn.Linear(5,15))

expected_device_map = {'0': 0, '1': 1, '2': 'cpu', '3': 'cpu'}
device_map = infer_auto_device_map(model, max_memory={0: gpu_0_mem, 1: gpu_1_mem, 'cpu':cpu_mem}, reserve_max_layer=True)
assert device_map == expected_device_map

device_map = infer_auto_device_map(model, max_memory={0: gpu_0_mem, 1: gpu_1_mem, 'cpu':cpu_mem}, reserve_max_layer=False)
assert device_map == expected_device_map


def test_infer_auto_device_map_with_tied_weights(self):
model = nn.Sequential(
OrderedDict([("layer1", ModelForTest()), ("layer2", ModelForTest()), ("layer3", ModelForTest())])
Expand Down