From 7ce7c247666bd1b213e18bc7301e0cea8a4563d7 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 5 Oct 2024 01:34:25 -0400 Subject: [PATCH] format --- ext/DataInterpolationsChainRulesCoreExt.jl | 2 +- test/zygote_tests.jl | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index f038db98..9e9659d1 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -16,7 +16,7 @@ else end function ChainRulesCore.rrule(::typeof(munge_data), u, t) - u_out, t_out = munge_data(u,t) + u_out, t_out = munge_data(u, t) # For now modifications by munge_data not supported @assert (u == u_out && t == t_out) diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl index 4d26b848..09b7be17 100644 --- a/test/zygote_tests.jl +++ b/test/zygote_tests.jl @@ -9,7 +9,8 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name @testset "$name, derivatives w.r.t. input" begin for _t in trange_exclude adiff = DataInterpolations.derivative(func, _t) - zdiff = u isa AbstractVector{<:Real} ? only(Zygote.gradient(func, _t)) : only(Zygote.jacobian(func, _t)) + zdiff = u isa AbstractVector{<:Real} ? only(Zygote.gradient(func, _t)) : + only(Zygote.jacobian(func, _t)) isnothing(zdiff) && (zdiff = 0.0) @test adiff ≈ zdiff end @@ -64,8 +65,9 @@ end u = [1.0 2.0; 0.0 1.0; 1.0 2.0; 0.0 1.0] test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (matrix)") - u = [[1.0,2.0,3.0,4.0],[2.0,3.0,4.0,5.0]] - test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation (vector of vectors)") + u = [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]] + test_zygote( + ConstantInterpolation, u, t, name = "Constant Interpolation (vector of vectors)") end @testset "Cubic Hermite Spline" begin