Skip to content

Commit

Permalink
Merge pull request #71 from nopperl/topology-agnostic-loading
Browse files Browse the repository at this point in the history
Implement pipeline parallel size-agnostic optimizer state loading
  • Loading branch information
NouamaneTazi authored Feb 16, 2024
2 parents 1676cec + f6dd495 commit 372fdc1
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 29 deletions.
7 changes: 4 additions & 3 deletions src/nanotron/optim/zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,17 @@ def find_optim_index_from_param_name(
# NOTE: (pp_rank, dp_rank, tp_rank) or (pp_rank, tp_rank)
ckp_sharded_optim_states: Union[Tuple[Tuple[int, int, int], torch.Tensor], Tuple[Tuple[int, int], torch.Tensor]],
is_zero1: bool,
pp_rank=0,
) -> int:
param_name = param_name.replace("module.", "")
# NOTE: since all shards have the same optim state names
# so we take the first shard
# so we take the first shard (except optionally the pp dimension)
if is_zero1 is True:
# NOTE: (pp_rank, dp_rank, tp_rank)
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(0, 0, 0)]["names"]
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0, 0)]["names"]
else:
# NOTE: (pp_rank, tp_rank)
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(0, 0)]["names"]
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0)]["names"]

return next((k for k, v in OPTIM_STATE_INDEX_TO_PARAM_NAME.items() if v == param_name), None)

Expand Down
77 changes: 51 additions & 26 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def load_optimizer(
ckp_dp_size = ckp_optimizer_config["parallelism"]["dp_size"]
ckpt_expert_parallel_size = ckp_optimizer_config["parallelism"]["expert_parallel_size"]

if int(ckp_tp_size) != int(parallel_context.tp_pg.size()):
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int(
parallel_context.pp_pg.size()
):
assert (
param_shard_metadata is not None
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}"
Expand Down Expand Up @@ -182,14 +184,17 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -

model_state_dict = model.state_dict()
new_optim_state_dict = optimizer.state_dict()
# TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
# NOTE: because we can only resume training with the same optimizer type
# (0, 0) = (pp_rank, tp_rank)
# NOTE: also we don't merge "step" because it's just a scalar

param_names = sorted(model_state_dict.items(), key=lambda x: x[0])
for param_name, _ in tqdm(
param_names,
param_names = list(model_state_dict.keys())
new_optim_state_param_names = {}
# NOTE: iterates through all model parameters in the local pipeline parallel rank (hence, might not be the full model).
# Since model parameters and optimizer states are aligned, loads only the optimizer states for these parameters from the checkpoint shards.
for param_index, param_name in tqdm(
enumerate(param_names),
disable=dist.get_rank(parallel_context.world_pg) != 0,
desc="Topology-agnostic optimizer loading",
):
Expand All @@ -201,28 +206,49 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
if not isinstance(param, NanotronParameter):
raise NotImplementedError("Parameters are required to be NanotronParameter")

# NOTE: for tied parameters, the metadata is stored using the parameter name,
# while the data is stored using the name of the main tied parameter,
# which may be different (e.g. `model.token_position_embeddings.pp_block.token_embedding.weight`
# for `model.lm_head.pp_block.weight`).
base_name = param.get_tied_info().name if param.is_tied else param_name
if param_name != base_name:
# NOTE: skip tied parameter if main tied parameter has already been loaded
# (not always the case if pipeline parallel)
if base_name in new_optim_state_param_names.values():
continue
new_optim_state_param_names[param_index] = base_name

if param.is_sharded:
# NOTE: optimizer states's shape is equal to the parameter's shape
# NOTE: sometines an unsharded parameter's shape differ
# from an unsharded optimizer state's shape
new_shard_metadata = param.get_sharded_info()
new_unshared_shape = new_shard_metadata.unsharded_shape

# NOTE: merging optimizer states
optim_state_index = find_optim_index_from_param_name(
param_name, ckp_sharded_optim_states, is_zero1=False
)

new_optim_state_dict["state"][optim_state_index] = {}
new_optim_state_dict["state"][param_index] = {}
# NOTE: restore each state tensor (e.g. exg_avg) by iterating through
# the optimizer state shards saved using the previous topology
for state_key in OPTIMIZER_STATE_NAMES:
# TODO(xrsrke): free the memory of the shards that isn't
# corresponding to the current rank
buffer = torch.zeros_like(param, device="cuda")
unsharded_buffer = torch.empty(new_unshared_shape, device="cuda")

for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
ckp_shard_metadata = get_checkpoint_state_metadata(param_name, pp_rank, tp_rank)
ckp_shard_data = ckp_optim_state["state"][optim_state_index][state_key]
old_optim_state_index = find_optim_index_from_param_name(
base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank
)
if old_optim_state_index is None:
continue # NOTE: param is not in this pp shard
ckp_shard_data = ckp_optim_state["state"][old_optim_state_index][state_key]
# NOTE: the metadata for the main parameter of a tied parameter might be in a
# different pipeline parallel shard.
if param.is_tied:
metadata_pp_rank = next(
iter(param_shard_metadata[param_name.replace("module.", "")].keys())
)[0]
else:
metadata_pp_rank = pp_rank
ckp_shard_metadata = get_checkpoint_state_metadata(param_name, metadata_pp_rank, tp_rank)

# NOTE: if the checkpoint is from a Zero-1 optimizer,
# so it's flattened, so we need to reshape it
Expand All @@ -232,7 +258,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
orig_shape = [int(dim) for dim in orig_shape]
ckp_shard_data = ckp_shard_data.view(orig_shape)

new_optim_state_dict["state"][optim_state_index][state_key] = merge_and_shard_tp_tensors(
new_optim_state_dict["state"][param_index][state_key] = merge_and_shard_tp_tensors(
buffer,
unsharded_buffer,
[
Expand All @@ -243,17 +269,16 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -

if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: flatten the optimizer states
new_optim_state_dict["state"][optim_state_index][state_key] = new_optim_state_dict[
"state"
][optim_state_index][state_key].flatten()

new_optim_state_dict["state"][optim_state_index]["step"] = ckp_optim_state["state"][optim_state_index][
"step"
]

# NOTE: since all shards have the same optim state names
# so we take the first shard
new_optim_state_dict["names"] = ckp_sharded_optim_states[(0, 0)]["names"]
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][
param_index
][state_key].flatten()
# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
# try to get the step value as well.
step = ckp_optim_state["state"][old_optim_state_index].get("step")
if step is not None:
new_optim_state_dict["state"][param_index]["step"] = step

new_optim_state_dict["names"] = new_optim_state_param_names
state_dict = new_optim_state_dict
else:
# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely
Expand Down

0 comments on commit 372fdc1

Please sign in to comment.