Skip to content

Commit

Permalink
Rename managed_* to externally_managed_*
Browse files Browse the repository at this point in the history
These functions rely on the caller keeping track of which parameters
have been trained, and passing that information in the function call.
  • Loading branch information
Waino committed May 20, 2024
1 parent a197f6b commit 25a8be0
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions mammoth/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
batch_producer,
consumer,
broadcast_tensors,
managed_reduce_and_rescale_grads,
externally_managed_reduce_and_rescale_grads,
ErrorHandler,
)
from mammoth.distributed.contexts import DeviceContext, WorldContext, DeviceContextEnum
Expand All @@ -20,7 +20,7 @@
"batch_producer",
"broadcast_tensors",
"consumer",
"managed_reduce_and_rescale_grads",
"externally_managed_reduce_and_rescale_grads",
"ErrorHandler",
"DeviceContext",
"WorldContext",
Expand Down
2 changes: 1 addition & 1 deletion mammoth/distributed/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def broadcast_tensors(tensors, src=0, group=None):
torch.distributed.broadcast(t, src, group=group)


def managed_reduce_and_rescale_grads(
def externally_managed_reduce_and_rescale_grads(
named_parameters,
has_local_gradient: bool,
gradient_norm: int,
Expand Down
4 changes: 2 additions & 2 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def train(
continue
# logger.warning(f'Syncing {component.get_name()}') # DEBUG
params = component.named_parameters(self.model)
mammoth.distributed.managed_reduce_and_rescale_grads(
mammoth.distributed.externally_managed_reduce_and_rescale_grads(
named_parameters=params,
has_local_gradient=gradient_sync.has_local_gradient,
gradient_norm=gradient_sync.gradient_norm,
Expand All @@ -288,7 +288,7 @@ def train(
self._maybe_update_stats_from_parameters(report_stats, self.model.named_parameters())

# Including single-device components
self.optim.managed_step(gradient_syncs)
self.optim.externally_managed_step(gradient_syncs)
self.optim.zero_grad()

if step % 1000 == 0 and step > 0:
Expand Down
6 changes: 3 additions & 3 deletions mammoth/utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def zero_grad(self):
for name in self.optimizers:
self.optimizers[name].zero_grad()

def managed_step(self, gradient_syncs, grad_scaler=None):
def externally_managed_step(self, gradient_syncs, grad_scaler=None):
"""Step through only the trained suboptimizers"""
trained_components = {
gradient_sync.component.get_name() for gradient_sync in gradient_syncs
Expand Down Expand Up @@ -387,7 +387,7 @@ def backward(self, loss):
else:
loss.backward()

def managed_step(self, *args, **kwargs):
def externally_managed_step(self, *args, **kwargs):
"""Update the model parameters based on current gradients.
Optionally, will employ gradient modification or update learning
Expand Down Expand Up @@ -415,7 +415,7 @@ def managed_step(self, *args, **kwargs):
# Updates the scale for next iteration.
self._scaler.update()
else:
self._optimizer.managed_step(*args, **kwargs)
self._optimizer.externally_managed_step(*args, **kwargs)
self._decay_step += 1
self._training_step += 1

Expand Down

0 comments on commit 25a8be0

Please sign in to comment.