Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use clamptype mechanism to project onto cotangent space #965

Closed
wants to merge 11 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented May 6, 2021

This adds a mechanism to enforce that gradients stay within a cotangent space defined by the original variable. And uses this to enforce that real numbers have real gradients, and that some LinearAlgebra structured matrix types are respected.

This is a conservative in that, rather than trying to have a watertight theory of what cotangent types are allowed, it only acts on types it knows how to fix, and lets everything else pass through, as it does today. Perhaps it could learn how to fix some gradients which are NamedTuples / Composite, but later. It should also be easy for packages to extend this for their own types, if desired.

I believe there is broad agreement that we do want to respect these constraints. The argument against an explicit clamp like this is that ideally rules would always respect them and so it would not be needed. That's a nice asperation but a world in which the rule for (say) atan(x,y) gets re-written to understand what to do about x::Real, y::Complex seems far away. When the rule does get it right, this clamp should be free, as it's only about types. Individual clamptype methods can, if we wish, contain @debug statements to help track down problems. It is wasteful of course to carve a Diagonal out of a full Matrix, but better to do it immediately than to let the problem propagate. Again, it's a step forwards, not a watertight re-design.

Still WIP, it hooks onto Zygote's own rules, but not yet onto those defined in ChainRules., some bugs.

Closes #342, closes #402. Fixes #917, fixes #431.

@testset "clamped gradients" shows some examples of what works.

Needs FluxML/ZygoteRules.jl#16, thus CI won't pass. Locally, the remaining issues are things like matrix exp of Hermitian (where I believe the output has not changed), and FFT tests (many of which relied on real input giving a complex gradient).

@willtebbutt
Copy link
Member

Modification of my comment on the ZygoteRules PR -- I agree it's better to aggregate words in the same place!

I generally really like this, but I wonder whether there's a better name availlable.

This is a bit verbose, but something like project_onto_cotangent_space but might be more informative?

My other question is whether there are any situations in which type information is insufficient to make this work? For example, to know how to clamp a SparseMatrixCSC you need to know the precise sparsity pattern, which is only available at runtime, so this seems like an example where you need the object itself. I'm not proposing to actually implement stuff for SparseMatrixCSC, just saying that perhaps the API should change to require the arguments to the rrule be passed to clamptype, rather than just their type.

I made some remarks on the other PR about implementing this here vs in ChainRules. It appears that this has been considered, and for the sake of obtaining something that works well, it probably makes sense to implement this here for now, rather than in ChainRules. We can think about transfering this to somewhere in ChainRules once ChainRules types are utilised fully in Zygote.

@mcabbott
Copy link
Member Author

mcabbott commented May 6, 2021

Yes, one reason for here not ChainRules is that it can work today, as a step forward. The tests which work now, do so without interacting with ChainRules adjoints. Some of what they catch is caused by broadcasting, and I didn't think there was proposal that ChainRules would handle the entire broadcast (even if it did the scalars). I'd like to establish the precedent that this is what the correct behaviour is. If rules upstream get better at hitting the passthrough not the clamp, that’s great too. 
If the entire thing gets re-organised some day before Zygote is sent to the glue factory, that's also OK.

Mike's objection to the original version was that passing the object means it has to stay in memory, but the type is lighter. (#342) Maybe that's the problem opaque closures will solve, but not today, if I understand right? We could introduce another hook by which types define what object they want passed forwards instead of their type... at the cost of a bit more complexity, define an extensible passthis(x) = typeof(x) and use it in argTs = map(typeof, ($(argnames...),)) here (in the ZygoteRules bit). 



Are there other examples where this might be desirable? It doesn't seem obvious to me that zeros of a SparseMatrixCSC should automatically be regarded as structural, see e.g. this discussion #163 (comment) (and perhaps JuliaDiff/ForwardDiff.jl#480 for accidental-to-structural promotion problems). It was suggested that some TakeSparsitySeriously wrapper could opt-in to such things (which would still need a mechanism).

Not set on the name, Zygote's flavour is to make things compact so I try to fit in, but perhaps there's a better choice. I guess it's really clamptype!! in that it may mutate the gradient; I see that near JuliaDiff/ChainRules.jl#232 (comment) I proposed also project_cotangent!! for exactly this.

@mcabbott mcabbott changed the title Use clamptype mechanism to project onto to tangent space Use clamptype mechanism to project onto cotangent space May 7, 2021
end
end

_twofold(trans, dx) = (dx .+ trans(dx)) ./ 2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice this:

using Zygote # this PR
using LinearAlgebra, FiniteDifferences, ForwardDiff, ChainRules

gradient(x -> Symmetric([x 3x; 5x 7x])[1,2], pi)   # (3,)
gradient(A -> A[1,2], Symmetric(rand(2,2)))        # ([0.0 0.5; 0.5 0.0],)
gradient(A -> Symmetric(A)[1,2], rand(2,2))        # ([0.0 1.0; 0.0 0.0],)

grad(central_fdm(5, 1), A -> A[1,2], Symmetric(rand(2,2))) # ([0.0 1.0; 1.0 0.0],) # weird? 
grad(central_fdm(5, 1), A -> Symmetric(A)[1,2], rand(2,2)) # ([0.0 1.0; 0.0 0.0],) # fine

If you use _twofold(trans, dx) = dx .+ trans(dx) .- Diagonal(dx), then the Zygote results will double, and the 1st & 3rd will be quite obviously wrong. And this isn't a projection operator.

Is the FiniteDifferences result the right thing for a tangent not cotangent? Or is it a bug? As far as I can tell it's doing something like this:

ve, re = to_vec(Symmetric([1 3; 5 7])) # ve == [1,3,3,7]
re([1,-42,3,7]) == [1 3; 3 7]

re2(v) = Symmetric(reshape(v,2,2))
v2 = rand(1:999, 4); re(v2) == re2(v2)
re2(ForwardDiff.gradient(v -> re2(v)[1,2], ve)) == [0 1; 1 0]

Looking in ChainRules, there seem to be few tests using this behaviour of FiniteDifferences. There is a rule for Matrix, which applies (dx .+ dx' .- Diagonal(dx) to match:

rrule(Matrix, Symmetric(rand(2,2)))[2]([1 3; 5 7])[2]    == [1 8; 8 7] 
pullback(Matrix, Symmetric(rand(3,3)))[2]([1 3; 5 7])[1] == [1 4; 4 7] # with this PR

Defined here https://github.com/JuliaDiff/ChainRules.jl/blob/4e3164a3a48d4da35e0112d30be7ea9dbdaf3920/src/rulesets/LinearAlgebra/symmetric.jl#L71 where _symmetric_back is also the gradient of Symmetric where it makes more sense IMO. (Originally from Zygote, I think.)

@mcabbott
Copy link
Member Author

whether there are any situations in which type information is insufficient to make this work?

One example is #599 (restoring arrays from splats), for which this PR can provide a solution very similar to #489. But to recover a splatted matrix from a tuple, you need the size not just the type. (Discussion of options: #489 (comment), this is 1.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants