Skip to content

Commit

Permalink
Fix merge gore
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Nov 2, 2024
1 parent 75a1395 commit f8f8a1e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 32 deletions.
2 changes: 1 addition & 1 deletion olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,8 @@ class SchedulerType(StrEnum):
inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
max_scheduler = "max_scheduler"
constant = "constant"
constant_with_warmup = "constant_with_warmup"
cosine_linear_envelope = "cosine_linear_envelope"
constant_with_warmup = "constant_with_warmup"


class SchedulerUnits(StrEnum):
Expand Down
42 changes: 11 additions & 31 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,17 +789,6 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
return initial_lr


@dataclass
class ConstantWithWarmupScheduler(Scheduler):
warmup_steps: int

def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
if step < self.warmup_steps:
return self._linear_warmup(initial_lr, step, self.warmup_steps)
del max_steps
return initial_lr


@dataclass
class CosLinearEnvelope(Scheduler):
"Pointwise product of cosine schedule and linear decay; useful during annealing."
Expand All @@ -823,6 +812,17 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
return eta_min + linear_envelope * cosine_schedule


@dataclass
class ConstantWithWarmupScheduler(Scheduler):
warmup_steps: int

def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
if step < self.warmup_steps:
return self._linear_warmup(initial_lr, step, self.warmup_steps)
del max_steps
return initial_lr


PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names")


Expand Down Expand Up @@ -1016,35 +1016,16 @@ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = Non
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
warmup_min_lr=sched_cfg.warmup_min_lr,
)
<<<<<<< HEAD
elif sched_cfg.name == SchedulerType.constant_with_warmup:
return ConstantWithWarmupScheduler(
=======
elif sched_cfg.name == SchedulerType.cosine_linear_envelope:
return CosLinearEnvelope(
>>>>>>> origin/main
grad_clip_warmup_steps=(
None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
),
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
<<<<<<< HEAD
warmup_min_lr=sched_cfg.warmup_min_lr,
warmup_steps=int(sched_cfg.t_warmup)
)
elif sched_cfg.name == SchedulerType.cosine_linear_envelope:
return CosLinearEnvelope(
grad_clip_warmup_steps=(
None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
),
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
=======
>>>>>>> origin/main
warmup_steps=int(sched_cfg.t_warmup),
alpha_f=sched_cfg.alpha_f,
t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
warmup_min_lr=sched_cfg.warmup_min_lr,
<<<<<<< HEAD
=======
)
elif sched_cfg.name == SchedulerType.constant_with_warmup:
return ConstantWithWarmupScheduler(
Expand All @@ -1054,7 +1035,6 @@ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = Non
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
warmup_min_lr=sched_cfg.warmup_min_lr,
warmup_steps=int(sched_cfg.t_warmup),
>>>>>>> origin/main
)
else:
raise NotImplementedError

0 comments on commit f8f8a1e

Please sign in to comment.