-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements to rules for norm
#337
Conversation
Codecov Report
@@ Coverage Diff @@
## master #337 +/- ##
==========================================
+ Coverage 97.64% 97.67% +0.03%
==========================================
Files 18 18
Lines 1018 1034 +16
==========================================
+ Hits 994 1010 +16
Misses 24 24
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great work, thanks! I have just some very minor suggestions, but overall LGTM. I think this might even fix some problems I am currently having with Zygote, but need to look into that some more.
We made the restriction of the julia> using LinearAlgebra, ChainRules
julia> A, ȳ = randn(ComplexF64, 5, 5), randn();
julia> hermA, Hermitian_back = rrule(Hermitian, A, :U);
julia> y, norm_back = rrule(norm, hermA);
julia> unthunk(norm_back(ȳ)[2])
ERROR: ArgumentError: Cannot set a non-diagonal index in a Hermitian matrix This is due to the broadcasting to a similar type to |
I would not describe an This is precisely because of the work-around for I guess I'm not completely sold that "all is forbidden except what is permitted" is the right policy, compared to Julia's usual AbstractArray flexibility. The motivating issue above involves StaticArrays, which used to work, and will behave well under the generic brodcast etc. of a fairly simple rule like this. (And less well under the |
The special problem is that there is a class of types which had an overly broad rule not been defined then the AD would have done the right thing, but with it defined the wrong thing happens. |
I'm not sure I follow. Which rule? In FluxML/Zygote.jl#860, there is no rule for
(Without this PR!) |
It is a general statement, not specifically related to this PR. |
Maybe this should be a separate message. Another problem with the
Does this package have opinions about second derivatives, and about tests for them? Seems tricky, and maybe not every rule can support them, but where easy it would be nice to have. |
It really seems that we rarely think about being second-differentiable, which seems like an oversight. |
Latest commit 8ab56db is an idea for fixing
Hermitian is handled by just broadcasting. And second derivatives may work:
Edit -- Tests pass locally, but fail on CI, in tests of Adjoint etc. Something is failing to be conjugated? Haven't made this happen locally yet. |
Supporting higher order rules is on the roadmap for ChainRules v2.0 I have head that it is a very common problem for Zygote nested AD not to work to to mutation.
That one will be fixed one Zygote is switched over to using |
src/rulesets/LinearAlgebra/norm.jl
Outdated
∂x = Thunk() do | ||
return if isempty(x) || p == 0 | ||
zero.(x) .* (zero(y) * zero(real(Δy))) | ||
InplaceableThunk( | ||
@thunk(zero.(x) .* (zero(y) * zero(real(Δy)))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You know what is fun? this is broken for some arrays.
We might want to constrain the eltype
to be Number
Consider:
julia> norm([[[1]]], 2)
1.0
julia> zero.([[[1]]])
ERROR: MethodError: no method matching zero(::Type{Vector{Int64}})
because of JuliaLang/julia#38064
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not blocking for this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire file should probably be restricted to arrays of <:Number
?
Done in be61a4e, but not 100% sure that's a good idea -- haven't tried to audit which rules would or would not work for arrays of arrays.
`λ .* Diagonal(rand(3))` gives a dense matrix when `x==Inf`. | ||
But `withsomezeros_rewrap(x, λ .* parent(x))` is type-stable. | ||
""" | ||
WithSomeZeros{T} = Union{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would call these StructuredSparseArray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You approve of the mechanism, #337 (comment)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am willing to give it a shot.
We can always change it later.
It's not going to lead to wrong behavour AFAICT.
It seems unfortunate not to take advantage of the fact that we know where the zeros are,
and we know that the pullback is going to map zeros to zeros, since linear.
So we should be able to skip some.
But idk that that is a generic API for our structurally sparse matrixes to know if an index will be zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I misunderstand you, but both λ .* Diagonal(rand(3))
and this function do know where the zeros are, and do O(N)
work. That's the only really sparse one.
For UpperTriangular, I haven't tried to time this against broadcasting... there could be trade-offs, maybe broadcasting skips half, but if so it needs lots of if statements. Frankly I doubt that anyone has ever called norm(::UpperTriangular)
outside a test, though. So perhaps thinking about that can wait until this finds wider use where someone does need to care.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would also be good to fix this instability upstream. Can't we argue that the off-diagonal elements are a strong zero like false
, and make NaN .* Diagonal(rand(3))
just work? Is there an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like all structural zeros should be strong yes.
I was sure I had seem julia displaying that behavour on SparseCSC matrixes, but I can't reproduce it right now.
cae1531
to
15e55ee
Compare
I wonder if we should just add StaticArrays as a dependency (it is a super popular package),
GPUArrays are also primative arrays in this sense I think? |
ae62bc0
to
741b141
Compare
Codecov Report
@@ Coverage Diff @@
## master #337 +/- ##
==========================================
+ Coverage 98.46% 98.49% +0.02%
==========================================
Files 23 23
Lines 1893 1929 +36
==========================================
+ Hits 1864 1900 +36
Misses 29 29
Continue to review full report at Codecov.
|
Maybe isn't a bad place to discuss this business of whether to write rules for Consider the following arrays, all
The function This is not true for But another position is that providing high-level rules allows more efficient code. Here's the present behaviour -- I guess you can blame
Note that In fact I would argue that such fallbacks are the reason
Ignoring this type instability issue, the generic rule already produces a (I think we should also add an overall projection step which ensures the preservation of structured types such as There isn't a special case for
At present Zygote can't differentiate The proposal to consider Some of my examples here could be evaded by making Being generic does introduce more ways to mess up -- it's obviously easier to test code which only accepts For those proposing much more restricted rules, would you care to argue where you disagree with the above? For (I think See also:
|
I would be surprised if there were an automatic way to distinguish these situations. I think we've narrowed the range of functions we're discussing down a bit now, which is good -- as I understand it we're now just considering whether, for functions for which specialised methods exist for a particular type, it's a better idea to default to utilising generic rrules, or to let AD have a go at differentiating the specialisation directly. I get the feeling that for the former is what you want for the kinds of situations you're considering, and the latter is preferable in the situations that I care about. So I think we're in the realm of trying to pick a tradeoff. Is this also your understanding? In terms of giving us nicer options, I wonder whether it's worth enabling two kinds of
This would give us the option to make quite strong statements where appropriate, and weaker ones where we just want to ensure that the fallback behaviour is reasonable. |
Yes, I think that's a fair statement of what's going on. If there's no automatic way, then some manual way to mark these distinctions is required. However, I'm not sure that tagging the original abstract rule is enough, as you suggest with these two classes of (And to make automation harder, the overload is one method down, The other manual way is to opt out. Writing By contrast, making people who don't care about AD have their specialisation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we are having useful and converstations in this thread (that i have not yet caught up on)
this code benchmarks well.
and passes tests.
Can we address my last few comments and i will merge it.
If we want to change things later we can.
(nice thing about working in code rather than wood, mistakes can be removed invisably)
Co-authored-by: Simeon Schaub <[email protected]>
Co-authored-by: Lyndon White <[email protected]>
This fixes FluxML/Zygote.jl#860 by relaxing the signature on the rule for
norm
. While there is no longer an explicit mention ofDiagonal
in the rule, the tests still check that this is correctly reproduced.It should also fix
gradient(norm, [1, 2])
, which gave an InexactError.And adds a few more InplaceableThunks while I was at it.
Edit -- fixes FluxML/Zygote.jl#960, too.