diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index de66945..1f64ba3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,15 +23,8 @@ jobs: fail-fast: false matrix: version: - - '1.0' - - '1.1' - - '1.2' - - '1.3' - - '1.4' - - '1.5' - - '1.6' - - '1.7' - - '1.8' + - '1.9' + - 'nightly' os: - ubuntu-latest arch: diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..7286d9a --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,7 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.0" +manifest_format = "2.0" +project_hash = "f6dd4230b36847a6d40a3eea7d2682365388770e" + +[deps] diff --git a/Project.toml b/Project.toml index 7e11247..1193c56 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,7 @@ authors = ["murrellb and contributors"] version = "0.1.0" [compat] -julia = "1" +julia = "1.9" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/schedules.jl b/src/schedules.jl index 913379a..030cff5 100644 --- a/src/schedules.jl +++ b/src/schedules.jl @@ -19,10 +19,10 @@ struct LinearSchedule{T} <: AbstractSchedule{T} steps::Int end -function Base.iterate(schedule::LinearSchedule{T}, (rate, state)::Tuple{T, Int}=(schedule.initial, 0)) where T - x = clamp(state / schedule.steps, 0, 1) +function Base.iterate(schedule::LinearSchedule{T}, (rate, step)::Tuple{T, Int}=(schedule.initial, 0)) where T + x = clamp(step / schedule.steps, 0, 1) rate = schedule.initial + (schedule.final - schedule.initial) * x - return (rate, (rate, state + 1)) + return (rate, (rate, step + 1)) end """ @@ -45,28 +45,30 @@ struct BurninSchedule{T} <: AbstractSchedule{T} decay::T end -function Base.iterate(schedule::BurninSchedule{T}, (rate, state)::Tuple{T, Int}=(schedule.min, 1)) where T - if state == 1 +function Base.iterate(schedule::BurninSchedule{T}, (rate, stage)::Tuple{T, Int}=(schedule.min, 0)) where T + if stage == 0 + stage = 1 + elseif stage == 1 rate *= schedule.inflate if rate ≥ schedule.max rate = schedule.max - state = 2 + stage = 2 end - elseif state == 2 + elseif stage == 2 rate *= schedule.decay if rate ≤ schedule.min rate = schedule.min - state = 3 + stage = 3 end end - return (rate, (rate, state)) + return (rate, (rate, stage)) end """ BurninHyperbolicSchedule{T} <: AbstractSchedule{T} BurninHyperbolicSchedule(min::T, max::T, inflate::T, decay::T, floor::T) -A learning rate schedule with exponential inflation and hyperbolic decay states. +A learning rate schedule with exponential inflation and hyperbolic decay stages. The rate starts at `min`, inflates exponentially to `max`, then decays hyperbolically to `min`. # Arguments @@ -74,7 +76,7 @@ The rate starts at `min`, inflates exponentially to `max`, then decays hyperboli - `max::T`: The maximum learning rate. - `inflate::T`: The inflation factor during stage 1. - `decay::T`: The decay factor during stage 2 (starts after max is reached). -- `floor::T = zero(T)`: The floor value for the decay. +- `floor::T`: idk ask Ben or look at the code lol """ struct BurninHyperbolicSchedule{T} <: AbstractSchedule{T} min::T @@ -84,19 +86,21 @@ struct BurninHyperbolicSchedule{T} <: AbstractSchedule{T} floor::T end -function Base.iterate(schedule::BurninHyperbolicSchedule{T}, (rate, state)::Tuple{T, Int}=(schedule.min, 1)) where T - if state == 1 +function Base.iterate(schedule::BurninHyperbolicSchedule{T}, (rate, stage)::Tuple{T, Int}=(schedule.min, 0)) where T + if stage == 0 + stage = 1 + elseif stage == 1 rate *= schedule.inflate if rate ≥ schedule.max rate = schedule.max - state = 2 + stage = 2 end - elseif state == 2 + elseif stage == 2 rate = (rate - schedule.floor) / (one(T) + schedule.decay * (rate - schedule.floor)) if rate ≤ schedule.min rate = schedule.min - state = 3 + stage = 3 end end - return (rate, (rate, state)) + return (rate, (rate, stage)) end diff --git a/test/runtests.jl b/test/runtests.jl index aac87ae..06dbe8f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,11 +2,29 @@ using LearningSchedules using Test @testset "LearningSchedules.jl" begin - + @testset "schedules.jl" begin + @testset "LinearSchedule" begin + schedule = LinearSchedule(1.0, 0.6, 4) + @test collect(r for (i, r) in zip(1:6, schedule)) == [1.0, 0.9, 0.8, 0.7, 0.6, 0.6] + end + @testset "BurninSchedule" begin + schedule = BurninSchedule(1.0, 8.0, 2.0, 0.5) + @test collect(r for (i, r) in zip(1:8, schedule)) == [1.0, 2.0, 4.0, 8.0, 4.0, 2.0, 1.0, 1.0] + end + @testset "BurninHyperbolicSchedule" begin + schedule = BurninHyperbolicSchedule(1.0, 8.0, 2.0, 0.5, 0.0) + @test collect(r for (i, r) in zip(1:8, schedule)) == [1.0, 2.0, 4.0, 8.0, 1.6, 1.0, 1.0, 1.0] + end end @testset "stateful.jl" begin + @testset "next_rate!" begin + schedule = Stateful(LinearSchedule(1.0, 0.6, 4)) + @test next_rate(schedule) == 1.0 + @test next_rate!(schedule) == 1.0 + @test next_rate!(schedule) == 0.9 + end end end