From 672d894a33feb5e136a5125455c3d2a8da9b0220 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Thu, 19 Sep 2024 11:39:19 -0400 Subject: [PATCH] fix stats --- .../src/derivative_wrappers.jl | 2 +- test/interface/stats_tests.jl | 98 +++++-------------- 2 files changed, 24 insertions(+), 76 deletions(-) diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl index 1093ab555e..8c7b94a85a 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl @@ -239,7 +239,7 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, else forwarddiff_color_jacobian!(J, f, x, jac_config) end - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, maximum(jac_config.colorvec)) elseif alg_autodiff(alg) isa AutoFiniteDiff isforward = alg_difftype(alg) === Val{:forward} if isforward diff --git a/test/interface/stats_tests.jl b/test/interface/stats_tests.jl index fd954134de..13f06cb127 100644 --- a/test/interface/stats_tests.jl +++ b/test/interface/stats_tests.jl @@ -5,84 +5,32 @@ function f(u, p, t) x[] += 1 return 5 * u end -u0 = [1.0, 1.0] -tspan = (0.0, 1.0) -prob = ODEProblem(f, u0, tspan) - -x[] = 0 -sol = solve(prob, Vern7()) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Vern8()) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Vern9()) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Tsit5()) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, BS3()) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, KenCarp4(; autodiff = true)) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, KenCarp4(; autodiff = false, diff_type = Val{:forward})) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, KenCarp4(; autodiff = false, diff_type = Val{:central})) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, KenCarp4(; autodiff = false, diff_type = Val{:complex})) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Rosenbrock23(; autodiff = true)) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Rosenbrock23(; autodiff = false, diff_type = Val{:forward})) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Rosenbrock23(; autodiff = false, diff_type = Val{:central})) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Rosenbrock23(; autodiff = false, diff_type = Val{:complex})) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Rodas5(; autodiff = true)) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Rodas5(; autodiff = false, diff_type = Val{:forward})) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Rodas5(; autodiff = false, diff_type = Val{:central})) -@test x[] == sol.stats.nf - -x[] = 0 -sol = solve(prob, Rodas5(; autodiff = false, diff_type = Val{:complex})) -@test x[] == sol.stats.nf - function g(du, u, p, t) x[] += 1 @. du = 5 * u end + +u0 = [1.0, 1.0] +tspan = (0.0, 1.0) +probop = ODEProblem(f, u0, tspan) probip = ODEProblem(g, u0, tspan) -x[] = 0 -sol = solve(probip, ROCK4()) -@test x[] == sol.stats.nf +@testset "stats_tests" begin + @testset "$prob" for prob in [probop, probip] + @testset "$alg" for alg in [BS3, Tsit5, Vern7, Vern9, ROCK4] + x[] = 0 + sol = solve(prob, alg()) + @test x[] == sol.stats.nf + end + @testset "$alg" for alg in [Rodas5P, KenCarp4] + @testset "$kwargs" for kwargs in [(autodiff = true,), + (autodiff = false, diff_type = Val{:forward}), + (autodiff = false, diff_type = Val{:central}), + (autodiff = false, diff_type = Val{:complex}),] + x[] = 0 + sol = solve(prob, alg(;kwargs...)) + @test x[] == sol.stats.nf + end + end + end +end