From 25a8be0b3537e94f88f0403920ef864770ef9825 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 13 May 2024 09:23:47 +0300 Subject: [PATCH] Rename managed_* to externally_managed_* These functions rely on the caller keeping track of which parameters have been trained, and passing that information in the function call. --- mammoth/distributed/__init__.py | 4 ++-- mammoth/distributed/communication.py | 2 +- mammoth/trainer.py | 4 ++-- mammoth/utils/optimizers.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mammoth/distributed/__init__.py b/mammoth/distributed/__init__.py index 1a6aff66..e00d6aea 100644 --- a/mammoth/distributed/__init__.py +++ b/mammoth/distributed/__init__.py @@ -4,7 +4,7 @@ batch_producer, consumer, broadcast_tensors, - managed_reduce_and_rescale_grads, + externally_managed_reduce_and_rescale_grads, ErrorHandler, ) from mammoth.distributed.contexts import DeviceContext, WorldContext, DeviceContextEnum @@ -20,7 +20,7 @@ "batch_producer", "broadcast_tensors", "consumer", - "managed_reduce_and_rescale_grads", + "externally_managed_reduce_and_rescale_grads", "ErrorHandler", "DeviceContext", "WorldContext", diff --git a/mammoth/distributed/communication.py b/mammoth/distributed/communication.py index daebdd0b..c1ce3c71 100644 --- a/mammoth/distributed/communication.py +++ b/mammoth/distributed/communication.py @@ -36,7 +36,7 @@ def broadcast_tensors(tensors, src=0, group=None): torch.distributed.broadcast(t, src, group=group) -def managed_reduce_and_rescale_grads( +def externally_managed_reduce_and_rescale_grads( named_parameters, has_local_gradient: bool, gradient_norm: int, diff --git a/mammoth/trainer.py b/mammoth/trainer.py index 9950c24f..052fe6e8 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -278,7 +278,7 @@ def train( continue # logger.warning(f'Syncing {component.get_name()}') # DEBUG params = component.named_parameters(self.model) - mammoth.distributed.managed_reduce_and_rescale_grads( + mammoth.distributed.externally_managed_reduce_and_rescale_grads( named_parameters=params, has_local_gradient=gradient_sync.has_local_gradient, gradient_norm=gradient_sync.gradient_norm, @@ -288,7 +288,7 @@ def train( self._maybe_update_stats_from_parameters(report_stats, self.model.named_parameters()) # Including single-device components - self.optim.managed_step(gradient_syncs) + self.optim.externally_managed_step(gradient_syncs) self.optim.zero_grad() if step % 1000 == 0 and step > 0: diff --git a/mammoth/utils/optimizers.py b/mammoth/utils/optimizers.py index 9b99b383..c446c036 100644 --- a/mammoth/utils/optimizers.py +++ b/mammoth/utils/optimizers.py @@ -202,7 +202,7 @@ def zero_grad(self): for name in self.optimizers: self.optimizers[name].zero_grad() - def managed_step(self, gradient_syncs, grad_scaler=None): + def externally_managed_step(self, gradient_syncs, grad_scaler=None): """Step through only the trained suboptimizers""" trained_components = { gradient_sync.component.get_name() for gradient_sync in gradient_syncs @@ -387,7 +387,7 @@ def backward(self, loss): else: loss.backward() - def managed_step(self, *args, **kwargs): + def externally_managed_step(self, *args, **kwargs): """Update the model parameters based on current gradients. Optionally, will employ gradient modification or update learning @@ -415,7 +415,7 @@ def managed_step(self, *args, **kwargs): # Updates the scale for next iteration. self._scaler.update() else: - self._optimizer.managed_step(*args, **kwargs) + self._optimizer.externally_managed_step(*args, **kwargs) self._decay_step += 1 self._training_step += 1