Skip to content

Commit

Permalink
chore: format
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 27, 2023
1 parent 0ae3915 commit ffe02c3
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 72 deletions.
41 changes: 31 additions & 10 deletions ext/DiffEqBaseEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@ isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)
using ChainRulesCore
using EnzymeCore

function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{
1,
},
func::Const{typeof(DiffEqBase.solve_up)},
::Type{Duplicated{RT}},
prob,
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
u0,
p,
args...;
kwargs...) where {RT}
@inline function copy_or_reuse(val, idx)
if EnzymeCore.EnzymeRules.overwritten(config)[idx] && ismutable(val)
return deepcopy(val)
Expand All @@ -17,24 +27,35 @@ function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.
end

@inline function arg_copy(i)
copy_or_reuse(args[i].val, i+5)
copy_or_reuse(args[i].val, i + 5)
end

res = DiffEqBase._solve_adjoint(copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...;

res = DiffEqBase._solve_adjoint(copy_or_reuse(prob.val, 2),
copy_or_reuse(sensealg.val, 3), copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5),
SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...;
kwargs...)

dres = deepcopy(res[1])::RT
for v in dres.u
v.= 0
v .= 0
end
tup = (dres, res[2])
return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
end

function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{<:Duplicated{RT}}, tape, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT
dres, clos = tape
function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1},
func::Const{typeof(DiffEqBase.solve_up)},
::Type{<:Duplicated{RT}},
tape,
prob,
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
u0,
p,
args...;
kwargs...) where {RT}
dres, clos = tape
dres = dres::RT
dargs = clos(dres)
dargs = clos(dres)
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
if ptr isa EnzymeCore.Const
continue
Expand All @@ -45,9 +66,9 @@ function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWid
ptr.dval .+= darg
end
for v in dres.u
v.= 0
v .= 0
end
return ntuple(_ -> nothing, Val(length(args)+4))
return ntuple(_ -> nothing, Val(length(args) + 4))
end

end
2 changes: 1 addition & 1 deletion src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ end
@require MPI="da04e1cc-30fd-572f-bb4f-1f8673147195" begin
include("../ext/DiffEqBaseMPIExt.jl")
end

@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("../ext/DiffEqBaseEnzymeExt.jl")
end
Expand Down
118 changes: 59 additions & 59 deletions src/stats.jl
Original file line number Diff line number Diff line change
@@ -1,62 +1,62 @@
@static if isdefined(SciMLBase, :DEStats)
const Stats = SciMLBase.DEStats
const Stats = SciMLBase.DEStats
else
"""
$(TYPEDEF)
Statistics from the differential equation solver about the solution process.
## Fields
- nf: Number of function evaluations. If the differential equation is a split function,
such as a `SplitFunction` for implicit-explicit (IMEX) integration, then `nf` is the
number of function evaluations for the first function (the implicit function)
- nf2: If the differential equation is a split function, such as a `SplitFunction`
for implicit-explicit (IMEX) integration, then `nf2` is the number of function
evaluations for the second function, i.e. the function treated explicitly. Otherwise
it is zero.
- nw: The number of W=I-gamma*J (or W=I/gamma-J) matrices constructed during the solving
process.
- nsolve: The number of linear solves `W\b` required for the integration.
- njacs: Number of Jacobians calculated during the integration.
- nnonliniter: Total number of iterations for the nonlinear solvers.
- nnonlinconvfail: Number of nonlinear solver convergence failures.
- ncondition: Number of calls to the condition function for callbacks.
- naccept: Number of accepted steps.
- nreject: Number of rejected steps.
- maxeig: Maximum eigenvalue over the solution. This is only computed if the
method is an auto-switching algorithm.
"""
mutable struct Stats
nf::Int
nf2::Int
nw::Int
nsolve::Int
njacs::Int
nnonliniter::Int
nnonlinconvfail::Int
ncondition::Int
naccept::Int
nreject::Int
maxeig::Float64
end
Base.@deprecate_binding DEStats Stats false
Stats(x::Int = -1) = Stats(x, x, x, x, x, x, x, x, x, x, 0.0)
function Base.show(io::IO, s::Stats)
println(io, summary(s))
@printf io "%-50s %-d\n" "Number of function 1 evaluations:" s.nf
@printf io "%-50s %-d\n" "Number of function 2 evaluations:" s.nf2
@printf io "%-50s %-d\n" "Number of W matrix evaluations:" s.nw
@printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve
@printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs
@printf io "%-50s %-d\n" "Number of nonlinear solver iterations:" s.nnonliniter
@printf io "%-50s %-d\n" "Number of nonlinear solver convergence failures:" s.nnonlinconvfail
@printf io "%-50s %-d\n" "Number of rootfind condition calls:" s.ncondition
@printf io "%-50s %-d\n" "Number of accepted steps:" s.naccept
@printf io "%-50s %-d" "Number of rejected steps:" s.nreject
iszero(s.maxeig) || @printf io "\n%-50s %-d" "Maximum eigenvalue recorded:" s.maxeig
end
"""
$(TYPEDEF)
Statistics from the differential equation solver about the solution process.
## Fields
- nf: Number of function evaluations. If the differential equation is a split function,
such as a `SplitFunction` for implicit-explicit (IMEX) integration, then `nf` is the
number of function evaluations for the first function (the implicit function)
- nf2: If the differential equation is a split function, such as a `SplitFunction`
for implicit-explicit (IMEX) integration, then `nf2` is the number of function
evaluations for the second function, i.e. the function treated explicitly. Otherwise
it is zero.
- nw: The number of W=I-gamma*J (or W=I/gamma-J) matrices constructed during the solving
process.
- nsolve: The number of linear solves `W\b` required for the integration.
- njacs: Number of Jacobians calculated during the integration.
- nnonliniter: Total number of iterations for the nonlinear solvers.
- nnonlinconvfail: Number of nonlinear solver convergence failures.
- ncondition: Number of calls to the condition function for callbacks.
- naccept: Number of accepted steps.
- nreject: Number of rejected steps.
- maxeig: Maximum eigenvalue over the solution. This is only computed if the
method is an auto-switching algorithm.
"""
mutable struct Stats
nf::Int
nf2::Int
nw::Int
nsolve::Int
njacs::Int
nnonliniter::Int
nnonlinconvfail::Int
ncondition::Int
naccept::Int
nreject::Int
maxeig::Float64
end

Base.@deprecate_binding DEStats Stats false

Stats(x::Int = -1) = Stats(x, x, x, x, x, x, x, x, x, x, 0.0)

function Base.show(io::IO, s::Stats)
println(io, summary(s))
@printf io "%-50s %-d\n" "Number of function 1 evaluations:" s.nf
@printf io "%-50s %-d\n" "Number of function 2 evaluations:" s.nf2
@printf io "%-50s %-d\n" "Number of W matrix evaluations:" s.nw
@printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve
@printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs
@printf io "%-50s %-d\n" "Number of nonlinear solver iterations:" s.nnonliniter
@printf io "%-50s %-d\n" "Number of nonlinear solver convergence failures:" s.nnonlinconvfail
@printf io "%-50s %-d\n" "Number of rootfind condition calls:" s.ncondition
@printf io "%-50s %-d\n" "Number of accepted steps:" s.naccept
@printf io "%-50s %-d" "Number of rejected steps:" s.nreject
iszero(s.maxeig) || @printf io "\n%-50s %-d" "Maximum eigenvalue recorded:" s.maxeig
end
end
5 changes: 4 additions & 1 deletion test/downstream/kwarg_warn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ tspan = (0.0, 100.0)
prob = ODEProblem(lorenz, u0, tspan)
@test_nowarn sol = solve(prob, Tsit5(), reltol = 1e-6)
sol = solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn)
@test_logs (:warn, DiffEqBase.KWARGWARN_MESSAGE) sol=solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn)
@test_logs (:warn, DiffEqBase.KWARGWARN_MESSAGE) sol=solve(prob,
Tsit5(),
rel_tol = 1e-6,
kwargshandle = DiffEqBase.KeywordArgWarn)
@test_throws DiffEqBase.CommonKwargError sol=solve(prob, Tsit5(), rel_tol = 1e-6)

prob = ODEProblem(lorenz, u0, tspan, test = 2.0, kwargshandle = DiffEqBase.KeywordArgWarn)
Expand Down
2 changes: 1 addition & 1 deletion test/remake_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ noise2 = remake(noise1; tspan = tspan2);

# Test remake with TwoPointBVPFunction (manually defined):
f1 = SciMLBase.TwoPointBVPFunction((u, p, t) -> 1, ((u_a, p) -> 2, (u_b, p) -> 2))
@test_broken f2 = remake(f1; bc = ((u_a, p) -> 3, (u_b, p) -> 4))
@test_broken f2 = remake(f1; bc = ((u_a, p) -> 3, (u_b, p) -> 4))
@test_broken f1.bc() == 1
@test_broken f2.bc() == 2

Expand Down

0 comments on commit ffe02c3

Please sign in to comment.