Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Oct 5, 2024
1 parent a883ffe commit 7ce7c24
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ else
end

function ChainRulesCore.rrule(::typeof(munge_data), u, t)
u_out, t_out = munge_data(u,t)
u_out, t_out = munge_data(u, t)

# For now modifications by munge_data not supported
@assert (u == u_out && t == t_out)
Expand Down
8 changes: 5 additions & 3 deletions test/zygote_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name
@testset "$name, derivatives w.r.t. input" begin
for _t in trange_exclude
adiff = DataInterpolations.derivative(func, _t)
zdiff = u isa AbstractVector{<:Real} ? only(Zygote.gradient(func, _t)) : only(Zygote.jacobian(func, _t))
zdiff = u isa AbstractVector{<:Real} ? only(Zygote.gradient(func, _t)) :
only(Zygote.jacobian(func, _t))
isnothing(zdiff) && (zdiff = 0.0)
@test adiff zdiff
end
Expand Down Expand Up @@ -64,8 +65,9 @@ end
u = [1.0 2.0; 0.0 1.0; 1.0 2.0; 0.0 1.0]
test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (matrix)")

u = [[1.0,2.0,3.0,4.0],[2.0,3.0,4.0,5.0]]
test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (vector of vectors)")
u = [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]]
test_zygote(
ConstantInterpolation, u, t, name = "Constant Interpolation (vector of vectors)")
end

@testset "Cubic Hermite Spline" begin
Expand Down

0 comments on commit 7ce7c24

Please sign in to comment.