diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 988e519b..9e9659d1 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -4,17 +4,27 @@ if isdefined(Base, :get_extension) LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, get_idx, get_parameters, - _quad_interp_indices + _quad_interp_indices, munge_data using ChainRulesCore else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox, get_parameters, - _quad_interp_indices + _quad_interp_indices, munge_data using ..ChainRulesCore end +function ChainRulesCore.rrule(::typeof(munge_data), u, t) + u_out, t_out = munge_data(u, t) + + # For now modifications by munge_data not supported + @assert (u == u_out && t == t_out) + + munge_data_pullback = Δ -> (NoTangent(), Δ[1], Δ[2]) + (u_out, t_out), munge_data_pullback +end + function ChainRulesCore.rrule( ::Type{LinearInterpolation}, u, t, I, p, extrapolate, cache_parameters) A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) diff --git a/src/derivatives.jl b/src/derivatives.jl index ff5b9850..bdb4d770 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -114,12 +114,12 @@ function _derivative(A::ConstantInterpolation, t::Number, iguess) return zero(first(A.u)) end -function _derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number) +function _derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number, iguess) ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) return isempty(searchsorted(A.t, t)) ? zero(A.u[1]) : eltype(A.u)(NaN) end -function _derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number) +function _derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number, iguess) ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) return isempty(searchsorted(A.t, t)) ? zero(A.u[:, 1]) : eltype(A.u)(NaN) .* A.u[:, 1] end diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index a4eb349b..fd9745de 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -133,9 +133,11 @@ end @testset "Constant Interpolation" begin u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] - t = collect(0.0:10.0) + t = collect(0.0:11.0) A = ConstantInterpolation(u, t) - @test all(derivative.(Ref(A), t) .== 0.0) + t2 = collect(0.0:10.0) + @test all(isnan, derivative.(Ref(A), t)) + @test all(derivative.(Ref(A), t2 .+ 0.1) .== 0.0) end @testset "Quadratic Spline" begin diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl index 1a7fc447..09b7be17 100644 --- a/test/zygote_tests.jl +++ b/test/zygote_tests.jl @@ -4,13 +4,13 @@ using Zygote function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name::String) func = method(args..., u, t, args_after...; kwargs..., extrapolate = true) - (; u, t) = func trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) trange_exclude = filter(x -> !in(x, t), trange) @testset "$name, derivatives w.r.t. input" begin for _t in trange_exclude adiff = DataInterpolations.derivative(func, _t) - zdiff = only(Zygote.gradient(func, _t)) + zdiff = u isa AbstractVector{<:Real} ? only(Zygote.gradient(func, _t)) : + only(Zygote.jacobian(func, _t)) isnothing(zdiff) && (zdiff = 0.0) @test adiff ≈ zdiff end @@ -19,15 +19,26 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name @testset "$name, derivatives w.r.t. u" begin function f(u) A = method(args..., u, t, args_after...; kwargs..., extrapolate = true) - out = zero(eltype(u)) + out = if u isa AbstractVector{<:Real} + zero(eltype(u)) + elseif u isa AbstractMatrix + zero(u[:, 1]) + else + zero(u[1]) + end + for _t in trange out += A(_t) end out end - zgrad = only(Zygote.gradient(f, u)) - fgrad = ForwardDiff.gradient(f, u) - @test zgrad ≈ fgrad + zgrad, fgrad = if u isa AbstractVector{<:Real} + Zygote.gradient(f, u), ForwardDiff.gradient(f, u) + elseif u isa AbstractMatrix + Zygote.jacobian(f, u), ForwardDiff.jacobian(f, u) + else + Zygote.jacobian(f, u), ForwardDiff.jacobian(f, hcat(u...)) + end end end end @@ -48,7 +59,15 @@ end @testset "Constant Interpolation" begin u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] t = collect(0.0:10.0) - test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation") + test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation (vector)") + + t = [1.0, 4.0] + u = [1.0 2.0; 0.0 1.0; 1.0 2.0; 0.0 1.0] + test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (matrix)") + + u = [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]] + test_zygote( + ConstantInterpolation, u, t, name = "Constant Interpolation (vector of vectors)") end @testset "Cubic Hermite Spline" begin