You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In my application I do 25 steps of gradient descent update! steps in a loop (solving a differential equation). I need the momentum from the previous 25 GD steps to NOT carry over to the next 25 GD steps. In other words, the behavior I am looking for is analogous to calling Flux.setup(optimiser, model) every time. Unfortunately, Flux.setup is type-unstable #162. It would be great to have a function reset!(optimiser_state) that resets the momenta. Maybe a more stringent requirement is that
state = Flux.setup(optimiser, model)
# do some training
reset!(state)
state == Flux.setup(optimiser, model)
holds.
Possible Implementation
Below is an implementation for Adam.
function reset!(leaf::Leaf{A, S}) where {A <: Optimisers.Adam, S}
leaf.state[1] .= 0
leaf.state[2] .= 0
leaf.state = (leaf.state[1], leaf.state[2], leaf.rule.beta)
nothing
end
function reset!(state::NamedTuple{(:layers,), L}) where {L}
for layer in state.layers
reset!(layer.weight)
reset!(layer.bias)
end
nothing
end
The text was updated successfully, but these errors were encountered:
reset!(tree) =foreach(reset!, tree)
reset!(ℓ::Leaf) = ℓ.state =reset!(ℓ.rule, ℓ.state)
reset!(::AbstractRule, ::Nothing) =nothingreset!(rule::AbstractRule, state) =throw(ArgumentError("""reset! does not now how to handle this rule."))
Then rules need to opt-in by defining a method of 2-arg reset!... with some fill!! which allows for immutable arrays?
Motivation and description
In my application I do 25 steps of gradient descent
update!
steps in a loop (solving a differential equation). I need the momentum from the previous 25 GD steps to NOT carry over to the next 25 GD steps. In other words, the behavior I am looking for is analogous to callingFlux.setup(optimiser, model)
every time. Unfortunately,Flux.setup
is type-unstable #162. It would be great to have a functionreset!(optimiser_state)
that resets the momenta. Maybe a more stringent requirement is thatholds.
Possible Implementation
Below is an implementation for
Adam
.The text was updated successfully, but these errors were encountered: