Skip to content

Commit

Permalink
feat: enable OverrideInit to solve for du0 of DAEProblems
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 28, 2024
1 parent a5ee8e9 commit f3938ad
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 4 deletions.
22 changes: 18 additions & 4 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
A collection of all the data required for `OverrideInit`.
"""
struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap, IProbDu0Map}
"""
The `AbstractNonlinearProblem` to solve for initialization.
"""
Expand All @@ -29,12 +29,18 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
initialized will be returned as-is.
"""
initializeprobpmap::IProbPmap
"""
A function which takes the solution of `initializeprob` and returns the
`du0` vector of the original problem.
"""
initializeprob_du0map::IProbDu0Map

function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K,
initprobpmap::L) where {I, J, K, L}
initprobpmap::L, initprob_du0map::M = nothing) where {I, J, K, L, M}
@assert initprob isa
Union{SCCNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem}
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap)
return new{I, J, K, L, M}(
initprob, update_initprob!, initprobmap, initprobpmap, initprob_du0map)
end
end

Expand Down Expand Up @@ -171,9 +177,12 @@ Keyword arguments:
provided to the `OverrideInit` constructor takes priority over this keyword argument.
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
an error will be thrown.
- `return_du0`: Whether to use `initializeprob_du0map` (if present) and return
`du0, u0, p, success`.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing,
reltol = nothing, return_du0 = false, kwargs...)
u0 = state_values(valp)
p = parameter_values(valp)

Expand Down Expand Up @@ -214,5 +223,10 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
p = initdata.initializeprobpmap(valp, nlsol)
end

if initdata.initializeprob_du0map !== nothing && return_du0
du0 = initdata.initializeprob_du0map(nlsol)
return du0, u0, p, SciMLBase.successful_retcode(nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
end
58 changes: 58 additions & 0 deletions test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,62 @@ end
@test p 0.0
@test success
end

@testset "DAEProblem" begin
function daerhs(du, u, p, t)
return [u[1] * t + p, u[1]^2 - u[2]^2]
end
# unknowns are u[2], p, D(u[1]), D(u[2]). Parameters are u[1], t
initprob = NonlinearProblem([1.0, 1.0, 1.0, 1.0], [1.0, 0.0]) do x, _p
u2, p, du1, du2 = x
u1, t = _p
return [u1^3 - u2^3, p^2 - 2p + 1, du1 - u1 * t - p, 2u1 * du1 - 2u2 * du2]
end

update_initializeprob! = function (iprob, integ)
iprob.p[1] = integ.u[1]
iprob.p[2] = integ.t
end
initprobmap = function (nlsol)
return [parameter_values(nlsol)[1], nlsol.u[1]]
end
initprobpmap = function (_, nlsol)
return nlsol.u[2]
end
initprob_du0map = function (nlsol)
return nlsol.u[3:4]
end
initialization_data = SciMLBase.OverrideInitData(
initprob, update_initializeprob!, initprobmap, initprobpmap, initprob_du0map)
fn = DAEFunction(daerhs; initialization_data)
prob = DAEProblem(fn, [0.0, 0.0], [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob, DImplicitEuler(); initializealg = NoInit())

initialization_data2 = SciMLBase.OverrideInitData(
initprob, update_initializeprob!, initprobmap, initprobpmap)
fn2 = DAEFunction(daerhs; initialization_data = initialization_data2)
prob2 = DAEProblem(fn2, [0.0, 0.0], [2.0, 0.0], (0.0, 1.0), 0.0)
integ2 = init(prob2, DImplicitEuler(); initializealg = NoInit())

nlsolve_alg = FastShortcutNonlinearPolyalg()
@testset "Doesn't return `du0` by default" begin
@test length(SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg, abstol, reltol)) == 3
end
@testset "Doesn't return `du0` if missing `du0map`" begin
@test length(SciMLBase.get_initial_values(
prob2, integ2, fn2, SciMLBase.OverrideInit(), Val(false);
nlsolve_alg, abstol, reltol, return_du0 = true)) == 3
end
@testset "With `return_du0 = true`" begin
du0, u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(), Val(false);
nlsolve_alg, abstol, reltol, return_du0 = true)
@test du0 [1.0, 1.0]
@test u0 [2.0, 2.0]
@test p 1.0
@test success
end
end
end

0 comments on commit f3938ad

Please sign in to comment.