Skip to content

Commit

Permalink
Attempt at test fix for vector of vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Aug 6, 2024
1 parent 1a98506 commit 4bf1b12
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions test/zygote_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name
@testset "$name, derivatives w.r.t. input" begin
for _t in trange_exclude
adiff = DataInterpolations.derivative(func, _t)
zdiff = u isa AbstractMatrix ? only(Zygote.jacobian(func, _t)) : only(Zygote.gradient(func, _t))
zdiff = u isa AbstractVector{<:Real} ? only(Zygote.gradient(func, _t)) : only(Zygote.jacobian(func, _t))
isnothing(zdiff) && (zdiff = 0.0)
@test adiff zdiff
end
Expand All @@ -18,19 +18,26 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name
@testset "$name, derivatives w.r.t. u" begin
function f(u)
A = method(args..., u, t, args_after...; kwargs..., extrapolate = true)
out = u isa AbstractMatrix ? zero(u[:,1]) : zero(eltype(u))
out = if u isa AbstractVector{<:Real}
zero(eltype(u))
elseif u isa AbstractMatrix
zero(u[:, 1])
else
zero(u[1])
end

for _t in trange
out += A(_t)
end
out
end
zgrad, fgrad = if u isa AbstractMatrix
only(Zygote.jacobian(f, u)), ForwardDiff.jacobian(f, u)
zgrad, fgrad = if u isa AbstractVector{<:Real}
Zygote.gradient(f, u), ForwardDiff.gradient(f, u)
elseif u isa AbstractMatrix
Zygote.jacobian(f, u), ForwardDiff.jacobian(f, u)
else
only(Zygote.gradient(f, u)), ForwardDiff.gradient(f, u)
Zygote.jacobian(f, u), ForwardDiff.jacobian(f, hcat(u...))
end
@test zgrad fgrad
end
end
end
Expand All @@ -56,6 +63,9 @@ end
t = [1.0, 4.0]
u = [1.0 2.0; 0.0 1.0; 1.0 2.0; 0.0 1.0]
test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (matrix)")

u = [[1.0,2.0,3.0,4.0],[2.0,3.0,4.0,5.0]]
test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (vector of vectors)")
end

@testset "Cubic Hermite Spline" begin
Expand Down

0 comments on commit 4bf1b12

Please sign in to comment.