Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use LinearSolve precs interface #2318

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ OrdinaryDiffEqQPRK = "1"
OrdinaryDiffEqRKN = "1"
OrdinaryDiffEqRosenbrock = "1"
OrdinaryDiffEqSDIRK = "1"
OrdinaryDiffEqStabilizedIRK = "1"
OrdinaryDiffEqSSPRK = "1"
OrdinaryDiffEqStabilizedIRK = "1"
OrdinaryDiffEqStabilizedRK = "1"
OrdinaryDiffEqSymplecticRK = "1"
OrdinaryDiffEqTsit5 = "1"
Expand Down
3 changes: 1 addition & 2 deletions lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
OrdinaryDiffEqNewtonAdaptiveAlgorithm,
OrdinaryDiffEqNewtonAlgorithm,
AbstractController, DEFAULT_PRECS,
CompiledFloats, uses_uprev,
AbstractController, CompiledFloats, uses_uprev,
alg_cache, _vec, _reshape, @cache,
isfsal, full_cache,
constvalue, isadaptive, error_constant,
Expand Down
135 changes: 42 additions & 93 deletions lib/OrdinaryDiffEqBDF/src/algorithms.jl

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ end
Divergence = -2
end
const TryAgain = SlowConvergence

DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
isdiscretecache(cache) = false

# unused. Delete this once StocasticDiffEq doesn't use it
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
include("doc_utils.jl")
include("misc_utils.jl")

Expand Down
36 changes: 1 addition & 35 deletions lib/OrdinaryDiffEqCore/src/doc_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ function differentiation_rk_docstring(description::String,
concrete_jac = nothing,
diff_type = Val{:forward},
linsolve = nothing,
precs = DEFAULT_PRECS,
""" * extra_keyword_default

keyword_default_description = """
Expand All @@ -111,40 +110,7 @@ function differentiation_rk_docstring(description::String,
For example, to use [KLU.jl](https://github.com/JuliaSparse/KLU.jl), specify
`$name(linsolve = KLUFactorization()`).
When `nothing` is passed, uses `DefaultLinearSolver`.
- `precs`: Any [LinearSolve.jl-compatible preconditioner](https://docs.sciml.ai/LinearSolve/stable/basics/Preconditioners/)
can be used as a left or right preconditioner.
Preconditioners are specified by the `Pl,Pr = precs(W,du,u,p,t,newW,Plprev,Prprev,solverdata)`
function where the arguments are defined as:
- `W`: the current Jacobian of the nonlinear system. Specified as either
``I - \\gamma J`` or ``I/\\gamma - J`` depending on the algorithm. This will
commonly be a `WOperator` type defined by OrdinaryDiffEq.jl. It is a lazy
representation of the operator. Users can construct the W-matrix on demand
by calling `convert(AbstractMatrix,W)` to receive an `AbstractMatrix` matching
the `jac_prototype`.
- `du`: the current ODE derivative
- `u`: the current ODE state
- `p`: the ODE parameters
- `t`: the current ODE time
- `newW`: a `Bool` which specifies whether the `W` matrix has been updated since
the last call to `precs`. It is recommended that this is checked to only
update the preconditioner when `newW == true`.
- `Plprev`: the previous `Pl`.
- `Prprev`: the previous `Pr`.
- `solverdata`: Optional extra data the solvers can give to the `precs` function.
Solver-dependent and subject to change.
The return is a tuple `(Pl,Pr)` of the LinearSolve.jl-compatible preconditioners.
To specify one-sided preconditioning, simply return `nothing` for the preconditioner
which is not used. Additionally, `precs` must supply the dispatch:
```julia
Pl, Pr = precs(W, du, u, p, t, ::Nothing, ::Nothing, ::Nothing, solverdata)
```
which is used in the solver setup phase to construct the integrator
type with the preconditioners `(Pl,Pr)`.
The default is `precs=DEFAULT_PRECS` where the default preconditioner function
is defined as:
```julia
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
```
""" * extra_keyword_description

generic_solver_docstring(
Expand Down
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqDefault/test/default_solver_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23(autodiff = false)))
sol = solve(prob_rober, reltol = 1e-7, abstol = 1e-7)
rosensol = solve(
prob_rober, AutoVern7(Rodas5P(autodiff = false)), reltol = 1e-7, abstol = 1e-7)
# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this).
# test that default has the same performance as AutoTsit5(Rodas5P()) (which we expect it to use for this).
@test sol.stats.naccept == rosensol.stats.naccept
@test sol.stats.nf == rosensol.stats.nf
@test unique(sol.alg_choice) == [2, 4]
Expand All @@ -76,7 +76,7 @@ for n in (100, 600)
vcat([1.0, 0.0, 0.0], ones(n)), (0.0, 100.0), (0.04, 3e7, 1e4))
global sol = solve(prob_ex_rober)
fsol = solve(prob_ex_rober, AutoTsit5(FBDF(; autodiff = false, linsolve)))
# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this).
# test that default has the same performance as AutoTsit5(FBDF()) (which we expect it to use for this).
@test sol.stats.naccept == fsol.stats.naccept
@test sol.stats.nf == fsol.stats.nf
@test unique(sol.alg_choice) == [1, stiffalg]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, S
using DiffEqBase: TimeGradientWrapper,
UJacobianWrapper, TimeDerivativeWrapper,
UDerivativeWrapper
using SciMLBase: AbstractSciMLOperator
using SciMLBase: AbstractSciMLOperator, DEIntegrator
import OrdinaryDiffEqCore
using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm,
DAEAlgorithm,
Expand Down
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
# Thus setup JacVec and a concrete J, using sparsity when possible
_f = islin ? (isode ? f.f : f.f1.f) : f
J = if f.jac_prototype === nothing
ArrayInterface.undefmatrix(u)
ArrayInterface.zeromatrix(u)
else
deepcopy(f.jac_prototype)
end
Expand All @@ -907,7 +907,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
f.jac(uprev, p, t)
end
elseif f.jac_prototype === nothing
ArrayInterface.undefmatrix(u)
ArrayInterface.zeromatrix(u)
else
deepcopy(f.jac_prototype)
end
Expand Down
42 changes: 20 additions & 22 deletions lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,20 @@ issuccess_W(W::Number) = !iszero(W)
issuccess_W(::Any) = true

function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
du = nothing, u = nothing, p = nothing, t = nothing,
weight = nothing, solverdata = nothing,
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
A !== nothing && (linsolve.A = A)
b !== nothing && (linsolve.b = b)
linu !== nothing && (linsolve.u = linu)

Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
linsolve.Pl
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
linsolve.Pr

_alg = unwrap_alg(integrator, true)

_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
solverdata)
if (_Pl !== nothing || _Pr !== nothing)
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
linsolve.Pl = __Pl
linsolve.Pr = __Pr
if !isnothing(A)
if integrator isa DEIntegrator
(;u, p, t) = integrator
du = hasproperty(integrator, :du) ? integrator.du : nothing
p = (du, u, p, t)
reinit!(linsolve; A, p)
else
reinit!(linsolve; A)
end
end

linres = solve!(linsolve; reltol)
Expand All @@ -44,16 +37,21 @@ function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothi
return linres
end

function wrapprecs(_Pl::Nothing, _Pr::Nothing, weight, u)
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
Pr = Diagonal(_vec(weight))
Pl, Pr
end

#for backward compat delete soon
function wrapprecs(_Pl, _Pr, weight, u)
Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pl
Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pr
Pl, Pr
end
function wrapprecs(linsolver, W, weight)
if hasproperty(linsolver, :precs) && isnothing(linsolver.precs)
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
Pr = Diagonal(_vec(weight))
precs = Returns((Pl, Pr))
return remake(linsolver; precs)
else
return linsolver
end
end

Base.resize!(p::LinearSolve.LinearCache, i) = p
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import OrdinaryDiffEqCore: alg_order, alg_maximum_order, get_current_adaptive_or
OrdinaryDiffEqAdaptiveAlgorithm,
OrdinaryDiffEqAdaptiveImplicitAlgorithm,
alg_cache, CompiledFloats, @threaded, stepsize_controller!,
DEFAULT_PRECS, full_cache,
full_cache,
constvalue, PolyesterThreads, Sequential, BaseThreads,
_digest_beta1_beta2, timedepentdtmin, _unwrap_val,
_reshape, _vec, get_fsalfirstlast, generic_solver_docstring,
Expand Down
41 changes: 16 additions & 25 deletions lib/OrdinaryDiffEqExtrapolation/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,9 @@ Similar to Hairer's SEULEX.",
thread = OrdinaryDiffEq.False(),
sequence = :harmonic
""")
struct ImplicitEulerExtrapolation{CS, AD, F, P, FDT, ST, CJ, TO} <:
struct ImplicitEulerExtrapolation{CS, AD, F, FDT, ST, CJ, TO} <:
OrdinaryDiffEqImplicitExtrapolationAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
max_order::Int
min_order::Int
init_order::Int
Expand All @@ -69,7 +68,6 @@ end
function ImplicitEulerExtrapolation(; chunk_size = Val{0}(), autodiff = true,
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing,
precs = DEFAULT_PRECS,
max_order = 12, min_order = 3, init_order = 5,
threading = false, sequence = :harmonic)
linsolve = (linsolve === nothing &&
Expand Down Expand Up @@ -100,11 +98,10 @@ Initial order: " * lpad(init_order, 2, " ") * " --> " * lpad(init_order, 2, " ")
sequence = :harmonic
end
ImplicitEulerExtrapolation{_unwrap_val(chunk_size), _unwrap_val(autodiff),
typeof(linsolve), typeof(precs), diff_type,
typeof(linsolve), diff_type,
_unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(threading)}(linsolve, precs, max_order, min_order,
init_order,
threading, sequence)
typeof(threading)}(linsolve, max_order, min_order,
init_order, threading, sequence)
end

@doc generic_solver_docstring("Midpoint extrapolation using Barycentric coordinates.",
Expand Down Expand Up @@ -195,10 +192,9 @@ end
thread = OrdinaryDiffEq.False(),
sequence = :harmonic,
""")
struct ImplicitDeuflhardExtrapolation{CS, AD, F, P, FDT, ST, CJ, TO} <:
struct ImplicitDeuflhardExtrapolation{CS, AD, F, FDT, ST, CJ, TO} <:
OrdinaryDiffEqImplicitExtrapolationAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
min_order::Int # Minimal extrapolation order
init_order::Int # Initial extrapolation order
max_order::Int # Maximal extrapolation order
Expand All @@ -207,8 +203,7 @@ struct ImplicitDeuflhardExtrapolation{CS, AD, F, P, FDT, ST, CJ, TO} <:
end
function ImplicitDeuflhardExtrapolation(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
linsolve = nothing, precs = DEFAULT_PRECS,
diff_type = Val{:forward},
linsolve = nothing, diff_type = Val{:forward},
min_order = 1, init_order = 5, max_order = 10,
sequence = :harmonic, threading = false)
# Enforce 1 <= min_order <= init_order <= max_order:
Expand Down Expand Up @@ -243,9 +238,9 @@ Initial order: " * lpad(init_order, 2, " ") * " --> " * lpad(init_order, 2, " ")

# Initialize algorithm
ImplicitDeuflhardExtrapolation{_unwrap_val(chunk_size), _unwrap_val(autodiff),
typeof(linsolve), typeof(precs), diff_type,
typeof(linsolve), diff_type,
_unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(threading)}(linsolve, precs, min_order,
typeof(threading)}(linsolve, min_order,
init_order, max_order,
sequence, threading)
end
Expand Down Expand Up @@ -342,10 +337,9 @@ end
thread = OrdinaryDiffEq.False(),
sequence = :harmonic,
""")
struct ImplicitHairerWannerExtrapolation{CS, AD, F, P, FDT, ST, CJ, TO} <:
struct ImplicitHairerWannerExtrapolation{CS, AD, F, FDT, ST, CJ, TO} <:
OrdinaryDiffEqImplicitExtrapolationAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
min_order::Int # Minimal extrapolation order
init_order::Int # Initial extrapolation order
max_order::Int # Maximal extrapolation order
Expand All @@ -356,7 +350,7 @@ end
function ImplicitHairerWannerExtrapolation(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(),
concrete_jac = nothing,
linsolve = nothing, precs = DEFAULT_PRECS,
linsolve = nothing,
diff_type = Val{:forward},
min_order = 2, init_order = 5, max_order = 10,
sequence = :harmonic, threading = false)
Expand Down Expand Up @@ -392,11 +386,10 @@ Initial order: " * lpad(init_order, 2, " ") * " --> " * lpad(init_order, 2, " ")

# Initialize algorithm
ImplicitHairerWannerExtrapolation{_unwrap_val(chunk_size), _unwrap_val(autodiff),
typeof(linsolve), typeof(precs), diff_type,
typeof(linsolve), diff_type,
_unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(threading)}(linsolve, precs, min_order,
init_order,
max_order, sequence, threading)
typeof(threading)}(linsolve, min_order,
init_order, max_order, sequence, threading)
end

@doc differentiation_rk_docstring("Euler extrapolation using Barycentric coordinates,
Expand All @@ -420,10 +413,9 @@ end
sequence = :harmonic,
sequence_factor = 2,
""")
struct ImplicitEulerBarycentricExtrapolation{CS, AD, F, P, FDT, ST, CJ, TO} <:
struct ImplicitEulerBarycentricExtrapolation{CS, AD, F, FDT, ST, CJ, TO} <:
OrdinaryDiffEqImplicitExtrapolationAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
min_order::Int # Minimal extrapolation order
init_order::Int # Initial extrapolation order
max_order::Int # Maximal extrapolation order
Expand All @@ -436,7 +428,7 @@ function ImplicitEulerBarycentricExtrapolation(; chunk_size = Val{0}(),
autodiff = Val{true}(),
standardtag = Val{true}(),
concrete_jac = nothing,
linsolve = nothing, precs = DEFAULT_PRECS,
linsolve = nothing,
diff_type = Val{:forward},
min_order = 3, init_order = 5,
max_order = 12, sequence = :harmonic,
Expand Down Expand Up @@ -473,10 +465,9 @@ Initial order: " * lpad(init_order, 2, " ") * " --> " * lpad(init_order, 2, " ")

# Initialize algorithm
ImplicitEulerBarycentricExtrapolation{_unwrap_val(chunk_size), _unwrap_val(autodiff),
typeof(linsolve), typeof(precs), diff_type,
typeof(linsolve), diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac), typeof(threading)}(linsolve,
precs,
min_order,
init_order,
max_order,
Expand Down
Loading
Loading