Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
gritukan committed Nov 2, 2024
1 parent 51ca40b commit 2e8f1b0
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from nanotron.serialize import (
load_lr_scheduler,
load_meta,
load_random_states,
load_weights,
parse_ckpt_path,
save,
Expand Down Expand Up @@ -170,6 +171,11 @@ def __init__(
self.random_states = init_random_states(
parallel_config=self.config.parallelism, tp_pg=self.parallel_context.tp_pg
)
if self.init_checkpoint_path is not None:
self.random_states = load_random_states(
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)
self.model = self.init_model() # Defines self.model
self.unwrapped_model: NanotronModel = (
self.model.module if isinstance(self.model, DistributedDataParallel) else self.model
Expand Down

0 comments on commit 2e8f1b0

Please sign in to comment.