-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove redundant sum() rules #1453
base: master
Are you sure you want to change the base?
Conversation
I thought this existed in order to opt-out of the Zygote rule for sum which makes a FillArray.
We could delete that too, it saves one copy sometimes but rarely matters in real code, and causes problems. |
Deleting that rule fixes all but one testsuite, Line 53 in 6129613
Lines 340 to 342 in 6129613
Integer s and convert it to a rrule(::ZygoteRuleConfig, ...) for future-proofing at the same time?
|
We could certainly delete the rule for bool arrays, as there's one here: IDK what the issue with that Dict test is. (Considering integers to be differentiable was a mistake, IMO, but a breaking change to fix that, here or in CR.) |
The old rule was arguably wrong, because it was passing through the gradient for the summed value without doing any form of projection. If this were a scalar function, asking to differentiate wrt an integer argument would return a float gradient. So in my mind the test is actually capturing incorrect and inconsistent behaviour of the current rule. If we agree on that, I'll just tweak the test and we'll be back on green CI (minus known AbstractFFT failures). |
Sorry I didn't look closely, but if the change is just that now you get a Dict of Floats not Ints, then that seems totally fine, we just adjust the test. |
1037852
to
a32f039
Compare
The one remaining failure: sum, prod, cumsum: Test Failed at /var/lib/buildkite-agent/builds/gpuci-1/julialang/zygote-dot-jl/test/gradcheck.jl:117
Expression: gradient(sum, [true, false, true]) == (nothing,)
Evaluated: nothing == (nothing,) Which comes from the Zygote.jl/src/compiler/interface.jl Line 98 in e0d3d8b
@mcabbott do you recall why we're collapsing to |
My memory is that Zygote is eager to collapse any tuple of nothings to just nothing, but doesn't always manage to do so. I think at least |
It looks like |
Hi, |
Maybe, if we can get some consensus on the behaviour of |
The pullback is non-differentiable, which messes with nested AD (#1450). It's also not clear to me why this rule still exists when ChainRules has a seemingly GPU-compatible one. Let's see what CI says.
PR Checklist
Tests are addedDocumentation, if applicable