From 339960f6da75fe4150ff76603818d7e1d9cdd895 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 27 May 2024 11:18:21 +0300 Subject: [PATCH] Remove obsolete multiOptims_Langs --- mammoth/utils/optimizers.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/mammoth/utils/optimizers.py b/mammoth/utils/optimizers.py index ef4e88c7..8d1e7b58 100644 --- a/mammoth/utils/optimizers.py +++ b/mammoth/utils/optimizers.py @@ -1,13 +1,10 @@ """ Optimizers class """ import functools -import importlib import math import torch import torch.optim as optim -import types from collections import Counter from math import sqrt -from mammoth.utils.misc import fn_args from torch.nn.utils import clip_grad_norm_ from mammoth.utils.logging import logger @@ -170,11 +167,10 @@ def linear_warmup_decay(step, warmup_steps, rate, train_steps): class MultipleOptimizer(object): - """Implement multiple optimizers""" + """Separate optimizers for each distributed component""" - def __init__(self, op, multiOptims_Langs=None): + def __init__(self, op): self.optimizers = op - self.multiOptims_Langs = multiOptims_Langs self._steps = Counter() @property @@ -212,7 +208,6 @@ def state_dict(self): """Returns the state dictionary""" return { 'optimizers': {k: v.state_dict() for k, v in self.optimizers.items()}, - 'multiOptims_Langs': self.multiOptims_Langs, 'steps': self._steps, } @@ -231,7 +226,6 @@ def load_state_dict(self, state_dict): else: logger.info("Some components do not match. Do not load optimizer from checkpoint.") - self.multiOptims_Langs = state_dict["multiOptims_Langs"] self._steps = state_dict["steps"]