Skip to content

Commit

Permalink
Merge pull request #324 from SouthEndMusic/zygote_vector_u
Browse files Browse the repository at this point in the history
Fix Zygote AD with `u` a matrix or vector of arrays
  • Loading branch information
ChrisRackauckas authored Oct 5, 2024
2 parents aa02884 + 7ce7c24 commit b39b336
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 13 deletions.
14 changes: 12 additions & 2 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,27 @@ if isdefined(Base, :get_extension)
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox, get_idx, get_parameters,
_quad_interp_indices
_quad_interp_indices, munge_data
using ChainRulesCore
else
using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation,
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox, get_parameters,
_quad_interp_indices
_quad_interp_indices, munge_data
using ..ChainRulesCore
end

function ChainRulesCore.rrule(::typeof(munge_data), u, t)
u_out, t_out = munge_data(u, t)

# For now modifications by munge_data not supported
@assert (u == u_out && t == t_out)

munge_data_pullback = Δ -> (NoTangent(), Δ[1], Δ[2])
(u_out, t_out), munge_data_pullback
end

function ChainRulesCore.rrule(
::Type{LinearInterpolation}, u, t, I, p, extrapolate, cache_parameters)
A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters)
Expand Down
4 changes: 2 additions & 2 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ function _derivative(A::ConstantInterpolation, t::Number, iguess)
return zero(first(A.u))
end

function _derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number)
function _derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number, iguess)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
return isempty(searchsorted(A.t, t)) ? zero(A.u[1]) : eltype(A.u)(NaN)
end

function _derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number)
function _derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number, iguess)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
return isempty(searchsorted(A.t, t)) ? zero(A.u[:, 1]) : eltype(A.u)(NaN) .* A.u[:, 1]
end
Expand Down
6 changes: 4 additions & 2 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ end

@testset "Constant Interpolation" begin
u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0]
t = collect(0.0:10.0)
t = collect(0.0:11.0)
A = ConstantInterpolation(u, t)
@test all(derivative.(Ref(A), t) .== 0.0)
t2 = collect(0.0:10.0)
@test all(isnan, derivative.(Ref(A), t))
@test all(derivative.(Ref(A), t2 .+ 0.1) .== 0.0)
end

@testset "Quadratic Spline" begin
Expand Down
33 changes: 26 additions & 7 deletions test/zygote_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ using Zygote

function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name::String)
func = method(args..., u, t, args_after...; kwargs..., extrapolate = true)
(; u, t) = func
trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1))
trange_exclude = filter(x -> !in(x, t), trange)
@testset "$name, derivatives w.r.t. input" begin
for _t in trange_exclude
adiff = DataInterpolations.derivative(func, _t)
zdiff = only(Zygote.gradient(func, _t))
zdiff = u isa AbstractVector{<:Real} ? only(Zygote.gradient(func, _t)) :
only(Zygote.jacobian(func, _t))
isnothing(zdiff) && (zdiff = 0.0)
@test adiff zdiff
end
Expand All @@ -19,15 +19,26 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name
@testset "$name, derivatives w.r.t. u" begin
function f(u)
A = method(args..., u, t, args_after...; kwargs..., extrapolate = true)
out = zero(eltype(u))
out = if u isa AbstractVector{<:Real}
zero(eltype(u))
elseif u isa AbstractMatrix
zero(u[:, 1])
else
zero(u[1])
end

for _t in trange
out += A(_t)
end
out
end
zgrad = only(Zygote.gradient(f, u))
fgrad = ForwardDiff.gradient(f, u)
@test zgrad fgrad
zgrad, fgrad = if u isa AbstractVector{<:Real}
Zygote.gradient(f, u), ForwardDiff.gradient(f, u)
elseif u isa AbstractMatrix
Zygote.jacobian(f, u), ForwardDiff.jacobian(f, u)
else
Zygote.jacobian(f, u), ForwardDiff.jacobian(f, hcat(u...))
end
end
end
end
Expand All @@ -48,7 +59,15 @@ end
@testset "Constant Interpolation" begin
u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0]
t = collect(0.0:10.0)
test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation")
test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation (vector)")

t = [1.0, 4.0]
u = [1.0 2.0; 0.0 1.0; 1.0 2.0; 0.0 1.0]
test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (matrix)")

u = [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]]
test_zygote(
ConstantInterpolation, u, t, name = "Constant Interpolation (vector of vectors)")
end

@testset "Cubic Hermite Spline" begin
Expand Down

0 comments on commit b39b336

Please sign in to comment.