Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reset!(optimiser_state) #163

Open
Vilin97 opened this issue Oct 11, 2023 · 2 comments
Open

reset!(optimiser_state) #163

Vilin97 opened this issue Oct 11, 2023 · 2 comments

Comments

@Vilin97
Copy link

Vilin97 commented Oct 11, 2023

Motivation and description

In my application I do 25 steps of gradient descent update! steps in a loop (solving a differential equation). I need the momentum from the previous 25 GD steps to NOT carry over to the next 25 GD steps. In other words, the behavior I am looking for is analogous to calling Flux.setup(optimiser, model) every time. Unfortunately, Flux.setup is type-unstable #162. It would be great to have a function reset!(optimiser_state) that resets the momenta. Maybe a more stringent requirement is that

state = Flux.setup(optimiser, model)
# do some training
reset!(state)
state == Flux.setup(optimiser, model)

holds.

Possible Implementation

Below is an implementation for Adam.

function reset!(leaf::Leaf{A, S}) where {A <: Optimisers.Adam, S}
    leaf.state[1] .= 0
    leaf.state[2] .= 0
    leaf.state = (leaf.state[1], leaf.state[2], leaf.rule.beta)
    nothing
end
function reset!(state::NamedTuple{(:layers,), L}) where {L}
    for layer in state.layers
        reset!(layer.weight)
        reset!(layer.bias)
    end
    nothing
end
@mcabbott
Copy link
Member

One possible design is this:

reset!(tree) = foreach(reset!, tree)
reset!(ℓ::Leaf) =.state = reset!(ℓ.rule, ℓ.state)

reset!(::AbstractRule, ::Nothing) = nothing
reset!(rule::AbstractRule, state) = throw(ArgumentError("""reset! does not now how to handle this rule."))

Then rules need to opt-in by defining a method of 2-arg reset!... with some fill!! which allows for immutable arrays?

reset!(rule::Adam, (mt, vt, βt)) = (fill!!(mt, 0), fill!!(vt, 0), rule.beta)

We can't easily fall back to calling init again for unknown rules, as we don't have the original parameters x here.

Falling back to zero like this might be OK for built-in rules like Momentum etc, but could be wrong for user-defined rules... probably we shouldn't:

reset!(rule::AbstractRule, state::AbstractArray) = fill!!(state, 0)

@ToucheSir
Copy link
Member

We could always make reset! take the parameter tree of xs too, but that may come at the cost of sacrificing type stability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants