Skip to content

Commit

Permalink
Add support for config.lr_configs.base_learning_rate = 0 iif config.o…
Browse files Browse the repository at this point in the history
…ptimizer.per_example_clipping is used.

PiperOrigin-RevId: 559356473
  • Loading branch information
AndrBaer authored and Scenic Authors committed Aug 23, 2023
1 parent 2dd1974 commit 50be9b0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
12 changes: 8 additions & 4 deletions scenic/train_lib/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,14 @@ def get_learning_rate_fn(config: ml_collections.ConfigDict):
raise ValueError(
'`base_learning_rate` has to be defined in the lr_config.')
if not config.lr_configs.base_learning_rate:
raise ValueError( # raised for {0, False, None, [], (), {}}
f'`base_learning_rate = {config.lr_configs.base_learning_rate}` is not '
'allowed for training parameters. If your intention was to freeze '
'parameters, use Scenic optax and `config.lr_configs = None` instead.')
# raise ValueError( # raised for {0, False, None, [], (), {}}
# f'`base_learning_rate = {config.lr_configs.base_learning_rate}` is not '
# 'allowed for training parameters. If your intention was to freeze '
# 'parameters, use Scenic optax and `config.lr_configs = None` instead.')
pass
# Circumvent failing of config.lr_configs.base_learning_rate in {0, False,
# None, [], (), {}} here as a short-term solution. This case is for now
# handled in optax.make to handle edge cases.
if 'learning_rate_schedule' in config.lr_configs:
# A function that given the current step, returns the LR.
return lr_fn_dict[config.lr_configs['learning_rate_schedule']](
Expand Down
10 changes: 10 additions & 0 deletions scenic/train_lib/optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ def make(config: ml_collections.ConfigDict,
schedule function and base learning rate (for WD decoupling).
params: Model parameters.
"""
if not config.get('per_example_clipping'):
# Collect all base_lrs and transform to bool. Each element of schedule fol-
# lows the structure (re, name, (fn, base_lr)) [see above].
base_lrs = [fn_base_lr[1] for _, _, fn_base_lr in schedule]
if any([not base_lr for base_lr in base_lrs]):
raise ValueError( # raised if base_lr in {0, False, None, [], (), {}}
f'`base_learning_rate` contains unsupported values {base_lrs}. '
'Unsupported values are: 0, False, None, [], (), {}. If '
'your intention was to freeze parameters, use Scenic optax and '
'`config.lr_configs = None` instead.')

masks, scheds = _make_mask_trees(params, schedule, log='schedule')
frozen_mask, masks, scheds = _split_frozen(masks, scheds)
Expand Down

0 comments on commit 50be9b0

Please sign in to comment.