Skip to content

Commit

Permalink
simplify dolinsolve
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Sep 11, 2024
1 parent 94a6fbc commit 671775f
Show file tree
Hide file tree
Showing 39 changed files with 392 additions and 878 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ OrdinaryDiffEqQPRK = "1"
OrdinaryDiffEqRKN = "1"
OrdinaryDiffEqRosenbrock = "1"
OrdinaryDiffEqSDIRK = "1"
OrdinaryDiffEqStabilizedIRK = "1"
OrdinaryDiffEqSSPRK = "1"
OrdinaryDiffEqStabilizedIRK = "1"
OrdinaryDiffEqStabilizedRK = "1"
OrdinaryDiffEqSymplecticRK = "1"
OrdinaryDiffEqTsit5 = "1"
Expand Down
3 changes: 1 addition & 2 deletions lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
OrdinaryDiffEqNewtonAdaptiveAlgorithm,
OrdinaryDiffEqNewtonAlgorithm,
AbstractController, DEFAULT_PRECS,
CompiledFloats, uses_uprev,
AbstractController, CompiledFloats, uses_uprev,
alg_cache, _vec, _reshape, @cache,
isfsal, full_cache,
constvalue, isadaptive, error_constant,
Expand Down
135 changes: 42 additions & 93 deletions lib/OrdinaryDiffEqBDF/src/algorithms.jl

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ end
Divergence = -2
end
const TryAgain = SlowConvergence

DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
isdiscretecache(cache) = false

# unused. Delete this once StocasticDiffEq doesn't use it
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
include("doc_utils.jl")
include("misc_utils.jl")

Expand Down
36 changes: 1 addition & 35 deletions lib/OrdinaryDiffEqCore/src/doc_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ function differentiation_rk_docstring(description::String,
concrete_jac = nothing,
diff_type = Val{:forward},
linsolve = nothing,
precs = DEFAULT_PRECS,
""" * extra_keyword_default

keyword_default_description = """
Expand All @@ -111,40 +110,7 @@ function differentiation_rk_docstring(description::String,
For example, to use [KLU.jl](https://github.com/JuliaSparse/KLU.jl), specify
`$name(linsolve = KLUFactorization()`).
When `nothing` is passed, uses `DefaultLinearSolver`.
- `precs`: Any [LinearSolve.jl-compatible preconditioner](https://docs.sciml.ai/LinearSolve/stable/basics/Preconditioners/)
can be used as a left or right preconditioner.
Preconditioners are specified by the `Pl,Pr = precs(W,du,u,p,t,newW,Plprev,Prprev,solverdata)`
function where the arguments are defined as:
- `W`: the current Jacobian of the nonlinear system. Specified as either
``I - \\gamma J`` or ``I/\\gamma - J`` depending on the algorithm. This will
commonly be a `WOperator` type defined by OrdinaryDiffEq.jl. It is a lazy
representation of the operator. Users can construct the W-matrix on demand
by calling `convert(AbstractMatrix,W)` to receive an `AbstractMatrix` matching
the `jac_prototype`.
- `du`: the current ODE derivative
- `u`: the current ODE state
- `p`: the ODE parameters
- `t`: the current ODE time
- `newW`: a `Bool` which specifies whether the `W` matrix has been updated since
the last call to `precs`. It is recommended that this is checked to only
update the preconditioner when `newW == true`.
- `Plprev`: the previous `Pl`.
- `Prprev`: the previous `Pr`.
- `solverdata`: Optional extra data the solvers can give to the `precs` function.
Solver-dependent and subject to change.
The return is a tuple `(Pl,Pr)` of the LinearSolve.jl-compatible preconditioners.
To specify one-sided preconditioning, simply return `nothing` for the preconditioner
which is not used. Additionally, `precs` must supply the dispatch:
```julia
Pl, Pr = precs(W, du, u, p, t, ::Nothing, ::Nothing, ::Nothing, solverdata)
```
which is used in the solver setup phase to construct the integrator
type with the preconditioners `(Pl,Pr)`.
The default is `precs=DEFAULT_PRECS` where the default preconditioner function
is defined as:
```julia
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
```
""" * extra_keyword_description

generic_solver_docstring(
Expand Down
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqDefault/test/default_solver_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23(autodiff = false)))
sol = solve(prob_rober, reltol = 1e-7, abstol = 1e-7)
rosensol = solve(
prob_rober, AutoVern7(Rodas5P(autodiff = false)), reltol = 1e-7, abstol = 1e-7)
# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this).
# test that default has the same performance as AutoTsit5(Rodas5P()) (which we expect it to use for this).
@test sol.stats.naccept == rosensol.stats.naccept
@test sol.stats.nf == rosensol.stats.nf
@test unique(sol.alg_choice) == [2, 4]
Expand All @@ -75,7 +75,7 @@ for n in (100, 600)
vcat([1.0, 0.0, 0.0], ones(n)), (0.0, 100.0), (0.04, 3e7, 1e4))
global sol = solve(prob_ex_rober)
fsol = solve(prob_ex_rober, AutoTsit5(FBDF(; autodiff = false, linsolve)))
# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this).
# test that default has the same performance as AutoTsit5(FBDF()) (which we expect it to use for this).
@test sol.stats.naccept == fsol.stats.naccept
@test sol.stats.nf == fsol.stats.nf
@test unique(sol.alg_choice) == [1, stiffalg]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, S
using DiffEqBase: TimeGradientWrapper,
UJacobianWrapper, TimeDerivativeWrapper,
UDerivativeWrapper
using SciMLBase: AbstractSciMLOperator
using SciMLBase: AbstractSciMLOperator, DEIntegrator
import OrdinaryDiffEqCore
using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm,
DAEAlgorithm,
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
# Thus setup JacVec and a concrete J, using sparsity when possible
_f = islin ? (isode ? f.f : f.f1.f) : f
J = if f.jac_prototype === nothing
ArrayInterface.undefmatrix(u)
ArrayInterface.zeromatrix(u)
else
deepcopy(f.jac_prototype)
end
Expand All @@ -907,7 +907,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
f.jac(uprev, p, t)
end
elseif f.jac_prototype === nothing
ArrayInterface.undefmatrix(u)
ArrayInterface.zeromatrix(u)
else
deepcopy(f.jac_prototype)
end
Expand Down Expand Up @@ -1003,4 +1003,4 @@ function resize_J_W!(cache, integrator, i)
end

nothing
end
end
44 changes: 18 additions & 26 deletions lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,20 @@ issuccess_W(W::Number) = !iszero(W)
issuccess_W(::Any) = true

function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
du = nothing, u = nothing, p = nothing, t = nothing,
weight = nothing, solverdata = nothing,
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
A !== nothing && (linsolve.A = A)
b !== nothing && (linsolve.b = b)
linu !== nothing && (linsolve.u = linu)

Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
linsolve.Pl
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
linsolve.Pr

_alg = unwrap_alg(integrator, true)

_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
solverdata)
if (_Pl !== nothing || _Pr !== nothing)
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
linsolve.Pl = __Pl
linsolve.Pr = __Pr
if !isnothing(A)
if integrator isa DEIntegrator
(;u, p, t) = integrator
du = hasproperty(integrator, :du) ? integrator.du : nothing
p = (du, u, p, t)
reinit!(linsolve; A, p)
else
reinit!(linsolve; A)
end
end

linres = solve!(linsolve; reltol)
Expand All @@ -44,16 +37,15 @@ function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothi
return linres
end

function wrapprecs(_Pl::Nothing, _Pr::Nothing, weight, u)
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
Pr = Diagonal(_vec(weight))
Pl, Pr
end

function wrapprecs(_Pl, _Pr, weight, u)
Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pl
Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pr
Pl, Pr
function wrapprecs(linsolver, W, weight)
if hasproperty(linsolver, :precs) && isnothing(linsolver.precs)
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
Pr = Diagonal(_vec(weight))
precs = Returns((Pl, Pr))
return remake(linsolver; precs)
else
return linsolver
end
end

Base.resize!(p::LinearSolve.LinearCache, i) = p
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import OrdinaryDiffEqCore: alg_order, alg_maximum_order, get_current_adaptive_or
OrdinaryDiffEqAdaptiveAlgorithm,
OrdinaryDiffEqAdaptiveImplicitAlgorithm,
alg_cache, CompiledFloats, @threaded, stepsize_controller!,
DEFAULT_PRECS, full_cache,
full_cache,
constvalue, PolyesterThreads, Sequential, BaseThreads,
_digest_beta1_beta2, timedepentdtmin, _unwrap_val,
_reshape, _vec, get_fsalfirstlast, generic_solver_docstring,
Expand Down
41 changes: 16 additions & 25 deletions lib/OrdinaryDiffEqExtrapolation/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,9 @@ Similar to Hairer's SEULEX.",
thread = OrdinaryDiffEq.False(),
sequence = :harmonic
""")
struct ImplicitEulerExtrapolation{CS, AD, F, P, FDT, ST, CJ, TO} <:
struct ImplicitEulerExtrapolation{CS, AD, F, FDT, ST, CJ, TO} <:
OrdinaryDiffEqImplicitExtrapolationAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
max_order::Int
min_order::Int
init_order::Int
Expand All @@ -69,7 +68,6 @@ end
function ImplicitEulerExtrapolation(; chunk_size = Val{0}(), autodiff = true,
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing,
precs = DEFAULT_PRECS,
max_order = 12, min_order = 3, init_order = 5,
threading = false, sequence = :harmonic)
linsolve = (linsolve === nothing &&
Expand Down Expand Up @@ -100,11 +98,10 @@ Initial order: " * lpad(init_order, 2, " ") * " --> " * lpad(init_order, 2, " ")
sequence = :harmonic
end
ImplicitEulerExtrapolation{_unwrap_val(chunk_size), _unwrap_val(autodiff),
typeof(linsolve), typeof(precs), diff_type,
typeof(linsolve), diff_type,
_unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(threading)}(linsolve, precs, max_order, min_order,
init_order,
threading, sequence)
typeof(threading)}(linsolve, max_order, min_order,
init_order, threading, sequence)
end

@doc generic_solver_docstring("Midpoint extrapolation using Barycentric coordinates.",
Expand Down Expand Up @@ -195,10 +192,9 @@ end
thread = OrdinaryDiffEq.False(),
sequence = :harmonic,
""")
struct ImplicitDeuflhardExtrapolation{CS, AD, F, P, FDT, ST, CJ, TO} <:
struct ImplicitDeuflhardExtrapolation{CS, AD, F, FDT, ST, CJ, TO} <:
OrdinaryDiffEqImplicitExtrapolationAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
min_order::Int # Minimal extrapolation order
init_order::Int # Initial extrapolation order
max_order::Int # Maximal extrapolation order
Expand All @@ -207,8 +203,7 @@ struct ImplicitDeuflhardExtrapolation{CS, AD, F, P, FDT, ST, CJ, TO} <:
end
function ImplicitDeuflhardExtrapolation(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
linsolve = nothing, precs = DEFAULT_PRECS,
diff_type = Val{:forward},
linsolve = nothing, diff_type = Val{:forward},
min_order = 1, init_order = 5, max_order = 10,
sequence = :harmonic, threading = false)
# Enforce 1 <= min_order <= init_order <= max_order:
Expand Down Expand Up @@ -243,9 +238,9 @@ Initial order: " * lpad(init_order, 2, " ") * " --> " * lpad(init_order, 2, " ")

# Initialize algorithm
ImplicitDeuflhardExtrapolation{_unwrap_val(chunk_size), _unwrap_val(autodiff),
typeof(linsolve), typeof(precs), diff_type,
typeof(linsolve), diff_type,
_unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(threading)}(linsolve, precs, min_order,
typeof(threading)}(linsolve, min_order,
init_order, max_order,
sequence, threading)
end
Expand Down Expand Up @@ -342,10 +337,9 @@ end
thread = OrdinaryDiffEq.False(),
sequence = :harmonic,
""")
struct ImplicitHairerWannerExtrapolation{CS, AD, F, P, FDT, ST, CJ, TO} <:
struct ImplicitHairerWannerExtrapolation{CS, AD, F, FDT, ST, CJ, TO} <:
OrdinaryDiffEqImplicitExtrapolationAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
min_order::Int # Minimal extrapolation order
init_order::Int # Initial extrapolation order
max_order::Int # Maximal extrapolation order
Expand All @@ -356,7 +350,7 @@ end
function ImplicitHairerWannerExtrapolation(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(),
concrete_jac = nothing,
linsolve = nothing, precs = DEFAULT_PRECS,
linsolve = nothing,
diff_type = Val{:forward},
min_order = 2, init_order = 5, max_order = 10,
sequence = :harmonic, threading = false)
Expand Down Expand Up @@ -392,11 +386,10 @@ Initial order: " * lpad(init_order, 2, " ") * " --> " * lpad(init_order, 2, " ")

# Initialize algorithm
ImplicitHairerWannerExtrapolation{_unwrap_val(chunk_size), _unwrap_val(autodiff),
typeof(linsolve), typeof(precs), diff_type,
typeof(linsolve), diff_type,
_unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(threading)}(linsolve, precs, min_order,
init_order,
max_order, sequence, threading)
typeof(threading)}(linsolve, min_order,
init_order, max_order, sequence, threading)
end

@doc differentiation_rk_docstring("Euler extrapolation using Barycentric coordinates,
Expand All @@ -420,10 +413,9 @@ end
sequence = :harmonic,
sequence_factor = 2,
""")
struct ImplicitEulerBarycentricExtrapolation{CS, AD, F, P, FDT, ST, CJ, TO} <:
struct ImplicitEulerBarycentricExtrapolation{CS, AD, F, FDT, ST, CJ, TO} <:
OrdinaryDiffEqImplicitExtrapolationAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
min_order::Int # Minimal extrapolation order
init_order::Int # Initial extrapolation order
max_order::Int # Maximal extrapolation order
Expand All @@ -436,7 +428,7 @@ function ImplicitEulerBarycentricExtrapolation(; chunk_size = Val{0}(),
autodiff = Val{true}(),
standardtag = Val{true}(),
concrete_jac = nothing,
linsolve = nothing, precs = DEFAULT_PRECS,
linsolve = nothing,
diff_type = Val{:forward},
min_order = 3, init_order = 5,
max_order = 12, sequence = :harmonic,
Expand Down Expand Up @@ -473,10 +465,9 @@ Initial order: " * lpad(init_order, 2, " ") * " --> " * lpad(init_order, 2, " ")

# Initialize algorithm
ImplicitEulerBarycentricExtrapolation{_unwrap_val(chunk_size), _unwrap_val(autodiff),
typeof(linsolve), typeof(precs), diff_type,
typeof(linsolve), diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac), typeof(threading)}(linsolve,
precs,
min_order,
init_order,
max_order,
Expand Down
Loading

0 comments on commit 671775f

Please sign in to comment.