Skip to content

Commit

Permalink
fix optim states dtype + fix lr scheduler initial value when resuming
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Nov 19, 2024
1 parent 51ca40b commit 4cc62ae
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
25 changes: 25 additions & 0 deletions src/nanotron/sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def before_tbi_sanity_checks(
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that the model params are synchronized across dp
Expand All @@ -84,6 +85,17 @@ def before_tbi_sanity_checks(
msg=lambda err: f"[Before train] Tied weights {name} are not synchronized. {err}",
)

# SANITY CHECK: Check that model grads are zeroed or None
for name, param in unwrapped_model.named_parameters():
if param.grad is not None:
torch.testing.assert_close(
param.grad,
torch.zeros_like(param.grad),
atol=0,
rtol=0,
msg="Model half precision grads must be zeroed or None in first accumulation step.",
)

# SANITY CHECK: Check that the grad accumulator buffers are ready for DDP
if grad_accumulator is not None:
for _, elt in grad_accumulator.fp32_grad_buffers.items():
Expand All @@ -96,6 +108,15 @@ def before_tbi_sanity_checks(
msg="Grad accumulator buffers must be zeroed in first accumulation step.",
)

# TODO: add checks for memory contiguousness

# SANITY CHECK: Check that optimizer's lr is synchronized with lr_scheduler
for i, group in enumerate(lr_scheduler.optimizer.param_groups):
assert (
group["lr"] == lr_scheduler.get_last_lr()[i]
), f"Optimizer and LR scheduler are not in sync. Got {group['lr']} and {lr_scheduler.get_last_lr()[i]}"
break

# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_tbi_sanity_checks()

Expand Down Expand Up @@ -211,6 +232,10 @@ def before_optim_step_sanity_checks(
msg=lambda err: f"[Before optimizer step] Tied weights {name} are not synchronized. {err}",
)

# SANITY CHECK: optimizer's lr is synchronized with lr_scheduler
for group in unwrapped_model.optimizer.param_groups:
assert group["lr"] == unwrapped_model.lr_scheduler.get_last_lr()[0]

# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_optim_step_sanity_checks()

Expand Down
17 changes: 13 additions & 4 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,23 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
# NOTE: if the checkpoint is from a Zero-0 optimizer, then we don't need to merge the shards
# across data parallel dimension, just directly load the checkpoints
shard_paths = list(
root_folder.glob(f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_tp-*-of-{ckp_tp_size}.pt")
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_tp-*-of-{ckp_tp_size}.pt"
) # WARN: wildcard here after tp can hold `0-of-1_exp-0`
)

ckp_sharded_optim_states = {}
for shard_path in shard_paths:
pp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=False)
ckp_sharded_optim_states[(pp_rank, tp_rank)] = torch.load(shard_path, map_location=map_location)
ckp_sharded_optim_states[(pp_rank, tp_rank)] = torch.load(
shard_path, map_location=map_location
) # load all optim states in mem

model_state_dict = model.state_dict()
new_optim_state_dict = optimizer.state_dict()
# TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
OPTIMIZER_STATE_DTYPE = ckp_sharded_optim_states[(0, 0)]["state"][0][OPTIMIZER_STATE_NAMES[0]].dtype
# NOTE: because we can only resume training with the same optimizer type
# (0, 0) = (pp_rank, tp_rank)
# NOTE: also we don't merge "step" because it's just a scalar
Expand Down Expand Up @@ -230,8 +235,9 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
for state_key in OPTIMIZER_STATE_NAMES:
# TODO(xrsrke): free the memory of the shards that isn't
# corresponding to the current rank
buffer = torch.zeros_like(param, device="cuda")
unsharded_buffer = torch.empty(new_unshared_shape, device="cuda")
# TODO: maybe better to allocate memory for all states at once
buffer = torch.zeros_like(param, device="cuda", dtype=OPTIMIZER_STATE_DTYPE)
unsharded_buffer = torch.empty(new_unshared_shape, device="cuda", dtype=OPTIMIZER_STATE_DTYPE)

for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
Expand Down Expand Up @@ -278,6 +284,8 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
if step is not None:
new_optim_state_dict["state"][param_index]["step"] = step

# NOTE: we throw away ckp_optim_state['gradient_accumulator'] which has fp32 grads

new_optim_state_dict["names"] = new_optim_state_param_names
state_dict = new_optim_state_dict
else:
Expand Down Expand Up @@ -319,3 +327,4 @@ def load_lr_scheduler(

state_dict = torch.load(root_folder / lr_scheduler_filename())
lr_scheduler.load_state_dict(state_dict)
lr_scheduler._initial_step() # NOTE: this is required to set the initial learning rate
4 changes: 3 additions & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,9 @@ def train(
def training_step(
self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]
) -> Tuple[Iterable[Dict], Optional[torch.Tensor]]:
before_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)
before_tbi_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.lr_scheduler
)

if self.iteration_step < 5:
log_memory(logger=logger)
Expand Down

0 comments on commit 4cc62ae

Please sign in to comment.