Skip to content

Commit

Permalink
Merge pull request SciML#2472 from oscardssmith/os/fix-jacobian!-stats
Browse files Browse the repository at this point in the history
fix `ForwardDiff` `jacobian!` stats
  • Loading branch information
oscardssmith authored Sep 19, 2024
2 parents b0f957c + 672d894 commit 2fa672b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
else
forwarddiff_color_jacobian!(J, f, x, jac_config)
end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, maximum(jac_config.colorvec))
elseif alg_autodiff(alg) isa AutoFiniteDiff
isforward = alg_difftype(alg) === Val{:forward}
if isforward
Expand Down
98 changes: 23 additions & 75 deletions test/interface/stats_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,32 @@ function f(u, p, t)
x[] += 1
return 5 * u
end
u0 = [1.0, 1.0]
tspan = (0.0, 1.0)
prob = ODEProblem(f, u0, tspan)

x[] = 0
sol = solve(prob, Vern7())
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Vern8())
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Vern9())
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Tsit5())
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, BS3())
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, KenCarp4(; autodiff = true))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, KenCarp4(; autodiff = false, diff_type = Val{:forward}))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, KenCarp4(; autodiff = false, diff_type = Val{:central}))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, KenCarp4(; autodiff = false, diff_type = Val{:complex}))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Rosenbrock23(; autodiff = true))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Rosenbrock23(; autodiff = false, diff_type = Val{:forward}))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Rosenbrock23(; autodiff = false, diff_type = Val{:central}))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Rosenbrock23(; autodiff = false, diff_type = Val{:complex}))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Rodas5(; autodiff = true))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Rodas5(; autodiff = false, diff_type = Val{:forward}))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Rodas5(; autodiff = false, diff_type = Val{:central}))
@test x[] == sol.stats.nf

x[] = 0
sol = solve(prob, Rodas5(; autodiff = false, diff_type = Val{:complex}))
@test x[] == sol.stats.nf

function g(du, u, p, t)
x[] += 1
@. du = 5 * u
end

u0 = [1.0, 1.0]
tspan = (0.0, 1.0)
probop = ODEProblem(f, u0, tspan)
probip = ODEProblem(g, u0, tspan)

x[] = 0
sol = solve(probip, ROCK4())
@test x[] == sol.stats.nf
@testset "stats_tests" begin
@testset "$prob" for prob in [probop, probip]
@testset "$alg" for alg in [BS3, Tsit5, Vern7, Vern9, ROCK4]
x[] = 0
sol = solve(prob, alg())
@test x[] == sol.stats.nf
end
@testset "$alg" for alg in [Rodas5P, KenCarp4]
@testset "$kwargs" for kwargs in [(autodiff = true,),
(autodiff = false, diff_type = Val{:forward}),
(autodiff = false, diff_type = Val{:central}),
(autodiff = false, diff_type = Val{:complex}),]
x[] = 0
sol = solve(prob, alg(;kwargs...))
@test x[] == sol.stats.nf
end
end
end
end

0 comments on commit 2fa672b

Please sign in to comment.