-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transparent handling of tied weights #100
Conversation
This makes `Leaf` a mutable type so that tied weights are represented by the same leaf instance. Co-authored-by: Michael Abbott <[email protected]>
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, _...) = nothing | ||
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, ::Zero, ::Zero...) = nothing | ||
_accumulate!(::AbstractDict{Leaf,Any}, ℓ::Leaf, _, ::Zero, ::Zero...) = nothing | ||
_accumulate!(::AbstractDict{Leaf,Any}, _, _, ::Zero, ::Zero...) = nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a lot of overloads with high degrees of overlap across multiple functions. I couldn't think of a way to deduplicate some of them, so if anyone has ideas that would be swell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it can be just 4 methods, if the state tree has () instead of nothing, as in #106.
I also think it would be clearer to write variable names more often, not _
, since 5 arguments is quite a few to count.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I tried that, but ran into ambiguities. This is the smallest number of methods I could come up with that didn't have ambiguities. If you can narrow that down, that would be superb.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The underscores are mostly to appease the linter and possibly improve latency(??) Perhaps ::Any
would work better, though I'm not sure that addresses your point about clarity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see how names can affect latency. I just mean they let your eye know what the 4th argument means, which ::Any
doesn't help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My impression was it implicitly acted as a @nospecialize
, but looking at https://github.com/JuliaLang/julia/blob/98e1b13a7db5aa1d05b9a48085993375cf2298d0/src/method.c#L656 that may not be the case.
tree′ = fmap(tree; cache, exclude = Base.Fix2(isa, Leaf)) do ℓ | ||
Leaf(ℓ.rule, fmap(copy, ℓ.state; cache, exclude = iswriteable)) | ||
end | ||
x′ = fmap(copy, x; cache = empty!(cache), exclude = iswriteable) | ||
x̄s′ = fmap(copy, x̄s; cache = empty!(cache), exclude = iswriteable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out we were not defensively copying state or gradients before, so they could still be mutated by a call to update
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine never to copy gradients. It's never safe to mutate them anyway, a rule which does so (or an rrule
likewise) is simply a bug.
For copying state, can't we just say @functor Leaf (state,)
and let fmap
do it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For copying state, can't we just say
@functor Leaf (state,)
and letfmap
do it?
That breaks Leaf
identity, unfortunately. fmap
will end up untying shared parameters by creating new leaves at each location during reconstruction.
Not defensively copying gradients seems fine though, good point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't fmap
will preserve the Leaf identifications? That's what its cache is for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a recollection that it sometimes only preserved leaves, but re-reading the code you are correct.
Doctests appear to be picking up changes on master that aren't present on this branch, is that expected? I can't tweak the test because it doesn't exist here! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I tried to read this, and in the process of understanding what's going on, wrote #106. Maybe that explains my thoughts more clearly than the comments here.
mutable struct Leaf{R,S} | ||
rule::R | ||
state::S | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once this is mutable, then update!(tree, model, grad)
can be guaranteed to alter the state tree in place. This opens the possibility of simplifying the interface, and never returning multiple things whose order you have to remember.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xtree = map(cb, tree, x′, x̄s′...) | ||
return map(first, xtree), re(map(last, xtree)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This complication exists I think to reconstruct both the tree and the model on the way out of the recursion. But once Leaf
is mutable, can't we skip that, and just mutate it? Just call fmap
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, absolutely. I held off from doing that here in case some user was stashing old state trees and would be blindsided by the values in those leaves suddenly changing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow. update!
claimed it would mutate the states if it wanted to, and would typically alter arrays. (And update
claimed not to, but had a bug.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if you had immutable arrays in your state tree before, the original state tree would be unchanged after update!
. Perhaps we don't feel that was ever a solid guarantee (I don't), but we ought to get that point out in writing for posterity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I think we should change the doc for update!
to be explicit that it now guarantees to update the state tree. (I thought the old one said "inputs are trash afterwards" but in fact it is explicit only about the model.)
update!
need not in fact return two arguments, but whether that is too confusing to change (and to differ from update
which must) is another question.
tree′ = fmap(tree; cache, exclude = Base.Fix2(isa, Leaf)) do ℓ | ||
Leaf(ℓ.rule, fmap(copy, ℓ.state; cache, exclude = iswriteable)) | ||
end | ||
x′ = fmap(copy, x; cache = empty!(cache), exclude = iswriteable) | ||
x̄s′ = fmap(copy, x̄s; cache = empty!(cache), exclude = iswriteable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine never to copy gradients. It's never safe to mutate them anyway, a rule which does so (or an rrule
likewise) is simply a bug.
For copying state, can't we just say @functor Leaf (state,)
and let fmap
do it?
end | ||
end | ||
|
||
_add!(x, x̄) = iswriteable(x) ? (x .= x .+ x̄) : eltype(x).(x .+ x̄) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this is what we want. We should never ever mutate a gradient, but I think we can just call @lazy x̄old + x̄new
and lazily accumulate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My worry with the lazy accumulation approach is threefold. First, it blows any chance of making this type stable out the window. Secondly, it's possible the lazy Broadcasted
may be evaluated multiple times as it passes through a chain of rules and thus incur accumulation overhead more than once. Lastly, complicated broadcasts come with a lot of compilation latency (especially on GPU) and I'm wary of making optimizers worse than they already are on that front.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do it eagerly to avoid this. But we cannot mutate the gradients, as they may be shared with others (e.g. from the rule for +).
Lazy .+
is almost free, it's very difficult to picture evaluating this twice ever costing as much as a copy. Not sure about compile times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point about aliased gradients. If this is a correctness issue, we don't have much of a choice :)
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, _...) = nothing | ||
_accumulate!(::AbstractDict{Leaf,Any}, ::Nothing, _, ::Zero, ::Zero...) = nothing | ||
_accumulate!(::AbstractDict{Leaf,Any}, ℓ::Leaf, _, ::Zero, ::Zero...) = nothing | ||
_accumulate!(::AbstractDict{Leaf,Any}, _, _, ::Zero, ::Zero...) = nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it can be just 4 methods, if the state tree has () instead of nothing, as in #106.
I also think it would be clearer to write variable names more often, not _
, since 5 arguments is quite a few to count.
# slightly cleaner way of closing over update! internal state | ||
struct UpdateCallback | ||
acc_grads::IdDict{Leaf,Any} | ||
param_cache::IdDict{Leaf,Any} | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The limitation might be mine but I have to say I find this struct really hard to read, compared to just closing over things which have one name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what "just closing over things which have one name." entails here, can you elaborate? Another reason for the struct over a normal closure is self-recursion, which I use here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean things like this, which define a dict & then use it:
cache = IdDict{Leaf,Any}()
_accumulate!(cache, tree, x, x̄s...)
With no further names: no structs, no field names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recall trying this first, and deciding to bundle things into a struct after seeing a lot of long, long lines from threading the two IdDicts through multiple levels of functions. It may also have been tricky to get that working in a backwards compatible way, but it's been long enough that I don't remember the whole context.
Here's an evil case for shared parameters: mutable struct MutTwo; x; y; end
Functors.@functor MutTwo
tmp = MutTwo([1.0], [2.0])
model = (a=tmp, b=tmp, c=MutTwo(tmp.x, tmp.y))
state = Optimisers.setup(Momentum(), model)
model.a === model.b
model.a !== model.c # fields are identified, but struct is not
state.a.x === state.b.x
state.a === state.b
state.a === state.c # unavoidable, but means we can't use Leaf ID alone?
mgrad = (a=(x=[1.], y=[10.]), b=(x=[100], y=[1000]), c=(x=[1/3], y=[1/30]))
state2, model2 = Optimisers.update(state, model, mgrad)
model2.a === model2.b
model2.a !== model2.c The state of all 3 components is One answer here is to store tuples |
I don't think we ever guaranteed |
This makes
Leaf
a mutable type so that tied weights are represented by the same leaf instance.Although only mutable array types are automatically detected as tied, one can also tie immutable parameters by manually creating shared
Leaf
s.The test suite is practically the same as #42, with some slight modifications since there is no equivalent to
Tied
in this PR.