From 1c39c25bb098da161673089bdc95dd38b59e2008 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 6 Aug 2024 16:45:56 +0200 Subject: [PATCH 1/4] POC with ConstantInterpolation --- ext/DataInterpolationsChainRulesCoreExt.jl | 14 ++++++++++++-- src/derivatives.jl | 4 ++-- test/zygote_tests.jl | 19 +++++++++++++------ 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 988e519b..f038db98 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -4,17 +4,27 @@ if isdefined(Base, :get_extension) LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, get_idx, get_parameters, - _quad_interp_indices + _quad_interp_indices, munge_data using ChainRulesCore else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, get_parameters, - _quad_interp_indices + _quad_interp_indices, munge_data using ..ChainRulesCore end +function ChainRulesCore.rrule(::typeof(munge_data), u, t) + u_out, t_out = munge_data(u,t) + + # For now modifications by munge_data not supported + @assert (u == u_out && t == t_out) + + munge_data_pullback = Δ -> (NoTangent(), Δ[1], Δ[2]) + (u_out, t_out), munge_data_pullback +end + function ChainRulesCore.rrule( ::Type{LinearInterpolation}, u, t, I, p, extrapolate, cache_parameters) A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) diff --git a/src/derivatives.jl b/src/derivatives.jl index ff5b9850..bdb4d770 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -114,12 +114,12 @@ function _derivative(A::ConstantInterpolation, t::Number, iguess) return zero(first(A.u)) end -function _derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number) +function _derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number, iguess) ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) return isempty(searchsorted(A.t, t)) ? zero(A.u[1]) : eltype(A.u)(NaN) end -function _derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number) +function _derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number, iguess) ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) return isempty(searchsorted(A.t, t)) ? zero(A.u[:, 1]) : eltype(A.u)(NaN) .* A.u[:, 1] end diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl index 1a7fc447..fcbd1f1c 100644 --- a/test/zygote_tests.jl +++ b/test/zygote_tests.jl @@ -4,13 +4,12 @@ using Zygote function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name::String) func = method(args..., u, t, args_after...; kwargs..., extrapolate = true) - (; u, t) = func trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) trange_exclude = filter(x -> !in(x, t), trange) @testset "$name, derivatives w.r.t. input" begin for _t in trange_exclude adiff = DataInterpolations.derivative(func, _t) - zdiff = only(Zygote.gradient(func, _t)) + zdiff = u isa AbstractMatrix ? only(Zygote.jacobian(func, _t)) : only(Zygote.gradient(func, _t)) isnothing(zdiff) && (zdiff = 0.0) @test adiff ≈ zdiff end @@ -19,14 +18,18 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name @testset "$name, derivatives w.r.t. u" begin function f(u) A = method(args..., u, t, args_after...; kwargs..., extrapolate = true) - out = zero(eltype(u)) + out = u isa AbstractMatrix ? zero(u[:,1]) : zero(eltype(u)) + for _t in trange out += A(_t) end out end - zgrad = only(Zygote.gradient(f, u)) - fgrad = ForwardDiff.gradient(f, u) + zgrad, fgrad = if u isa AbstractMatrix + only(Zygote.jacobian(f, u)), ForwardDiff.jacobian(f, u) + else + only(Zygote.gradient(f, u)), ForwardDiff.gradient(f, u) + end @test zgrad ≈ fgrad end end @@ -48,7 +51,11 @@ end @testset "Constant Interpolation" begin u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] t = collect(0.0:10.0) - test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation") + test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation (vector)") + + t = [1.0, 4.0] + u = [1.0 2.0; 0.0 1.0; 1.0 2.0; 0.0 1.0] + test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (matrix)") end @testset "Cubic Hermite Spline" begin From f0ef765ed682b5aa26c1842023a28fae933fa285 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 6 Aug 2024 17:31:00 +0200 Subject: [PATCH 2/4] Attempt at test fix for vector of vectors --- test/zygote_tests.jl | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl index fcbd1f1c..4d26b848 100644 --- a/test/zygote_tests.jl +++ b/test/zygote_tests.jl @@ -9,7 +9,7 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name @testset "$name, derivatives w.r.t. input" begin for _t in trange_exclude adiff = DataInterpolations.derivative(func, _t) - zdiff = u isa AbstractMatrix ? only(Zygote.jacobian(func, _t)) : only(Zygote.gradient(func, _t)) + zdiff = u isa AbstractVector{<:Real} ? only(Zygote.gradient(func, _t)) : only(Zygote.jacobian(func, _t)) isnothing(zdiff) && (zdiff = 0.0) @test adiff ≈ zdiff end @@ -18,19 +18,26 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name @testset "$name, derivatives w.r.t. u" begin function f(u) A = method(args..., u, t, args_after...; kwargs..., extrapolate = true) - out = u isa AbstractMatrix ? zero(u[:,1]) : zero(eltype(u)) + out = if u isa AbstractVector{<:Real} + zero(eltype(u)) + elseif u isa AbstractMatrix + zero(u[:, 1]) + else + zero(u[1]) + end for _t in trange out += A(_t) end out end - zgrad, fgrad = if u isa AbstractMatrix - only(Zygote.jacobian(f, u)), ForwardDiff.jacobian(f, u) + zgrad, fgrad = if u isa AbstractVector{<:Real} + Zygote.gradient(f, u), ForwardDiff.gradient(f, u) + elseif u isa AbstractMatrix + Zygote.jacobian(f, u), ForwardDiff.jacobian(f, u) else - only(Zygote.gradient(f, u)), ForwardDiff.gradient(f, u) + Zygote.jacobian(f, u), ForwardDiff.jacobian(f, hcat(u...)) end - @test zgrad ≈ fgrad end end end @@ -56,6 +63,9 @@ end t = [1.0, 4.0] u = [1.0 2.0; 0.0 1.0; 1.0 2.0; 0.0 1.0] test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (matrix)") + + u = [[1.0,2.0,3.0,4.0],[2.0,3.0,4.0,5.0]] + test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (vector of vectors)") end @testset "Cubic Hermite Spline" begin From a883ffe7c79ea0ea003637803243bfe651dadbe3 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 5 Oct 2024 01:03:30 -0400 Subject: [PATCH 3/4] Fix derivative tests --- test/derivative_tests.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index a4eb349b..fd9745de 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -133,9 +133,11 @@ end @testset "Constant Interpolation" begin u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] - t = collect(0.0:10.0) + t = collect(0.0:11.0) A = ConstantInterpolation(u, t) - @test all(derivative.(Ref(A), t) .== 0.0) + t2 = collect(0.0:10.0) + @test all(isnan, derivative.(Ref(A), t)) + @test all(derivative.(Ref(A), t2 .+ 0.1) .== 0.0) end @testset "Quadratic Spline" begin From 7ce7c247666bd1b213e18bc7301e0cea8a4563d7 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 5 Oct 2024 01:34:25 -0400 Subject: [PATCH 4/4] format --- ext/DataInterpolationsChainRulesCoreExt.jl | 2 +- test/zygote_tests.jl | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index f038db98..9e9659d1 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -16,7 +16,7 @@ else end function ChainRulesCore.rrule(::typeof(munge_data), u, t) - u_out, t_out = munge_data(u,t) + u_out, t_out = munge_data(u, t) # For now modifications by munge_data not supported @assert (u == u_out && t == t_out) diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl index 4d26b848..09b7be17 100644 --- a/test/zygote_tests.jl +++ b/test/zygote_tests.jl @@ -9,7 +9,8 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name @testset "$name, derivatives w.r.t. input" begin for _t in trange_exclude adiff = DataInterpolations.derivative(func, _t) - zdiff = u isa AbstractVector{<:Real} ? only(Zygote.gradient(func, _t)) : only(Zygote.jacobian(func, _t)) + zdiff = u isa AbstractVector{<:Real} ? only(Zygote.gradient(func, _t)) : + only(Zygote.jacobian(func, _t)) isnothing(zdiff) && (zdiff = 0.0) @test adiff ≈ zdiff end @@ -64,8 +65,9 @@ end u = [1.0 2.0; 0.0 1.0; 1.0 2.0; 0.0 1.0] test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (matrix)") - u = [[1.0,2.0,3.0,4.0],[2.0,3.0,4.0,5.0]] - test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (vector of vectors)") + u = [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]] + test_zygote( + ConstantInterpolation, u, t, name = "Constant Interpolation (vector of vectors)") end @testset "Cubic Hermite Spline" begin