From ffe02c342de13ba7830ab14adcc9ac1edb75bce1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Oct 2023 13:48:28 +0530 Subject: [PATCH] chore: format --- ext/DiffEqBaseEnzymeExt.jl | 41 +++++++++--- src/init.jl | 2 +- src/stats.jl | 118 +++++++++++++++++----------------- test/downstream/kwarg_warn.jl | 5 +- test/remake_tests.jl | 2 +- 5 files changed, 96 insertions(+), 72 deletions(-) diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index ed2237ae2..c8d2038ef 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -7,7 +7,17 @@ isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) using ChainRulesCore using EnzymeCore -function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT +function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{ + 1, + }, + func::Const{typeof(DiffEqBase.solve_up)}, + ::Type{Duplicated{RT}}, + prob, + sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, + u0, + p, + args...; + kwargs...) where {RT} @inline function copy_or_reuse(val, idx) if EnzymeCore.EnzymeRules.overwritten(config)[idx] && ismutable(val) return deepcopy(val) @@ -17,24 +27,35 @@ function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules. end @inline function arg_copy(i) - copy_or_reuse(args[i].val, i+5) + copy_or_reuse(args[i].val, i + 5) end - - res = DiffEqBase._solve_adjoint(copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...; + + res = DiffEqBase._solve_adjoint(copy_or_reuse(prob.val, 2), + copy_or_reuse(sensealg.val, 3), copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), + SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...; kwargs...) dres = deepcopy(res[1])::RT for v in dres.u - v.= 0 + v .= 0 end tup = (dres, res[2]) return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) end -function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{<:Duplicated{RT}}, tape, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT - dres, clos = tape +function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, + func::Const{typeof(DiffEqBase.solve_up)}, + ::Type{<:Duplicated{RT}}, + tape, + prob, + sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, + u0, + p, + args...; + kwargs...) where {RT} + dres, clos = tape dres = dres::RT - dargs = clos(dres) + dargs = clos(dres) for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) if ptr isa EnzymeCore.Const continue @@ -45,9 +66,9 @@ function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWid ptr.dval .+= darg end for v in dres.u - v.= 0 + v .= 0 end - return ntuple(_ -> nothing, Val(length(args)+4)) + return ntuple(_ -> nothing, Val(length(args) + 4)) end end diff --git a/src/init.jl b/src/init.jl index ba245c09d..88342cb0d 100644 --- a/src/init.jl +++ b/src/init.jl @@ -47,7 +47,7 @@ end @require MPI="da04e1cc-30fd-572f-bb4f-1f8673147195" begin include("../ext/DiffEqBaseMPIExt.jl") end - + @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin include("../ext/DiffEqBaseEnzymeExt.jl") end diff --git a/src/stats.jl b/src/stats.jl index 71e66ed2e..14833f69b 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -1,62 +1,62 @@ @static if isdefined(SciMLBase, :DEStats) - const Stats = SciMLBase.DEStats + const Stats = SciMLBase.DEStats else - """ - $(TYPEDEF) - - Statistics from the differential equation solver about the solution process. - - ## Fields - - - nf: Number of function evaluations. If the differential equation is a split function, - such as a `SplitFunction` for implicit-explicit (IMEX) integration, then `nf` is the - number of function evaluations for the first function (the implicit function) - - nf2: If the differential equation is a split function, such as a `SplitFunction` - for implicit-explicit (IMEX) integration, then `nf2` is the number of function - evaluations for the second function, i.e. the function treated explicitly. Otherwise - it is zero. - - nw: The number of W=I-gamma*J (or W=I/gamma-J) matrices constructed during the solving - process. - - nsolve: The number of linear solves `W\b` required for the integration. - - njacs: Number of Jacobians calculated during the integration. - - nnonliniter: Total number of iterations for the nonlinear solvers. - - nnonlinconvfail: Number of nonlinear solver convergence failures. - - ncondition: Number of calls to the condition function for callbacks. - - naccept: Number of accepted steps. - - nreject: Number of rejected steps. - - maxeig: Maximum eigenvalue over the solution. This is only computed if the - method is an auto-switching algorithm. - """ - mutable struct Stats - nf::Int - nf2::Int - nw::Int - nsolve::Int - njacs::Int - nnonliniter::Int - nnonlinconvfail::Int - ncondition::Int - naccept::Int - nreject::Int - maxeig::Float64 - end - - Base.@deprecate_binding DEStats Stats false - - Stats(x::Int = -1) = Stats(x, x, x, x, x, x, x, x, x, x, 0.0) - - function Base.show(io::IO, s::Stats) - println(io, summary(s)) - @printf io "%-50s %-d\n" "Number of function 1 evaluations:" s.nf - @printf io "%-50s %-d\n" "Number of function 2 evaluations:" s.nf2 - @printf io "%-50s %-d\n" "Number of W matrix evaluations:" s.nw - @printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve - @printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs - @printf io "%-50s %-d\n" "Number of nonlinear solver iterations:" s.nnonliniter - @printf io "%-50s %-d\n" "Number of nonlinear solver convergence failures:" s.nnonlinconvfail - @printf io "%-50s %-d\n" "Number of rootfind condition calls:" s.ncondition - @printf io "%-50s %-d\n" "Number of accepted steps:" s.naccept - @printf io "%-50s %-d" "Number of rejected steps:" s.nreject - iszero(s.maxeig) || @printf io "\n%-50s %-d" "Maximum eigenvalue recorded:" s.maxeig - end + """ + $(TYPEDEF) + + Statistics from the differential equation solver about the solution process. + + ## Fields + + - nf: Number of function evaluations. If the differential equation is a split function, + such as a `SplitFunction` for implicit-explicit (IMEX) integration, then `nf` is the + number of function evaluations for the first function (the implicit function) + - nf2: If the differential equation is a split function, such as a `SplitFunction` + for implicit-explicit (IMEX) integration, then `nf2` is the number of function + evaluations for the second function, i.e. the function treated explicitly. Otherwise + it is zero. + - nw: The number of W=I-gamma*J (or W=I/gamma-J) matrices constructed during the solving + process. + - nsolve: The number of linear solves `W\b` required for the integration. + - njacs: Number of Jacobians calculated during the integration. + - nnonliniter: Total number of iterations for the nonlinear solvers. + - nnonlinconvfail: Number of nonlinear solver convergence failures. + - ncondition: Number of calls to the condition function for callbacks. + - naccept: Number of accepted steps. + - nreject: Number of rejected steps. + - maxeig: Maximum eigenvalue over the solution. This is only computed if the + method is an auto-switching algorithm. + """ + mutable struct Stats + nf::Int + nf2::Int + nw::Int + nsolve::Int + njacs::Int + nnonliniter::Int + nnonlinconvfail::Int + ncondition::Int + naccept::Int + nreject::Int + maxeig::Float64 + end + + Base.@deprecate_binding DEStats Stats false + + Stats(x::Int = -1) = Stats(x, x, x, x, x, x, x, x, x, x, 0.0) + + function Base.show(io::IO, s::Stats) + println(io, summary(s)) + @printf io "%-50s %-d\n" "Number of function 1 evaluations:" s.nf + @printf io "%-50s %-d\n" "Number of function 2 evaluations:" s.nf2 + @printf io "%-50s %-d\n" "Number of W matrix evaluations:" s.nw + @printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve + @printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs + @printf io "%-50s %-d\n" "Number of nonlinear solver iterations:" s.nnonliniter + @printf io "%-50s %-d\n" "Number of nonlinear solver convergence failures:" s.nnonlinconvfail + @printf io "%-50s %-d\n" "Number of rootfind condition calls:" s.ncondition + @printf io "%-50s %-d\n" "Number of accepted steps:" s.naccept + @printf io "%-50s %-d" "Number of rejected steps:" s.nreject + iszero(s.maxeig) || @printf io "\n%-50s %-d" "Maximum eigenvalue recorded:" s.maxeig + end end diff --git a/test/downstream/kwarg_warn.jl b/test/downstream/kwarg_warn.jl index 3c7e9b014..2d44a1f28 100644 --- a/test/downstream/kwarg_warn.jl +++ b/test/downstream/kwarg_warn.jl @@ -9,7 +9,10 @@ tspan = (0.0, 100.0) prob = ODEProblem(lorenz, u0, tspan) @test_nowarn sol = solve(prob, Tsit5(), reltol = 1e-6) sol = solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn) -@test_logs (:warn, DiffEqBase.KWARGWARN_MESSAGE) sol=solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn) +@test_logs (:warn, DiffEqBase.KWARGWARN_MESSAGE) sol=solve(prob, + Tsit5(), + rel_tol = 1e-6, + kwargshandle = DiffEqBase.KeywordArgWarn) @test_throws DiffEqBase.CommonKwargError sol=solve(prob, Tsit5(), rel_tol = 1e-6) prob = ODEProblem(lorenz, u0, tspan, test = 2.0, kwargshandle = DiffEqBase.KeywordArgWarn) diff --git a/test/remake_tests.jl b/test/remake_tests.jl index a4950b877..f1a49b945 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -79,7 +79,7 @@ noise2 = remake(noise1; tspan = tspan2); # Test remake with TwoPointBVPFunction (manually defined): f1 = SciMLBase.TwoPointBVPFunction((u, p, t) -> 1, ((u_a, p) -> 2, (u_b, p) -> 2)) -@test_broken f2 = remake(f1; bc = ((u_a, p) -> 3, (u_b, p) -> 4)) +@test_broken f2 = remake(f1; bc = ((u_a, p) -> 3, (u_b, p) -> 4)) @test_broken f1.bc() == 1 @test_broken f2.bc() == 2