Skip to content

Commit

Permalink
simplify dolinsolve
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Aug 14, 2024
1 parent f5f1cc4 commit 67f869d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 130 deletions.
16 changes: 4 additions & 12 deletions lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,19 @@ issuccess_W(W::Number) = !iszero(W)
issuccess_W(::Any) = true

function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
du = nothing, u = nothing, p = nothing, t = nothing,
weight = nothing, solverdata = nothing,
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
A !== nothing && (linsolve.A = A)
b !== nothing && (linsolve.b = b)
linu !== nothing && (linsolve.u = linu)

Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
linsolve.Pl
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
linsolve.Pr

_alg = unwrap_alg(integrator, true)

_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
solverdata)
if (_Pl !== nothing || _Pr !== nothing)
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
linsolve.Pl = __Pl
linsolve.Pr = __Pr
if !isnothing(A)
(;du, u, p, t) = integrator
p = isnothing(integrator) ? nothing : (du, u, p, t)
reinit!(linsolve; A, p)
end

linres = solve!(linsolve; reltol)
Expand Down
48 changes: 0 additions & 48 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve_tmp = zero(rate_prototype)

linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))

grad_config = build_grad_config(alg, f, tf, du1, t)
Expand Down Expand Up @@ -141,11 +137,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))

Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2, Val(false))
Expand Down Expand Up @@ -289,11 +281,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -375,11 +363,7 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -568,11 +552,7 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -612,11 +592,7 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -735,11 +711,7 @@ function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -795,11 +767,7 @@ function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -855,11 +823,7 @@ function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -915,11 +879,7 @@ function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -1032,11 +992,7 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down Expand Up @@ -1096,11 +1052,7 @@ function alg_cache(
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Expand Down
77 changes: 7 additions & 70 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = γ))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = γ))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

vecu = _vec(linres.u)
veck₁ = _vec(k₁)
Expand Down Expand Up @@ -162,16 +153,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = γ))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = γ))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

vecu = _vec(linres.u)
veck₁ = _vec(k₁)
Expand Down Expand Up @@ -521,16 +503,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

vecu = _vec(linres.u)
veck1 = _vec(k1)
Expand Down Expand Up @@ -716,16 +689,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

vecu = _vec(linres.u)
veck1 = _vec(k1)
Expand Down Expand Up @@ -1024,16 +988,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

@.. broadcast=false $(_vec(k1))=-linres.u

Expand Down Expand Up @@ -1387,16 +1342,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

@.. broadcast=false $(_vec(k1))=-linres.u

Expand Down Expand Up @@ -1790,16 +1736,7 @@ end
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

vecu = _vec(linres.u)
veck1 = _vec(k1)
Expand Down

0 comments on commit 67f869d

Please sign in to comment.