Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with ChainRules, Drop Julia 1.0 support #414

Merged
merged 11 commits into from
Apr 17, 2021

Conversation

rick2047
Copy link
Contributor

@rick2047 rick2047 commented Apr 14, 2021

related to #396. As mentioned by @moesphere in that thread, we need to define a rrule for the Interolations type. I've defined a simple rrule which enables me to integrate with, for example, Zygote like this:

julia> using Interpolations

julia> y = sin.(1:10)
10-element Vector{Float64}:
  0.8414709848078965
  0.9092974268256817
  0.1411200080598672
 -0.7568024953079282
 -0.9589242746631385
 -0.27941549819892586
  0.6569865987187891
  0.9893582466233818
  0.4121184852417566
 -0.5440211108893698

julia> itp = interpolate(y,BSpline(Cubic(Reflect(OnCell()))))
10-element interpolate(OffsetArray(::Vector{Float64}, 0:11), BSpline(Cubic(Reflect(OnCell())))) with element type Float64:
  0.7125302276762112
  0.769963450028415
  0.1194958272928704
 -0.6408357079724973
 -0.8119858486932346
 -0.23659994479000876
  0.556314857216602
  0.8377563450756788
  0.3489685127835064
 -0.29399432638595374

julia> Zygote.gradient(itp, 2)
([-0.35017548837401463],)

julia> Interpolations.gradient(itp, 2)
1-element StaticArrays.SVector{1, Float64} with indices SOneTo(1):
 -0.35017548837401463

This should allow us to integrate with Flux.

@codecov
Copy link

codecov bot commented Apr 14, 2021

Codecov Report

Merging #414 (8a82a9d) into master (3b00087) will increase coverage by 0.12%.
The diff coverage is 100.00%.

❗ Current head 8a82a9d differs from pull request most recent head 1ca330c. Consider uploading reports for the commit 1ca330c to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##           master     #414      +/-   ##
==========================================
+ Coverage   84.40%   84.52%   +0.12%     
==========================================
  Files          23       24       +1     
  Lines        1590     1596       +6     
==========================================
+ Hits         1342     1349       +7     
+ Misses        248      247       -1     
Impacted Files Coverage Δ
src/Interpolations.jl 76.98% <ø> (+0.18%) ⬆️
src/chainrules/chainrules.jl 100.00% <100.00%> (ø)
src/b-splines/indexing.jl 64.22% <0.00%> (+0.81%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9375e3b...1ca330c. Read the comment docs.

@mkitti
Copy link
Collaborator

mkitti commented Apr 14, 2021

Awesome. What's the best way to test this?

@rick2047
Copy link
Contributor Author

Actually I was going to ask the community the same question. I can of course write a naive test which takes gradient over a known function and we compare values. But that isn't the best approach I think. What we really want to do is check integration with chain rules.

The discussion started with trying to implement upsample methods. Maybe we can do something with that?

@rick2047
Copy link
Contributor Author

I found ChainRulesTestUtils.jl I guess that's the best approach.

@rick2047
Copy link
Contributor Author

I tried using ChainRulesTestUtils but it seems like it doesn't work for Interpolation objects. Adding code here in case I am doing something obviously wrong.

julia> using ChainRulesTestUtils

julia> x = collect(1:10)
julia> y = sin.(x)

julia> itp = interpolate(y,BSpline(Linear()))


julia> test_rrule(itp, 1, 1)
test_rrule: [0.8414709848078965, 0.9092974268256817, 0.1411200080598672, -0.7568024953079282, -0.9589242746631385, -0.27941549819892586, 0.6569865987187891, 0.9893582466233818, 0.4121184852417566, -0.5440211108893698] at (1, 1): Error During Test at C:\Users\paresh\.julia\packages\ChainRulesTestUtils\NCO8i\src\testers.jl:168
  Got exception outside of a @test
  ArgumentError: test_rrule cannot be used on closures/functors (such as [0.8414709848078965, 0.9092974268256817, 0.1411200080598672, -0.7568024953079282, -0.9589242746631385, -0.27941549819892586, 0.6569865987187891, 0.9893582466233818, 0.4121184852417566, -0.5440211108893698])
  Stacktrace:
    [1] _ensure_not_running_on_functor(f::Interpolations.BSplineInterpolation{Float64, 1, Vector{Float64}, BSpline{Linear}, Tuple{Base.OneTo{Int64}}}, name::String)
      @ ChainRulesTestUtils ~\.julia\packages\ChainRulesTestUtils\NCO8i\src\testers.jl:238
    [2] macro expansion
      @ ~\.julia\packages\ChainRulesTestUtils\NCO8i\src\testers.jl:169 [inlined]
    [3] macro expansion
      @ C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\Test\src\Test.jl:1151 [inlined]
    [4] test_rrule(::Interpolations.BSplineInterpolation{Float64, 1, Vector{Float64}, BSpline{Linear}, Tuple{Base.OneTo{Int64}}}, ::Int64, ::Vararg{Int64, N} 
where N; output_tangent::ChainRulesTestUtils.Auto, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, check_inferred::Bool, fkwargs::NamedTuple{(), Tuple{}}, rtol::Float64, atol::Float64, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
      @ ChainRulesTestUtils ~\.julia\packages\ChainRulesTestUtils\NCO8i\src\testers.jl:169
    [5] test_rrule(::Interpolations.BSplineInterpolation{Float64, 1, Vector{Float64}, BSpline{Linear}, Tuple{Base.OneTo{Int64}}}, ::Int64, ::Vararg{Int64, N} 
where N)
      @ ChainRulesTestUtils ~\.julia\packages\ChainRulesTestUtils\NCO8i\src\testers.jl:166
    [6] top-level scope
      @ REPL[32]:1
    [7] eval
      @ .\boot.jl:360 [inlined]
    [8] eval
      @ .\Base.jl:39 [inlined]
    [9] repleval(m::Module, code::Expr, #unused#::String)
      @ VSCodeServer ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\VSCodeServer\src\repl.jl:124
   [10] (::VSCodeServer.var"#47#49"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
      @ VSCodeServer ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\VSCodeServer\src\repl.jl:99
   [11] with_logstate(f::Function, logstate::Any)
      @ Base.CoreLogging .\logging.jl:491
   [12] with_logger
      @ .\logging.jl:603 [inlined]
   [13] (::VSCodeServer.var"#46#48"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
      @ VSCodeServer ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\VSCodeServer\src\repl.jl:100
   [14] #invokelatest#2
      @ .\essentials.jl:708 [inlined]
   [15] invokelatest(::Any)
      @ Base .\essentials.jl:706
   [16] macro expansion
      @ ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\VSCodeServer\src\eval.jl:34 [inlined]
   [17] (::VSCodeServer.var"#60#61")()
      @ VSCodeServer .\task.jl:406
Test Summary:
                                                                      | Error  Total
test_rrule: [0.8414709848078965, 0.9092974268256817, 0.1411200080598672, -0.7568024953079282, -0.9589242746631385, -0.27941549819892586, 0.6569865987187891, 0.9893582466233818, 0.4121184852417566, -0.5440211108893698] at (1, 1) |     1      1
ERROR: Some tests did not pass: 0 passed, 0 failed, 1 errored, 0 broken.

So this seems like a dud, any other ideas about testing are welcome

function ChainRulesCore.rrule(itp::AbstractInterpolation, x...)
y = itp(x...)
function back(Δ)
(NO_FIELDS, Δ * gradient(itp, x...))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you not want to update any of the fields of the interpolation too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about that. I am confident that the second argument needs to be the gradient, but I am not sure what the first argument is even supposed to be. I have hacked this in to see if it works but I am not confident this is right. What would you recommend? Maybe you can clarify a bit more on what the first and second argument are supposed to be

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change the call to Interpolations.gradient, I made the choice at random anyways.

@DhairyaLGandhi
Copy link

Yay! This looks like a really nice step. Maybe qualify the call to gradient as Interpolations.gradient for readability if nothing else?

@DhairyaLGandhi
Copy link

I would directly test that zygote is producing the correct and expected value

@rick2047
Copy link
Contributor Author

@DhairyaLGandhi thanks for your attention here. I have added a naive test, if this works we can think about adding more comprehensive testing.

@rick2047
Copy link
Contributor Author

rick2047 commented Apr 16, 2021

I dug around a bit more on ChainRulesTestUtils, they already have an issue open for the functor issue JuliaDiff/ChainRulesTestUtils.jl#117 (comment) . It seems like this is already a milestone for their v1. I would still proceed with the zygote but eventually we should shift to this package

@rick2047
Copy link
Contributor Author

I have no idea why these tests are failing. Did I do dependencies wrong?

@mkitti
Copy link
Collaborator

mkitti commented Apr 16, 2021

Interpolations: Error During Test at D:\a\Interpolations.jl\Interpolations.jl\test\chainrules.jl:10
  Test threw exception
  Expression: (Zygote.gradient(itp, 1))[1] == Interpolations.gradient(itp, 1)
  ArgumentError: unable to check bounds for indices of type Interpolations.WeightedAdjIndex{2,Int64}

@rick2047
Copy link
Contributor Author

@mkitti I saw that, that is the original error we started to solve. I don't understand how it works on 3 tests and doesn't work on 3 tests.

also, itworksonmymachine.jpg :(

@mkitti
Copy link
Collaborator

mkitti commented Apr 16, 2021

Ah I see. The tests are failing on Julia 1.0.5. They succeed on Julia 1.6.

@mkitti
Copy link
Collaborator

mkitti commented Apr 16, 2021

Where the tests are failing:

 Installed Zygote ─────────────── v0.4.20
 Installed ZygoteRules ────────── v0.2.1

@rick2047
Copy link
Contributor Author

I added another test for 2D interpolation case. Now I am a bit stuck about how should we proceed. @DhairyaLGandhi is still scrutinizing the actual rule, but apart from that it seems like integration with Zygote is working. Ideas on how to proceed are welcome, I was thinking maybe we can write what is asked for in this issue JuliaImages/ImageTransformations.jl#113?

@DhairyaLGandhi
Copy link

What else is needed right now? I am thinking it might be good to merge this and move from there. What do you think @mkitti ?

@mkitti
Copy link
Collaborator

mkitti commented Apr 17, 2021

The new code is in a mergeable state. Some housekeeping is required.

If we're going to break compat with Julia 1.0, this probably should be v0.14.0. However, it's only clear to me that we need Julia 1.3 for testing.

Do we have any bounds on compat with ChainRulesCore.jl or should we just declare compat against the latest version?

@DhairyaLGandhi
Copy link

The current minor version of ChainRulesCore should be fine. You can also make a patch release, it's allowed within semver.

@rick2047
Copy link
Contributor Author

If this is mergeable thats do it. this doesn't break anything and adds new functionality. The only apprehension I had was if our rule, which basically is Zygote.gradient(itp,x)[1] == Interpolations.gradient(itp,x) the right definition. Should I document this somewhere?

@mkitti
Copy link
Collaborator

mkitti commented Apr 17, 2021

Well the tests do break Julia 1.0 compatibility, which is fine. It just needs to be managed. I'll probably run a v0.13 branch that is compatible with Julia 1.0 for a while longer in parallel with a v0.14 branch that requires Julia 1.3.

Documentation would be great. Could you add a docs/src/chainrules.md to describe this?

@mkitti mkitti changed the title [WIP] Compatibility with ChainRules Compatibility with ChainRules Apr 17, 2021
@mkitti mkitti changed the title Compatibility with ChainRules Compatibility with ChainRules, Drop Julia 1.0 support Apr 17, 2021
@mkitti mkitti merged commit ee4a49b into JuliaMath:master Apr 17, 2021
function ChainRulesCore.rrule(itp::AbstractInterpolation, x...)
y = itp(x...)
function pullback(Δy)
(NO_FIELDS, Δy * Interpolations.gradient(itp, x...))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One quick thought (sorry this should have come earlier) NO_FIELDS -> DoesNotExist(). @rick2047 would you be up for a tiny follow up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I was bothered by this too. but this is already merged, can I still push to this branch?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the same branch. It's your branch. Just send a new PR when you are ready.

@mkitti
Copy link
Collaborator

mkitti commented Jun 8, 2021

We seem to be having CI issues with ChainRulesCore.jl 0.10. @rick2047 @DhairyaLGandhi , do you have any ideas where the issue may be?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants