-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use clamptype
mechanism to project onto cotangent space
#965
Conversation
Modification of my comment on the ZygoteRules PR -- I agree it's better to aggregate words in the same place! I generally really like this, but I wonder whether there's a better name availlable. This is a bit verbose, but something like project_onto_cotangent_space but might be more informative? My other question is whether there are any situations in which type information is insufficient to make this work? For example, to know how to clamp a I made some remarks on the other PR about implementing this here vs in ChainRules. It appears that this has been considered, and for the sake of obtaining something that works well, it probably makes sense to implement this here for now, rather than in ChainRules. We can think about transfering this to somewhere in ChainRules once ChainRules types are utilised fully in Zygote. |
Yes, one reason for here not ChainRules is that it can work today, as a step forward. The tests which work now, do so without interacting with ChainRules adjoints. Some of what they catch is caused by broadcasting, and I didn't think there was proposal that ChainRules would handle the entire broadcast (even if it did the scalars). I'd like to establish the precedent that this is what the correct behaviour is. If rules upstream get better at hitting the passthrough not the clamp, that’s great too. If the entire thing gets re-organised some day before Zygote is sent to the glue factory, that's also OK. Mike's objection to the original version was that passing the object means it has to stay in memory, but the type is lighter. (#342) Maybe that's the problem opaque closures will solve, but not today, if I understand right? We could introduce another hook by which types define what object they want passed forwards instead of their type... at the cost of a bit more complexity, define an extensible Are there other examples where this might be desirable? It doesn't seem obvious to me that zeros of a Not set on the name, Zygote's flavour is to make things compact so I try to fit in, but perhaps there's a better choice. I guess it's really |
clamptype
mechanism to project onto to tangent spaceclamptype
mechanism to project onto cotangent space
end | ||
end | ||
|
||
_twofold(trans, dx) = (dx .+ trans(dx)) ./ 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notice this:
using Zygote # this PR
using LinearAlgebra, FiniteDifferences, ForwardDiff, ChainRules
gradient(x -> Symmetric([x 3x; 5x 7x])[1,2], pi) # (3,)
gradient(A -> A[1,2], Symmetric(rand(2,2))) # ([0.0 0.5; 0.5 0.0],)
gradient(A -> Symmetric(A)[1,2], rand(2,2)) # ([0.0 1.0; 0.0 0.0],)
grad(central_fdm(5, 1), A -> A[1,2], Symmetric(rand(2,2))) # ([0.0 1.0; 1.0 0.0],) # weird?
grad(central_fdm(5, 1), A -> Symmetric(A)[1,2], rand(2,2)) # ([0.0 1.0; 0.0 0.0],) # fine
If you use _twofold(trans, dx) = dx .+ trans(dx) .- Diagonal(dx)
, then the Zygote results will double, and the 1st & 3rd will be quite obviously wrong. And this isn't a projection operator.
Is the FiniteDifferences result the right thing for a tangent not cotangent? Or is it a bug? As far as I can tell it's doing something like this:
ve, re = to_vec(Symmetric([1 3; 5 7])) # ve == [1,3,3,7]
re([1,-42,3,7]) == [1 3; 3 7]
re2(v) = Symmetric(reshape(v,2,2))
v2 = rand(1:999, 4); re(v2) == re2(v2)
re2(ForwardDiff.gradient(v -> re2(v)[1,2], ve)) == [0 1; 1 0]
Looking in ChainRules, there seem to be few tests using this behaviour of FiniteDifferences. There is a rule for Matrix
, which applies (dx .+ dx' .- Diagonal(dx)
to match:
rrule(Matrix, Symmetric(rand(2,2)))[2]([1 3; 5 7])[2] == [1 8; 8 7]
pullback(Matrix, Symmetric(rand(3,3)))[2]([1 3; 5 7])[1] == [1 4; 4 7] # with this PR
Defined here https://github.com/JuliaDiff/ChainRules.jl/blob/4e3164a3a48d4da35e0112d30be7ea9dbdaf3920/src/rulesets/LinearAlgebra/symmetric.jl#L71 where _symmetric_back
is also the gradient of Symmetric
where it makes more sense IMO. (Originally from Zygote, I think.)
One example is #599 (restoring arrays from splats), for which this PR can provide a solution very similar to #489. But to recover a splatted matrix from a tuple, you need the size not just the type. (Discussion of options: #489 (comment), this is 1.) |
This adds a mechanism to enforce that gradients stay within a cotangent space defined by the original variable. And uses this to enforce that real numbers have real gradients, and that some LinearAlgebra structured matrix types are respected.
This is a conservative in that, rather than trying to have a watertight theory of what cotangent types are allowed, it only acts on types it knows how to fix, and lets everything else pass through, as it does today. Perhaps it could learn how to fix some gradients which are NamedTuples / Composite, but later. It should also be easy for packages to extend this for their own types, if desired.
I believe there is broad agreement that we do want to respect these constraints. The argument against an explicit clamp like this is that ideally rules would always respect them and so it would not be needed. That's a nice asperation but a world in which the rule for (say)
atan(x,y)
gets re-written to understand what to do aboutx::Real, y::Complex
seems far away. When the rule does get it right, this clamp should be free, as it's only about types. Individualclamptype
methods can, if we wish, contain@debug
statements to help track down problems. It is wasteful of course to carve aDiagonal
out of a full Matrix, but better to do it immediately than to let the problem propagate. Again, it's a step forwards, not a watertight re-design.Still WIP,
it hooks onto Zygote's own rules, but not yet onto those defined in ChainRules., some bugs.Closes #342, closes #402. Fixes #917, fixes #431.
@testset "clamped gradients"
shows some examples of what works.Needs FluxML/ZygoteRules.jl#16, thus CI won't pass. Locally, the remaining issues are things like matrix
exp
of Hermitian (where I believe the output has not changed), and FFT tests (many of which relied on real input giving a complex gradient).