-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow shared parameters, take III #106
Conversation
I'd say this is less of a hack and something we should be doing more often. Either define a custom cache type, or (better) attach the cache to the callback itself by memoizing it. Then |
src/interface.jl
Outdated
function setup(rule::AbstractRule, model) | ||
cnt = Ref(0) | ||
# Rely on Functors to identify shared arrays, they will share a Leaf in this tree: | ||
tree = fmapstructure(model, exclude = isnumeric) do x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's pretty surprising tests pass with this, as it doesn't check trainable
at all.
update!(t′, x′, x̄s...) | ||
function _update!(tree, x; grads, params) | ||
haskey(params, (tree,x)) && return params[(tree,x)] | ||
isbits(tree) && return x # means () is not cached, and also (((),),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does imply we will be caching almost every level of an average Flux model (since BitsType{NotBits, BitsTypes...}
is not a bitstype). objectid
being not the fastest function in the world, perhaps both cache lookup and insertion should be additionally guarded by ismutable(x)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wondered this too. For large ImmutableArrays this may eventually need something fancier. But for now I think every fmap
walk does the same thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I wasn't even thinking about those, but cases like JuliaLang/julia#43542. We're unlikely to see any truly pathological behaviour, but I have to imagine the single comparison ismutable
makes is more efficient than the recursive hash function objectid
uses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I guess ismutable
really is right here. For parameter arrays IIRC there was a concern that it tells you e.g. that PermutedDimsArray is immutable. But for known non-leaf types, maybe it's always right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. PermutedDimsArray
at least does implement functor
, but you can always find an array wrapper which hasn't. Perhaps then the check should be isleaf
instead? The isbits
check is still useful either way.
Edit: I suppose isnumeric
makes more sense since it forwards to isleaf
already and setup
guarantees only unfamiliar immutable wrappers of immutable arrays will get their own Leaf
. Moving the isbits
check up front also seems safe and could save a couple cycles on dict lookups.
function _update!(tree, x; grads, params)
isbits(tree) && return x # means () is not cached, and also (((),),)
isnum = isnumeric(x)
isnum && haskey(params, (tree,x)) && return params[(tree,x)]
children, re = functor(x)
children′ = map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, children)
x′ = re(children′)
isnum ? (params[(tree,x)] = x′) : x′
end
It's likely this can be simplified, but I wanted to get something on the page first in case there are any unforeseen edge cases present in this formulation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think anything isnumeric
should have a corresponding Leaf
and hit the _update!(::Leaf, x; ...)
method.
This one wants only to deal with mutable non-leaf things, like my mutable struct MutTwo
example. Which makes me think that ismutable
is fine -- we have Foo(MutTwo(Bar(Transpose(Array
, then the Array is leaf, and the only level at which it's worthwhile for this method to cache anything is the MutTwo
one. If this whole stack appears twice, a fresh new struct Foo
cannot be distinguished from the old one.
Shall we do this? I don't love it, and feel a bit bad about re-writing #100 in order to understand it... but this does add some features in the end. But I do think we ought to handle shared parameters, and that we want mutable We can re-write the internals if FluxML/Functors.jl#43 or something allows for a prettier version. The tests are pretty good. Maybe Edit: In fact perhaps |
I have no objections assuming we're not considering any behavioural changes after those Functors PRs are merged. |
I am also okay with doing this |
Co-authored-by: Brian Chen <[email protected]>
Ok let's do it. |
Another take on #100. Borrows the idea of making
Leaf
mutable.TriesTried to be simpler by pushing more of the recursion ontofmap
:setup
is justfmapstructure
really. Its notion of sharing is thus exactly the one of Functors, one source of truth. We should fix that not to shareisbits
types, eventually.update!
is justfmap
. Much of the complication of the old walk was to reconstruct both the state and the model on the way out. But this isn't needed if Leaf is mutated.Tests from #100 pass with first commit. However, the shared Leaves must always match shared Arrays. It's possible that this scenario can be done even more simply, possibly without mutable Leaf.
What #100 does is instead to take shared Leaves as the truth about parameter sharing, which some future API could set in a way not matching the model (for ImmutableArrays, etc) even though present
setup
will not. Thenupdate!
cannot just befmap
, and needs one more separateIdDict
for the parameters. Second commit here e84b61b bolts that on, and adds a test of it (which also pass using #100). But it's a bit ugly.Edit: Third commit 0de29e1 instead just replaces the walk used for
fmap(f, tree, x)
to usere
from its 2nd argument, while Functors still uses the cache on the 1st argument. That's tidier.But the state tree contains the the same
()
at every non-parameter node, and Functors caches the results of these... we should fix this upstream? A possible hack for now would be to supply a special cacheIdDict{Leaf}
which cannot store anything else -- done in e17e474.But... that's still not right. If there are mutable layer structs, then I think you cannot rely on the ID of mutable
Leaf
to tie things. So I gave up on customisingfmap
and wrote out the recursion using(x,Leaf())
as the key for reconstruction.Gradient accumulation uses an IdDict as in #100, but
stores aChanged to eager addition.broadcasted
adding the pieces. Which it thus requires allapply!
methods to accept. They all do.Does not at present allow for more than one derivative. But no rules use that.Added. There were no tests it seems.Fixes the bug noted in #100 that
update
could in fact mutate the state. Does this by just saying@functor Leaf
. Added a test.One further possibility with a mutable
Leaf
is that if can easily have a flag to mark some parameters as temporarily frozen.This is implemented here (with no API to set the flag). Not sure it's what we want though. Easy to remove but perhaps if we're changing the struct we should consider other changes we might want.Because
setup
does not call itself in recursion, it is fairly easy to add a warning if the model has no parameters. This was something someone complained about, I forget where.Closes #42, closes #100, closes #97