Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] Preventing errors from misplaced optimizer objects #2106

Closed
MilesCranmer opened this issue Nov 9, 2022 · 2 comments
Closed

[API] Preventing errors from misplaced optimizer objects #2106

MilesCranmer opened this issue Nov 9, 2022 · 2 comments

Comments

@MilesCranmer
Copy link

I was wondering if you would be open to an API improvement, which would be completely optional and also minimal. The goal would be reduce potential for bugs and also make code more intuitive.

The following code is how one currently initializes an optimizer:

p = params(model)
opt = Adam(1e-3)

for i=1:1000
    ...
    update!(opt, p, grad)
end

It is not until the update! step does the optimizer actually initialize its state and read the parameters.

I think there are a couple potential issues with this:

  1. params and Adam are initialized quite far apart from eachother, even though they are very interrelated objects. This distance between them makes it easier for me to change the variable name for p, but not for opt. I could be running update! with an old optimizer (with its old state), and not even realize it.
  2. Knowing that Adam actually records information about the parameters, it seems unintuitive for it to not know about them when the object is created. To a beginner user, they might think that somehow just by initializing Adam after params, some global method would connect them to eachother, rather than the actual loading happening at the update! step.

I am wondering what you think about the following minimal (and optional) change to remedy this. I think it could be good for each optimizer to record the objectid of the parameter object at initialization, and throw an error if the user attempts to update them with a different set of parameters.

For example:

mutable struct Descent <: AbstractOptimiser
  eta::Float64
  param_id::Union{UInt64,Nothing}
end
Descent() = Descent(0.1, nothing)
Descent(eta) = Descent(eta, nothing)
Descent(eta, parameters) = Descent(eta, objectid(parameters))

Then, in the update! code, you would check whether param_id is nothing, and if it is not, you would verify that objectid(parameters) indeed matches what is stored.

Thus, it wouldn't affect any current code, but in the future it would let people start writing safer and more intuitive code, like:

p = params(model)
opt = Adam(1e-3, p)

for i=1:1000
    ...
    update!(opt, p, grad)
end

And, if I forget to change the name of p when copying the loop:

p2 = params(model)
opt2 = Adam(1e-3, p2)

for i=1:1000
    ...
    update!(opt2, p, grad)  # Throws an error!
end

In the future you could also have the constructor initialize the state of the optimizer, rather than just recording the objectid. I think just having this is a good start, though.

Thoughts on this?

Cheers,
Miles

@darsnack
Copy link
Member

darsnack commented Nov 9, 2022

Something very similar conceptually is in Optimisers.jl which is what we are planning to transition our optimizers to. See #2082 for some of that discussion (it's been happening for some time). I think while it doesn't do the objectid check that you want, it also no longer matters because the state is no longer in an IdDict.

The main reason for not introducing this change would be that it applies to stuff that will get deprecated soon. Otherwise seems reasonable!

@MilesCranmer
Copy link
Author

Cool, sounds good to me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants