Skip to content

Commit

Permalink
Support bitsandbytes optimizers in factory
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed May 9, 2023
1 parent 21e57c0 commit e3363a7
Showing 1 changed file with 49 additions and 6 deletions.
55 changes: 49 additions & 6 deletions timm/optim/optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP

try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -254,9 +249,23 @@ def create_optimizer_v2(
opt_lower = opt.lower()
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if 'fused' in opt_lower:

if opt_lower.startswith('fused'):
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'

if opt_lower.startswith('bnb'):
try:
import bitsandbytes as bnb
has_bnb = True
except ImportError:
has_bnb = False
assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers'

opt_args = dict(weight_decay=weight_decay, **kwargs)

if lr is not None:
Expand Down Expand Up @@ -357,6 +366,40 @@ def create_optimizer_v2(
opt_args.setdefault('betas', (0.95, 0.98))
optimizer = FusedNovoGrad(parameters, **opt_args)

# bitsandbytes optimizers, require bitsandbytes to be installed
elif opt_lower == 'bnbsgd':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'bnbsgd8bit':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'bnbmomentum':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'bnbmomentum8bit':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'bnbadam':
optimizer = bnb.optim.Adam(parameters, **opt_args)
elif opt_lower == 'bnbadam8bit':
optimizer = bnb.optim.Adam8bit(parameters, **opt_args)
elif opt_lower == 'bnbadamw':
optimizer = bnb.optim.AdamW(parameters, **opt_args)
elif opt_lower == 'bnbadamw8bit':
optimizer = bnb.optim.AdamW8bit(parameters, **opt_args)
elif opt_lower == 'bnblamb':
optimizer = bnb.optim.LAMB(parameters, **opt_args)
elif opt_lower == 'bnblamb8bit':
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
elif opt_lower == 'bnblars':
optimizer = bnb.optim.LARS(parameters, **opt_args)
elif opt_lower == 'bnblarsb8bit':
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
elif opt_lower == 'bnblion':
optimizer = bnb.optim.Lion(parameters, **opt_args)
elif opt_lower == 'bnblion8bit':
optimizer = bnb.optim.Lion8bit(parameters, **opt_args)

else:
assert False and "Invalid optimizer"
raise ValueError
Expand Down

0 comments on commit e3363a7

Please sign in to comment.