Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Jun 16, 2024
1 parent 3c4b161 commit c492796
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 28 deletions.
11 changes: 2 additions & 9 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.1'
- '1.2'
- '1.3'
- '1.4'
- '1.5'
- '1.6'
- '1.7'
- '1.8'
- '1.9'
- 'nightly'
os:
- ubuntu-latest
arch:
Expand Down
7 changes: 7 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.0"
manifest_format = "2.0"
project_hash = "f6dd4230b36847a6d40a3eea7d2682365388770e"

[deps]
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = ["murrellb <[email protected]> and contributors"]
version = "0.1.0"

[compat]
julia = "1"
julia = "1.9"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
38 changes: 21 additions & 17 deletions src/schedules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ struct LinearSchedule{T} <: AbstractSchedule{T}
steps::Int
end

function Base.iterate(schedule::LinearSchedule{T}, (rate, state)::Tuple{T, Int}=(schedule.initial, 0)) where T
x = clamp(state / schedule.steps, 0, 1)
function Base.iterate(schedule::LinearSchedule{T}, (rate, step)::Tuple{T, Int}=(schedule.initial, 0)) where T
x = clamp(step / schedule.steps, 0, 1)
rate = schedule.initial + (schedule.final - schedule.initial) * x
return (rate, (rate, state + 1))
return (rate, (rate, step + 1))
end

"""
Expand All @@ -45,36 +45,38 @@ struct BurninSchedule{T} <: AbstractSchedule{T}
decay::T
end

function Base.iterate(schedule::BurninSchedule{T}, (rate, state)::Tuple{T, Int}=(schedule.min, 1)) where T
if state == 1
function Base.iterate(schedule::BurninSchedule{T}, (rate, stage)::Tuple{T, Int}=(schedule.min, 0)) where T
if stage == 0
stage = 1
elseif stage == 1
rate *= schedule.inflate
if rate schedule.max
rate = schedule.max
state = 2
stage = 2
end
elseif state == 2
elseif stage == 2
rate *= schedule.decay
if rate schedule.min
rate = schedule.min
state = 3
stage = 3
end
end
return (rate, (rate, state))
return (rate, (rate, stage))
end

"""
BurninHyperbolicSchedule{T} <: AbstractSchedule{T}
BurninHyperbolicSchedule(min::T, max::T, inflate::T, decay::T, floor::T)
A learning rate schedule with exponential inflation and hyperbolic decay states.
A learning rate schedule with exponential inflation and hyperbolic decay stages.
The rate starts at `min`, inflates exponentially to `max`, then decays hyperbolically to `min`.
# Arguments
- `min::T`: The minimum learning rate.
- `max::T`: The maximum learning rate.
- `inflate::T`: The inflation factor during stage 1.
- `decay::T`: The decay factor during stage 2 (starts after max is reached).
- `floor::T = zero(T)`: The floor value for the decay.
- `floor::T`: idk ask Ben or look at the code lol
"""
struct BurninHyperbolicSchedule{T} <: AbstractSchedule{T}
min::T
Expand All @@ -84,19 +86,21 @@ struct BurninHyperbolicSchedule{T} <: AbstractSchedule{T}
floor::T
end

function Base.iterate(schedule::BurninHyperbolicSchedule{T}, (rate, state)::Tuple{T, Int}=(schedule.min, 1)) where T
if state == 1
function Base.iterate(schedule::BurninHyperbolicSchedule{T}, (rate, stage)::Tuple{T, Int}=(schedule.min, 0)) where T
if stage == 0
stage = 1
elseif stage == 1
rate *= schedule.inflate
if rate schedule.max
rate = schedule.max
state = 2
stage = 2
end
elseif state == 2
elseif stage == 2
rate = (rate - schedule.floor) / (one(T) + schedule.decay * (rate - schedule.floor))
if rate schedule.min
rate = schedule.min
state = 3
stage = 3
end
end
return (rate, (rate, state))
return (rate, (rate, stage))
end
20 changes: 19 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,29 @@ using LearningSchedules
using Test

@testset "LearningSchedules.jl" begin

@testset "schedules.jl" begin
@testset "LinearSchedule" begin
schedule = LinearSchedule(1.0, 0.6, 4)
@test collect(r for (i, r) in zip(1:6, schedule)) == [1.0, 0.9, 0.8, 0.7, 0.6, 0.6]
end
@testset "BurninSchedule" begin
schedule = BurninSchedule(1.0, 8.0, 2.0, 0.5)
@test collect(r for (i, r) in zip(1:8, schedule)) == [1.0, 2.0, 4.0, 8.0, 4.0, 2.0, 1.0, 1.0]
end
@testset "BurninHyperbolicSchedule" begin
schedule = BurninHyperbolicSchedule(1.0, 8.0, 2.0, 0.5, 0.0)
@test collect(r for (i, r) in zip(1:8, schedule)) == [1.0, 2.0, 4.0, 8.0, 1.6, 1.0, 1.0, 1.0]
end
end

@testset "stateful.jl" begin
@testset "next_rate!" begin
schedule = Stateful(LinearSchedule(1.0, 0.6, 4))
@test next_rate(schedule) == 1.0
@test next_rate!(schedule) == 1.0
@test next_rate!(schedule) == 0.9
end
end

end

0 comments on commit c492796

Please sign in to comment.