From e3363a7159992ae27e84b4ed4dad21942018b650 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 May 2023 11:33:51 -0700 Subject: [PATCH] Support bitsandbytes optimizers in factory --- timm/optim/optim_factory.py | 55 +++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 10950210ed..2e3020119b 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -27,11 +27,6 @@ from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -try: - from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD - has_apex = True -except ImportError: - has_apex = False _logger = logging.getLogger(__name__) @@ -254,9 +249,23 @@ def create_optimizer_v2( opt_lower = opt.lower() opt_split = opt_lower.split('_') opt_lower = opt_split[-1] - if 'fused' in opt_lower: + + if opt_lower.startswith('fused'): + try: + from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD + has_apex = True + except ImportError: + has_apex = False assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' + if opt_lower.startswith('bnb'): + try: + import bitsandbytes as bnb + has_bnb = True + except ImportError: + has_bnb = False + assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers' + opt_args = dict(weight_decay=weight_decay, **kwargs) if lr is not None: @@ -357,6 +366,40 @@ def create_optimizer_v2( opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) + # bitsandbytes optimizers, require bitsandbytes to be installed + elif opt_lower == 'bnbsgd': + opt_args.pop('eps', None) + optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'bnbsgd8bit': + opt_args.pop('eps', None) + optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'bnbmomentum': + opt_args.pop('eps', None) + optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args) + elif opt_lower == 'bnbmomentum8bit': + opt_args.pop('eps', None) + optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args) + elif opt_lower == 'bnbadam': + optimizer = bnb.optim.Adam(parameters, **opt_args) + elif opt_lower == 'bnbadam8bit': + optimizer = bnb.optim.Adam8bit(parameters, **opt_args) + elif opt_lower == 'bnbadamw': + optimizer = bnb.optim.AdamW(parameters, **opt_args) + elif opt_lower == 'bnbadamw8bit': + optimizer = bnb.optim.AdamW8bit(parameters, **opt_args) + elif opt_lower == 'bnblamb': + optimizer = bnb.optim.LAMB(parameters, **opt_args) + elif opt_lower == 'bnblamb8bit': + optimizer = bnb.optim.LAMB8bit(parameters, **opt_args) + elif opt_lower == 'bnblars': + optimizer = bnb.optim.LARS(parameters, **opt_args) + elif opt_lower == 'bnblarsb8bit': + optimizer = bnb.optim.LAMB8bit(parameters, **opt_args) + elif opt_lower == 'bnblion': + optimizer = bnb.optim.Lion(parameters, **opt_args) + elif opt_lower == 'bnblion8bit': + optimizer = bnb.optim.Lion8bit(parameters, **opt_args) + else: assert False and "Invalid optimizer" raise ValueError