From d3f145f1c04b4a1068be1f39f40b25f00ddea8eb Mon Sep 17 00:00:00 2001 From: dwadden Date: Tue, 6 Aug 2024 01:09:46 +0000 Subject: [PATCH 1/4] Add scheduler for cosine in linear envelope. This didn't prove particularly effective for annealing, but might as well have it as an option. --- olmo/config.py | 1 + olmo/optim.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/olmo/config.py b/olmo/config.py index 8d3ed0823..d3f94f37c 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -556,6 +556,7 @@ class SchedulerType(StrEnum): inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup" max_scheduler = "max_scheduler" constant = "constant" + cosine_linear_envelope = "cosine_linear_envelope" class SchedulerUnits(StrEnum): diff --git a/olmo/optim.py b/olmo/optim.py index d05536a45..8d24599ba 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -25,6 +25,7 @@ "InvSqrtWithWarmup", "MaxScheduler", "ConstantScheduler", + "CosLinearEnvelope", "BoltOnWarmupScheduler", "build_optimizer", "build_scheduler", @@ -788,6 +789,22 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: return initial_lr +@dataclass +class CosLinearEnvelope(Scheduler): + "Pointwise product of cosine schedule and linear decay; useful during annealing." + t_max: Optional[int] = None + + def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: + eta_min = 0.0 + + if step >= max_steps: + return eta_min + else: + linear_envelope = initial_lr - (initial_lr - eta_min) * (step / max_steps) + cosine_schedule = (1 + cos(pi * step / max_steps)) / 2 + return linear_envelope * cosine_schedule + + PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names") @@ -981,5 +998,14 @@ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = Non grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, warmup_min_lr=sched_cfg.warmup_min_lr, ) + elif sched_cfg.name == SchedulerType.cosine_linear_envelope: + return CosLinearEnvelope( + grad_clip_warmup_steps=None + if sched_cfg.grad_clip_warmup_steps is None + else int(sched_cfg.grad_clip_warmup_steps), + grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, + t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max), + warmup_min_lr=sched_cfg.warmup_min_lr, + ) else: raise NotImplementedError From 04aae0e7c28607f592de18a9e71c9f3531eddb58 Mon Sep 17 00:00:00 2001 From: dwadden Date: Tue, 6 Aug 2024 22:06:10 +0000 Subject: [PATCH 2/4] Add warmup steps and alpha_f. --- olmo/optim.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/olmo/optim.py b/olmo/optim.py index 8d24599ba..5c1ce119f 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -792,17 +792,24 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: @dataclass class CosLinearEnvelope(Scheduler): "Pointwise product of cosine schedule and linear decay; useful during annealing." + warmup_steps: int + alpha_f: float = 0.1 t_max: Optional[int] = None def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: - eta_min = 0.0 + max_steps = max_steps if self.t_max is None else self.t_max + eta_min = initial_lr * self.alpha_f + if step < self.warmup_steps: + return self._linear_warmup(initial_lr, step, self.warmup_steps) if step >= max_steps: return eta_min else: - linear_envelope = initial_lr - (initial_lr - eta_min) * (step / max_steps) - cosine_schedule = (1 + cos(pi * step / max_steps)) / 2 - return linear_envelope * cosine_schedule + step = step - self.warmup_steps + max_steps = max_steps - self.warmup_steps + linear_envelope = 1 - (step / max_steps) + cosine_schedule = (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2 + return eta_min + linear_envelope * cosine_schedule PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names") From 00f20a01ffba717b33f31db9c0c2615fdcbab7c3 Mon Sep 17 00:00:00 2001 From: dwadden Date: Tue, 6 Aug 2024 22:08:01 +0000 Subject: [PATCH 3/4] Update changelog. --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c86f78bc9..382cf638e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `model.rope_theta` configuration option. - Added `model.embedding_layer_norm` configuration option for adding a LN to the embeddings. - Added `model.emb_init_std` configuration option to override the standard deviation used to initialize the embeddings. +- Added `CosLinearEnvelope` scheduler, which is a pointwise product of a cosine schedule and a linear decay. ### Changed From 1cf30406737918e57109c5a46724530b3f4ef774 Mon Sep 17 00:00:00 2001 From: dwadden Date: Wed, 7 Aug 2024 18:43:48 +0000 Subject: [PATCH 4/4] Pass in warmup_steps and alpha_f. --- olmo/optim.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/olmo/optim.py b/olmo/optim.py index 5c1ce119f..5460ccee1 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -1007,10 +1007,12 @@ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = Non ) elif sched_cfg.name == SchedulerType.cosine_linear_envelope: return CosLinearEnvelope( - grad_clip_warmup_steps=None - if sched_cfg.grad_clip_warmup_steps is None - else int(sched_cfg.grad_clip_warmup_steps), + grad_clip_warmup_steps=( + None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps) + ), grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, + warmup_steps=int(sched_cfg.t_warmup), + alpha_f=sched_cfg.alpha_f, t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max), warmup_min_lr=sched_cfg.warmup_min_lr, )