Skip to content

Commit

Permalink
add constant learning rate with custom rule (huggingface#3133)
Browse files Browse the repository at this point in the history
* add constant lr with rules

* add constant with rules in TYPE_TO_SCHEDULER_FUNCTION

* add constant lr rate with rule

* hotfix code quality

* fix doc style

* change name constant_with_rules to piecewise constant
  • Loading branch information
jason9075 authored Apr 28, 2023
1 parent 8bb8052 commit 83c4ce7
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SchedulerType(Enum):
POLYNOMIAL = "polynomial"
CONSTANT = "constant"
CONSTANT_WITH_WARMUP = "constant_with_warmup"
PIECEWISE_CONSTANT = "piecewise_constant"


def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
Expand Down Expand Up @@ -77,6 +78,48 @@ def lr_lambda(current_step: int):
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)


def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
step_rules (`string`):
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
steps and multiple 0.005 for the other steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""

rules_dict = {}
rule_list = step_rules.split(",")
for rule_str in rule_list[:-1]:
value_str, steps_str = rule_str.split(":")
steps = int(steps_str)
value = float(value_str)
rules_dict[steps] = value
last_lr_multiple = float(rule_list[-1])

def create_rules_function(rules_dict, last_lr_multiple):
def rule_func(steps: int) -> float:
sorted_steps = sorted(rules_dict.keys())
for i, sorted_step in enumerate(sorted_steps):
if steps < sorted_step:
return rules_dict[sorted_steps[i]]
return last_lr_multiple

return rule_func

rules_func = create_rules_function(rules_dict, last_lr_multiple)

return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)


def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
Expand Down Expand Up @@ -232,12 +275,14 @@ def lr_lambda(current_step: int):
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
SchedulerType.CONSTANT: get_constant_schedule,
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
}


def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
step_rules: Optional[str] = None,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
Expand All @@ -252,6 +297,8 @@ def get_scheduler(
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
step_rules (`str`, *optional*):
A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
Expand All @@ -270,6 +317,9 @@ def get_scheduler(
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer, last_epoch=last_epoch)

if name == SchedulerType.PIECEWISE_CONSTANT:
return schedule_func(optimizer, rules=step_rules, last_epoch=last_epoch)

# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
Expand Down

0 comments on commit 83c4ce7

Please sign in to comment.