Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add remake for SCCNonlinearProblem #883

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,52 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
end
end

"""
remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing,
parameters_alias = prob.parameters_alias, sys = missing, explicitfuns! = missing)

Remake the given `SCCNonlinearProblem`. `u0` is the state vector for the entire problem,
which will be chunked appropriately and used to `remake` the individual subproblems. `p`
is the parameter object for `prob`. If `parameters_alias`, the same parameter object will be
used to `remake` the individual subproblems. Otherwise if `p !== missing`, this function will
error and require that `probs` be specified. `probs` is the collection of subproblems. Even if
`probs` is explicitly specified, the value of `u0` provided to `remake` will be used to
override the values in `probs`. `sys` is the index provider for the full system.
"""
function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing,
parameters_alias = prob.parameters_alias, sys = missing,
interpret_symbolicmap = true, use_defaults = false, explicitfuns! = missing)
if p !== missing && !parameters_alias && probs === missing
throw(ArgumentError("`parameters_alias` is `false` for the given `SCCNonlinearProblem`. Please provide the subproblems using the keyword `probs` with the parameters updated appropriately in each."))
end
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults,
indp = sys === missing ? prob.full_index_provider : sys)
if probs === missing
probs = prob.probs
end
offset = 0
if u0 !== missing || p !== missing && parameters_alias
probs = map(probs) do subprob
subprob = if parameters_alias
remake(subprob;
u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))],
p = newp)
else
remake(subprob;
u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))])
end
offset += length(state_values(subprob))
return subprob
end
end
if sys === missing
sys = prob.full_index_provider
end
return SCCNonlinearProblem{
typeof(probs), typeof(explicitfuns!), typeof(sys), typeof(newp)}(
probs, explicitfuns!, sys, newp, parameters_alias)
end

function varmap_has_var(varmap, var)
haskey(varmap, var) || hasname(var) && haskey(varmap, getname(var))
end
Expand Down Expand Up @@ -737,11 +783,12 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
end

function updated_u0_p(
prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
prob, u0, p, t0 = nothing; interpret_symbolicmap = true,
use_defaults = false, indp = has_sys(prob.f) ? prob.f.sys : nothing)
if u0 === missing && p === missing
return state_values(prob), parameter_values(prob)
end
if !has_sys(prob.f)
if indp === nothing
if interpret_symbolicmap && eltype(p) !== Union{} && eltype(p) <: Pair
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
Expand Down
62 changes: 62 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ discu0 = Dict([u0..., x(k - 1) => 0.0, y(k - 1) => 0.0, z(k - 1) => 0.0])
push!(syss, discsys)
push!(probs, DiscreteProblem(fn, getindex.((discu0,), unknowns(discsys)), (0, 10), ps))

# TODO: Rewrite this example when the MTK codegen is merged
@named sys1 = NonlinearSystem(
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ])
sys1 = complete(sys1)
@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], [])
sys2 = complete(sys2)
@named fullsys = NonlinearSystem(
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4],
[x, y, z], [σ, β, ρ])
fullsys = complete(fullsys)

prob1 = NonlinearProblem(sys1, u0, p)
prob2 = NonlinearProblem(sys2, u0, prob1.p)
sccprob = SCCNonlinearProblem(
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)
push!(syss, fullsys)
push!(probs, sccprob)

for (sys, prob) in zip(syss, probs)
@test parameter_values(prob) isa ModelingToolkit.MTKParameters

Expand Down Expand Up @@ -274,3 +292,47 @@ end
@test_throws SciMLBase.CyclicDependencyError remake(
prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3])
end

@testset "SCCNonlinearProblem" begin
@named sys1 = NonlinearSystem(
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ])
sys1 = complete(sys1)
@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], [])
sys2 = complete(sys2)
@named fullsys = NonlinearSystem(
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4],
[x, y, z], [σ, β, ρ])
fullsys = complete(fullsys)

u0 = [x => 1.0,
y => 0.0,
z => 0.0]

p = [σ => 28.0,
ρ => 10.0,
β => 8 / 3]

prob1 = NonlinearProblem(sys1, u0, p)
prob2 = NonlinearProblem(sys2, u0, prob1.p)
sccprob = SCCNonlinearProblem(
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)

sccprob2 = remake(sccprob; u0 = 2ones(3))
@test state_values(sccprob2) ≈ 2ones(3)
@test sccprob2.probs[1].u0 ≈ 2ones(2)
@test sccprob2.probs[2].u0 ≈ 2ones(1)

sccprob3 = remake(sccprob; p = [σ => 2.0])
@test sccprob3.parameter_object === sccprob3.probs[1].p
@test sccprob3.parameter_object === sccprob3.probs[2].p

@test_throws ["parameters_alias", "SCCNonlinearProblem"] remake(
sccprob; parameters_alias = false, p = [σ => 2.0])

newp = remake_buffer(sys1, prob1.p, [σ], [3.0])
sccprob4 = remake(sccprob; parameters_alias = false, p = newp,
probs = [remake(prob1; p = [σ => 3.0]), prob2])
@test !sccprob4.parameters_alias
@test sccprob4.parameter_object !== sccprob4.probs[1].p
@test sccprob4.parameter_object !== sccprob4.probs[2].p
end
Loading