-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compatibility with ChainRules, Drop Julia 1.0 support #414
Conversation
Codecov Report
@@ Coverage Diff @@
## master #414 +/- ##
==========================================
+ Coverage 84.40% 84.52% +0.12%
==========================================
Files 23 24 +1
Lines 1590 1596 +6
==========================================
+ Hits 1342 1349 +7
+ Misses 248 247 -1
Continue to review full report at Codecov.
|
Awesome. What's the best way to test this? |
Actually I was going to ask the community the same question. I can of course write a naive test which takes gradient over a known function and we compare values. But that isn't the best approach I think. What we really want to do is check integration with chain rules. The discussion started with trying to implement upsample methods. Maybe we can do something with that? |
I found ChainRulesTestUtils.jl I guess that's the best approach. |
I tried using ChainRulesTestUtils but it seems like it doesn't work for Interpolation objects. Adding code here in case I am doing something obviously wrong. julia> using ChainRulesTestUtils
julia> x = collect(1:10)
julia> y = sin.(x)
julia> itp = interpolate(y,BSpline(Linear()))
julia> test_rrule(itp, 1, 1)
test_rrule: [0.8414709848078965, 0.9092974268256817, 0.1411200080598672, -0.7568024953079282, -0.9589242746631385, -0.27941549819892586, 0.6569865987187891, 0.9893582466233818, 0.4121184852417566, -0.5440211108893698] at (1, 1): Error During Test at C:\Users\paresh\.julia\packages\ChainRulesTestUtils\NCO8i\src\testers.jl:168
Got exception outside of a @test
ArgumentError: test_rrule cannot be used on closures/functors (such as [0.8414709848078965, 0.9092974268256817, 0.1411200080598672, -0.7568024953079282, -0.9589242746631385, -0.27941549819892586, 0.6569865987187891, 0.9893582466233818, 0.4121184852417566, -0.5440211108893698])
Stacktrace:
[1] _ensure_not_running_on_functor(f::Interpolations.BSplineInterpolation{Float64, 1, Vector{Float64}, BSpline{Linear}, Tuple{Base.OneTo{Int64}}}, name::String)
@ ChainRulesTestUtils ~\.julia\packages\ChainRulesTestUtils\NCO8i\src\testers.jl:238
[2] macro expansion
@ ~\.julia\packages\ChainRulesTestUtils\NCO8i\src\testers.jl:169 [inlined]
[3] macro expansion
@ C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\Test\src\Test.jl:1151 [inlined]
[4] test_rrule(::Interpolations.BSplineInterpolation{Float64, 1, Vector{Float64}, BSpline{Linear}, Tuple{Base.OneTo{Int64}}}, ::Int64, ::Vararg{Int64, N}
where N; output_tangent::ChainRulesTestUtils.Auto, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, check_inferred::Bool, fkwargs::NamedTuple{(), Tuple{}}, rtol::Float64, atol::Float64, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ChainRulesTestUtils ~\.julia\packages\ChainRulesTestUtils\NCO8i\src\testers.jl:169
[5] test_rrule(::Interpolations.BSplineInterpolation{Float64, 1, Vector{Float64}, BSpline{Linear}, Tuple{Base.OneTo{Int64}}}, ::Int64, ::Vararg{Int64, N}
where N)
@ ChainRulesTestUtils ~\.julia\packages\ChainRulesTestUtils\NCO8i\src\testers.jl:166
[6] top-level scope
@ REPL[32]:1
[7] eval
@ .\boot.jl:360 [inlined]
[8] eval
@ .\Base.jl:39 [inlined]
[9] repleval(m::Module, code::Expr, #unused#::String)
@ VSCodeServer ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\VSCodeServer\src\repl.jl:124
[10] (::VSCodeServer.var"#47#49"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\VSCodeServer\src\repl.jl:99
[11] with_logstate(f::Function, logstate::Any)
@ Base.CoreLogging .\logging.jl:491
[12] with_logger
@ .\logging.jl:603 [inlined]
[13] (::VSCodeServer.var"#46#48"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\VSCodeServer\src\repl.jl:100
[14] #invokelatest#2
@ .\essentials.jl:708 [inlined]
[15] invokelatest(::Any)
@ Base .\essentials.jl:706
[16] macro expansion
@ ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\VSCodeServer\src\eval.jl:34 [inlined]
[17] (::VSCodeServer.var"#60#61")()
@ VSCodeServer .\task.jl:406
Test Summary:
| Error Total
test_rrule: [0.8414709848078965, 0.9092974268256817, 0.1411200080598672, -0.7568024953079282, -0.9589242746631385, -0.27941549819892586, 0.6569865987187891, 0.9893582466233818, 0.4121184852417566, -0.5440211108893698] at (1, 1) | 1 1
ERROR: Some tests did not pass: 0 passed, 0 failed, 1 errored, 0 broken. So this seems like a dud, any other ideas about testing are welcome |
src/chainrules/chainrules.jl
Outdated
function ChainRulesCore.rrule(itp::AbstractInterpolation, x...) | ||
y = itp(x...) | ||
function back(Δ) | ||
(NO_FIELDS, Δ * gradient(itp, x...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you not want to update any of the fields of the interpolation too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure about that. I am confident that the second argument needs to be the gradient, but I am not sure what the first argument is even supposed to be. I have hacked this in to see if it works but I am not confident this is right. What would you recommend? Maybe you can clarify a bit more on what the first and second argument are supposed to be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change the call to Interpolations.gradient, I made the choice at random anyways.
Yay! This looks like a really nice step. Maybe qualify the call to gradient as |
I would directly test that zygote is producing the correct and expected value |
@DhairyaLGandhi thanks for your attention here. I have added a naive test, if this works we can think about adding more comprehensive testing. |
I dug around a bit more on ChainRulesTestUtils, they already have an issue open for the functor issue JuliaDiff/ChainRulesTestUtils.jl#117 (comment) . It seems like this is already a milestone for their v1. I would still proceed with the zygote but eventually we should shift to this package |
I have no idea why these tests are failing. Did I do dependencies wrong? |
Interpolations: Error During Test at D:\a\Interpolations.jl\Interpolations.jl\test\chainrules.jl:10
Test threw exception
Expression: (Zygote.gradient(itp, 1))[1] == Interpolations.gradient(itp, 1)
ArgumentError: unable to check bounds for indices of type Interpolations.WeightedAdjIndex{2,Int64} |
@mkitti I saw that, that is the original error we started to solve. I don't understand how it works on 3 tests and doesn't work on 3 tests. also, itworksonmymachine.jpg :( |
Ah I see. The tests are failing on Julia 1.0.5. They succeed on Julia 1.6. |
Where the tests are failing:
|
I added another test for 2D interpolation case. Now I am a bit stuck about how should we proceed. @DhairyaLGandhi is still scrutinizing the actual rule, but apart from that it seems like integration with Zygote is working. Ideas on how to proceed are welcome, I was thinking maybe we can write what is asked for in this issue JuliaImages/ImageTransformations.jl#113? |
What else is needed right now? I am thinking it might be good to merge this and move from there. What do you think @mkitti ? |
The new code is in a mergeable state. Some housekeeping is required. If we're going to break compat with Julia 1.0, this probably should be v0.14.0. However, it's only clear to me that we need Julia 1.3 for testing. Do we have any bounds on compat with ChainRulesCore.jl or should we just declare compat against the latest version? |
The current minor version of ChainRulesCore should be fine. You can also make a patch release, it's allowed within semver. |
If this is mergeable thats do it. this doesn't break anything and adds new functionality. The only apprehension I had was if our rule, which basically is |
Well the tests do break Julia 1.0 compatibility, which is fine. It just needs to be managed. I'll probably run a v0.13 branch that is compatible with Julia 1.0 for a while longer in parallel with a v0.14 branch that requires Julia 1.3. Documentation would be great. Could you add a |
function ChainRulesCore.rrule(itp::AbstractInterpolation, x...) | ||
y = itp(x...) | ||
function pullback(Δy) | ||
(NO_FIELDS, Δy * Interpolations.gradient(itp, x...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One quick thought (sorry this should have come earlier) NO_FIELDS
-> DoesNotExist()
. @rick2047 would you be up for a tiny follow up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I was bothered by this too. but this is already merged, can I still push to this branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the same branch. It's your branch. Just send a new PR when you are ready.
We seem to be having CI issues with ChainRulesCore.jl 0.10. @rick2047 @DhairyaLGandhi , do you have any ideas where the issue may be? |
related to #396. As mentioned by @moesphere in that thread, we need to define a rrule for the Interolations type. I've defined a simple rrule which enables me to integrate with, for example, Zygote like this:
This should allow us to integrate with Flux.