Skip to content

Commit

Permalink
fix(activation_checkpoint.py): fix rng mode in activation ckpt (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 authored Apr 3, 2024
1 parent 27b95d0 commit 0efc2d7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions internlm/core/context/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def set_state(self, parallel_mode: ParallelMode, state: Tensor):
assert parallel_mode in self._seed_states, f"{parallel_mode} not found in seed manager"
self._seed_states[parallel_mode] = state

def set_mode(self, parallel_mode: ParallelMode):
def set_mode(self, parallel_mode: ParallelMode, update_rng_current_mode: bool = True):
"""Sets the current mode of the seed manager."""
if self.current_mode:
if update_rng_current_mode and self.current_mode:
# save state for current mode
self._seed_states[self._current_mode] = internlm_accelerator.get_rng_state()

Expand Down Expand Up @@ -107,9 +107,9 @@ def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
_SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)


def set_mode(parallel_mode: ParallelMode):
def set_mode(parallel_mode: ParallelMode, update_rng_current_mode: bool = True):
"""Sets the current mode of the seed manager."""
_SEED_MANAGER.set_mode(parallel_mode)
_SEED_MANAGER.set_mode(parallel_mode, update_rng_current_mode=update_rng_current_mode)


def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
Expand Down
4 changes: 2 additions & 2 deletions internlm/solver/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def backward(ctx, *args):
torch.set_rng_state(ctx.fwd_cpu_rng_state)
for parallel_mode, state in ctx.fwd_seed_states.items():
set_seed_states(parallel_mode, state)
set_mode(ctx.fwd_current_mode)
set_mode(ctx.fwd_current_mode, update_rng_current_mode=False)
if ctx.activation_offload:
tensors = copy_to_device(tensors, ctx.device)

Expand All @@ -136,7 +136,7 @@ def backward(ctx, *args):
torch.set_rng_state(bwd_cpu_rng_state)
for parallel_mode, state in bwd_seed_states.items():
set_seed_states(parallel_mode, state)
set_mode(bwd_current_mode)
set_mode(bwd_current_mode, update_rng_current_mode=False)

# run backward() with only tensor that requires grad
outputs_with_grad = []
Expand Down

0 comments on commit 0efc2d7

Please sign in to comment.