Skip to content

Commit

Permalink
simplify dolinsolve
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Aug 7, 2024
1 parent 07470f9 commit d4ea754
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 132 deletions.
16 changes: 4 additions & 12 deletions lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,19 @@ issuccess_W(W::Number) = !iszero(W)
issuccess_W(::Any) = true

function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
du = nothing, u = nothing, p = nothing, t = nothing,
weight = nothing, solverdata = nothing,
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
A !== nothing && (linsolve.A = A)
b !== nothing && (linsolve.b = b)
linu !== nothing && (linsolve.u = linu)

Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
linsolve.Pl
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
linsolve.Pr

_alg = unwrap_alg(integrator, true)

_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
solverdata)
if (_Pl !== nothing || _Pr !== nothing)
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
linsolve.Pl = __Pl
linsolve.Pr = __Pr
if !isnothing(A)
(;du, u, p, t) = integrator
p = isnothing(integrator) ? nothing : (du, u, p, t)
reinit!(linsolve; A, p)
end

linres = solve!(linsolve; reltol)
Expand Down
50 changes: 1 addition & 49 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve_tmp = zero(rate_prototype)

linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))

grad_config = build_grad_config(alg, f, tf, du1, t)
Expand Down Expand Up @@ -143,11 +139,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))

Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2, Val(false))
Expand Down Expand Up @@ -291,11 +283,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -377,11 +365,7 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -570,11 +554,7 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -615,11 +595,7 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -740,11 +716,7 @@ function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -802,11 +774,7 @@ function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -862,11 +830,7 @@ function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -922,11 +886,7 @@ function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -1041,11 +1001,7 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -1105,11 +1061,7 @@ function alg_cache(
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -1140,4 +1092,4 @@ end

### RosenbrockW6S4O

@RosenbrockW6S4OS(:cache)
@RosenbrockW6S4OS(:cache)
79 changes: 8 additions & 71 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = γ))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = γ))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

vecu = _vec(linres.u)
veck₁ = _vec(k₁)
Expand Down Expand Up @@ -162,16 +153,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = γ))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = γ))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

vecu = _vec(linres.u)
veck₁ = _vec(k₁)
Expand Down Expand Up @@ -521,16 +503,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

vecu = _vec(linres.u)
veck1 = _vec(k1)
Expand Down Expand Up @@ -716,16 +689,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

vecu = _vec(linres.u)
veck1 = _vec(k1)
Expand Down Expand Up @@ -1024,16 +988,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

@.. broadcast=false $(_vec(k1))=-linres.u

Expand Down Expand Up @@ -1387,16 +1342,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

@.. broadcast=false $(_vec(k1))=-linres.u

Expand Down Expand Up @@ -1790,16 +1736,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

vecu = _vec(linres.u)
veck1 = _vec(k1)
Expand Down Expand Up @@ -1986,4 +1923,4 @@ end
end

@RosenbrockW6S4OS(:init)
@RosenbrockW6S4OS(:performstep)
@RosenbrockW6S4OS(:performstep)

0 comments on commit d4ea754

Please sign in to comment.