From f6dd49568b378f2ca7aabb25a56e668f3ae9aaaf Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Thu, 15 Feb 2024 11:21:24 +0100 Subject: [PATCH] Implement pipeline parallel-agnostic optimizer state loading --- src/nanotron/optim/zero.py | 7 +-- src/nanotron/serialize/optimizer.py | 77 +++++++++++++++++++---------- 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/src/nanotron/optim/zero.py b/src/nanotron/optim/zero.py index 3da518d5..cb61c8b7 100644 --- a/src/nanotron/optim/zero.py +++ b/src/nanotron/optim/zero.py @@ -348,16 +348,17 @@ def find_optim_index_from_param_name( # NOTE: (pp_rank, dp_rank, tp_rank) or (pp_rank, tp_rank) ckp_sharded_optim_states: Union[Tuple[Tuple[int, int, int], torch.Tensor], Tuple[Tuple[int, int], torch.Tensor]], is_zero1: bool, + pp_rank=0, ) -> int: param_name = param_name.replace("module.", "") # NOTE: since all shards have the same optim state names - # so we take the first shard + # so we take the first shard (except optionally the pp dimension) if is_zero1 is True: # NOTE: (pp_rank, dp_rank, tp_rank) - OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(0, 0, 0)]["names"] + OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0, 0)]["names"] else: # NOTE: (pp_rank, tp_rank) - OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(0, 0)]["names"] + OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0)]["names"] return next((k for k, v in OPTIM_STATE_INDEX_TO_PARAM_NAME.items() if v == param_name), None) diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index 7554a157..96a4cda1 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -141,7 +141,9 @@ def load_optimizer( ckp_tp_size = ckp_optimizer_config["parallelism"]["tp_size"] ckp_dp_size = ckp_optimizer_config["parallelism"]["dp_size"] - if int(ckp_tp_size) != int(parallel_context.tp_pg.size()): + if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int( + parallel_context.pp_pg.size() + ): assert ( param_shard_metadata is not None ), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}" @@ -179,14 +181,17 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - model_state_dict = model.state_dict() new_optim_state_dict = optimizer.state_dict() + # TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"]) # NOTE: because we can only resume training with the same optimizer type # (0, 0) = (pp_rank, tp_rank) # NOTE: also we don't merge "step" because it's just a scalar - - param_names = sorted(model_state_dict.items(), key=lambda x: x[0]) - for param_name, _ in tqdm( - param_names, + param_names = list(model_state_dict.keys()) + new_optim_state_param_names = {} + # NOTE: iterates through all model parameters in the local pipeline parallel rank (hence, might not be the full model). + # Since model parameters and optimizer states are aligned, loads only the optimizer states for these parameters from the checkpoint shards. + for param_index, param_name in tqdm( + enumerate(param_names), disable=dist.get_rank(parallel_context.world_pg) != 0, desc="Topology-agnostic optimizer loading", ): @@ -198,19 +203,27 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - if not isinstance(param, NanotronParameter): raise NotImplementedError("Parameters are required to be NanotronParameter") + # NOTE: for tied parameters, the metadata is stored using the parameter name, + # while the data is stored using the name of the main tied parameter, + # which may be different (e.g. `model.token_position_embeddings.pp_block.token_embedding.weight` + # for `model.lm_head.pp_block.weight`). + base_name = param.get_tied_info().name if param.is_tied else param_name + if param_name != base_name: + # NOTE: skip tied parameter if main tied parameter has already been loaded + # (not always the case if pipeline parallel) + if base_name in new_optim_state_param_names.values(): + continue + new_optim_state_param_names[param_index] = base_name + if param.is_sharded: # NOTE: optimizer states's shape is equal to the parameter's shape # NOTE: sometines an unsharded parameter's shape differ # from an unsharded optimizer state's shape new_shard_metadata = param.get_sharded_info() new_unshared_shape = new_shard_metadata.unsharded_shape - - # NOTE: merging optimizer states - optim_state_index = find_optim_index_from_param_name( - param_name, ckp_sharded_optim_states, is_zero1=False - ) - - new_optim_state_dict["state"][optim_state_index] = {} + new_optim_state_dict["state"][param_index] = {} + # NOTE: restore each state tensor (e.g. exg_avg) by iterating through + # the optimizer state shards saved using the previous topology for state_key in OPTIMIZER_STATE_NAMES: # TODO(xrsrke): free the memory of the shards that isn't # corresponding to the current rank @@ -218,8 +231,21 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - unsharded_buffer = torch.empty(new_unshared_shape, device="cuda") for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items(): - ckp_shard_metadata = get_checkpoint_state_metadata(param_name, pp_rank, tp_rank) - ckp_shard_data = ckp_optim_state["state"][optim_state_index][state_key] + old_optim_state_index = find_optim_index_from_param_name( + base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank + ) + if old_optim_state_index is None: + continue # NOTE: param is not in this pp shard + ckp_shard_data = ckp_optim_state["state"][old_optim_state_index][state_key] + # NOTE: the metadata for the main parameter of a tied parameter might be in a + # different pipeline parallel shard. + if param.is_tied: + metadata_pp_rank = next( + iter(param_shard_metadata[param_name.replace("module.", "")].keys()) + )[0] + else: + metadata_pp_rank = pp_rank + ckp_shard_metadata = get_checkpoint_state_metadata(param_name, metadata_pp_rank, tp_rank) # NOTE: if the checkpoint is from a Zero-1 optimizer, # so it's flattened, so we need to reshape it @@ -229,7 +255,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - orig_shape = [int(dim) for dim in orig_shape] ckp_shard_data = ckp_shard_data.view(orig_shape) - new_optim_state_dict["state"][optim_state_index][state_key] = merge_and_shard_tp_tensors( + new_optim_state_dict["state"][param_index][state_key] = merge_and_shard_tp_tensors( buffer, unsharded_buffer, [ @@ -240,17 +266,16 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - if ckp_optim_type == ZeroDistributedOptimizer.__name__: # NOTE: flatten the optimizer states - new_optim_state_dict["state"][optim_state_index][state_key] = new_optim_state_dict[ - "state" - ][optim_state_index][state_key].flatten() - - new_optim_state_dict["state"][optim_state_index]["step"] = ckp_optim_state["state"][optim_state_index][ - "step" - ] - - # NOTE: since all shards have the same optim state names - # so we take the first shard - new_optim_state_dict["names"] = ckp_sharded_optim_states[(0, 0)]["names"] + new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][ + param_index + ][state_key].flatten() + # NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key, + # try to get the step value as well. + step = ckp_optim_state["state"][old_optim_state_index].get("step") + if step is not None: + new_optim_state_dict["state"][param_index]["step"] = step + + new_optim_state_dict["names"] = new_optim_state_param_names state_dict = new_optim_state_dict else: # TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely