diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 2b7e5fcaa19..ba6d1e1bab7 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1306,6 +1306,7 @@ def infer_auto_device_map( model: nn.Module, max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, no_split_module_classes: Optional[List[str]] = None, + reserve_max_layer: bool = True, dtype: Optional[Union[str, torch.dtype]] = None, special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None, verbose: bool = False, @@ -1339,6 +1340,9 @@ def infer_auto_device_map( no_split_module_classes (`List[str]`, *optional*): A list of layer class names that should never be split across device (for instance any layer that has a residual connection). + reserve_max_layer (`bool`, *optional*, defaults to `True`): + Whether to reserve the maximum layer size for the main devices. This allows more efficient memory allocation + when multiple GPUs are present and no offloading to CPU or disk is needed. dtype (`str` or `torch.dtype`, *optional*): If provided, the weights will be converted to that type when loaded. special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): @@ -1374,21 +1378,24 @@ def infer_auto_device_map( device_minimum_assignment_memory = {} # Initialize maximum largest layer, to know which space to keep in memory - max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes) - + if reserve_max_layer: + max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes) + else: + max_layer_size, max_layer_names = 0, [] # Ready ? This is going to be a bit messy. while len(modules_to_treat) > 0: name, module = modules_to_treat.pop(0) if verbose: print(f"\nTreating module {name}.") # Max size in the remaining layers may have changed since we took one, so we maybe update it. - max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")] - if len(max_layer_names) == 0: - max_layer_size, max_layer_names = get_max_layer_size( - [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], - module_sizes, - no_split_module_classes, - ) + if reserve_max_layer: + max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")] + if len(max_layer_names) == 0: + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) # Assess size needed module_size = module_sizes[name] @@ -1417,8 +1424,8 @@ def infer_auto_device_map( device = devices[current_device] current_max_size = max_memory[device] if device != "disk" else None current_memory_reserved = 0 - # Reduce max size available by the largest layer. - if devices[current_device] in main_devices: + + if devices[current_device] in main_devices and reserve_max_layer: current_max_size = current_max_size - max_layer_size current_memory_reserved = max_layer_size @@ -1497,11 +1504,12 @@ def infer_auto_device_map( + modules_to_treat[tied_module_index + 1 :] ) # Update the max layer size. - max_layer_size, max_layer_names = get_max_layer_size( - [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], - module_sizes, - no_split_module_classes, - ) + if reserve_max_layer: + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) split_happened = True break @@ -1537,11 +1545,12 @@ def infer_auto_device_map( modules_children = list(module.named_parameters(recurse=False)) + modules_children modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat # Update the max layer size. - max_layer_size, max_layer_names = get_max_layer_size( - [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], - module_sizes, - no_split_module_classes, - ) + if reserve_max_layer: + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) continue # If no module is assigned to the current device, we attempt to allocate a fallback module @@ -1573,6 +1582,36 @@ def infer_auto_device_map( device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0} + # before we return, we check if the device map has offloaded layers that aren't accounted for as memory used + # if so, we call infer_auto_device_map again with the reserve_max_layer set to True + if not reserve_max_layer: + if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: + main_device = "cpu" + else: + main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] + + execution_device = { + name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items() + } + + execution_device[""] = main_device + + offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"] + offload = {name: device in offloaded_devices for name, device in device_map.items()} + + if any(offload.values()): + return infer_auto_device_map( + model, + max_memory, + no_split_module_classes, + reserve_max_layer=True, + dtype=dtype, + special_dtypes=special_dtypes, + verbose=verbose, + clean_result=clean_result, + offload_buffers=offload_buffers, + fallback_allocation=fallback_allocation, + ) if clean_result: device_map = clean_device_map(device_map) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 4ee90d07141..37cabd71618 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -563,6 +563,28 @@ def test_infer_auto_device_map(self): ) assert device_map == {"0": 0, "1": 1, "2": 1} + # Setting reserve_max_layer to False prevents unnecessary offloading + model = nn.Sequential(nn.Linear(10,5), nn.Linear(5,5), nn.Linear(5,15)) + gpu_0_mem = 145 * 4 + gpu_1_mem = 100 * 4 + cpu_mem = 700 * 4 # large enough mem + device_map = infer_auto_device_map(model, max_memory={0: gpu_0_mem, 1: gpu_1_mem, 'cpu':cpu_mem}, reserve_max_layer=True) + assert device_map == {'0': 0, '1': 1, '2': 'cpu'} + + device_map = infer_auto_device_map(model, max_memory={0: gpu_0_mem, 1: gpu_1_mem, 'cpu':cpu_mem}, reserve_max_layer=False) + assert device_map == {'0': 0, '1': 0, '2': 1} + + # Setting reserve_max_layer to False doesn't prevent necessary offloading + model = nn.Sequential(nn.Linear(10,5), nn.Linear(5,5), nn.Linear(5,15), nn.Linear(5,15)) + + expected_device_map = {'0': 0, '1': 1, '2': 'cpu', '3': 'cpu'} + device_map = infer_auto_device_map(model, max_memory={0: gpu_0_mem, 1: gpu_1_mem, 'cpu':cpu_mem}, reserve_max_layer=True) + assert device_map == expected_device_map + + device_map = infer_auto_device_map(model, max_memory={0: gpu_0_mem, 1: gpu_1_mem, 'cpu':cpu_mem}, reserve_max_layer=False) + assert device_map == expected_device_map + + def test_infer_auto_device_map_with_tied_weights(self): model = nn.Sequential( OrderedDict([("layer1", ModelForTest()), ("layer2", ModelForTest()), ("layer3", ModelForTest())])