Skip to content

Commit

Permalink
Add warmup steps and alpha_f.
Browse files Browse the repository at this point in the history
  • Loading branch information
dwadden committed Aug 6, 2024
1 parent d3f145f commit 04aae0e
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,17 +792,24 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
@dataclass
class CosLinearEnvelope(Scheduler):
"Pointwise product of cosine schedule and linear decay; useful during annealing."
warmup_steps: int
alpha_f: float = 0.1
t_max: Optional[int] = None

def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
eta_min = 0.0
max_steps = max_steps if self.t_max is None else self.t_max
eta_min = initial_lr * self.alpha_f

if step < self.warmup_steps:
return self._linear_warmup(initial_lr, step, self.warmup_steps)
if step >= max_steps:
return eta_min
else:
linear_envelope = initial_lr - (initial_lr - eta_min) * (step / max_steps)
cosine_schedule = (1 + cos(pi * step / max_steps)) / 2
return linear_envelope * cosine_schedule
step = step - self.warmup_steps
max_steps = max_steps - self.warmup_steps
linear_envelope = 1 - (step / max_steps)
cosine_schedule = (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2
return eta_min + linear_envelope * cosine_schedule


PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names")
Expand Down

0 comments on commit 04aae0e

Please sign in to comment.