From 844ebfd7be60a668f820c4df2723bbfe6611090d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 13:52:28 +0530 Subject: [PATCH] feat: add `remake` for `SCCNonlinearProblem` --- src/remake.jl | 51 ++++++++++++++++++- test/downstream/modelingtoolkit_remake.jl | 62 +++++++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 280957671..1a3d1c34b 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -508,6 +508,52 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p end end +""" + remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing, + parameters_alias = prob.parameters_alias, sys = missing, explicitfuns! = missing) + +Remake the given `SCCNonlinearProblem`. `u0` is the state vector for the entire problem, +which will be chunked appropriately and used to `remake` the individual subproblems. `p` +is the parameter object for `prob`. If `parameters_alias`, the same parameter object will be +used to `remake` the individual subproblems. Otherwise if `p !== missing`, this function will +error and require that `probs` be specified. `probs` is the collection of subproblems. Even if +`probs` is explicitly specified, the value of `u0` provided to `remake` will be used to +override the values in `probs`. `sys` is the index provider for the full system. +""" +function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing, + parameters_alias = prob.parameters_alias, sys = missing, + interpret_symbolicmap = true, use_defaults = false, explicitfuns! = missing) + if p !== missing && !parameters_alias && probs === missing + throw(ArgumentError("`parameters_alias` is `false` for the given `SCCNonlinearProblem`. Please provide the subproblems using the keyword `probs` with the parameters updated appropriately in each.")) + end + newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults, + indp = sys === missing ? prob.full_index_provider : sys) + if probs === missing + probs = prob.probs + end + offset = 0 + if u0 !== missing || p !== missing && parameters_alias + probs = map(probs) do subprob + subprob = if parameters_alias + remake(subprob; + u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))], + p = newp) + else + remake(subprob; + u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))]) + end + offset += length(state_values(subprob)) + return subprob + end + end + if sys === missing + sys = prob.full_index_provider + end + return SCCNonlinearProblem{ + typeof(probs), typeof(explicitfuns!), typeof(sys), typeof(newp)}( + probs, explicitfuns!, sys, newp, parameters_alias) +end + function varmap_has_var(varmap, var) haskey(varmap, var) || hasname(var) && haskey(varmap, getname(var)) end @@ -737,11 +783,12 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0) end function updated_u0_p( - prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false) + prob, u0, p, t0 = nothing; interpret_symbolicmap = true, + use_defaults = false, indp = has_sys(prob.f) ? prob.f.sys : nothing) if u0 === missing && p === missing return state_values(prob), parameter_values(prob) end - if !has_sys(prob.f) + if indp === nothing if interpret_symbolicmap && eltype(p) !== Union{} && eltype(p) <: Pair throw(ArgumentError("This problem does not support symbolic maps with " * "`remake`, i.e. it does not have a symbolic origin. Please use `remake`" * diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 228a26dd6..0395bb395 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -73,6 +73,24 @@ discu0 = Dict([u0..., x(k - 1) => 0.0, y(k - 1) => 0.0, z(k - 1) => 0.0]) push!(syss, discsys) push!(probs, DiscreteProblem(fn, getindex.((discu0,), unknowns(discsys)), (0, 10), ps)) +# TODO: Rewrite this example when the MTK codegen is merged +@named sys1 = NonlinearSystem( + [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ]) +sys1 = complete(sys1) +@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], []) +sys2 = complete(sys2) +@named fullsys = NonlinearSystem( + [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4], + [x, y, z], [σ, β, ρ]) +fullsys = complete(fullsys) + +prob1 = NonlinearProblem(sys1, u0, p) +prob2 = NonlinearProblem(sys2, u0, prob1.p) +sccprob = SCCNonlinearProblem( + [prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true) +push!(syss, fullsys) +push!(probs, sccprob) + for (sys, prob) in zip(syss, probs) @test parameter_values(prob) isa ModelingToolkit.MTKParameters @@ -274,3 +292,47 @@ end @test_throws SciMLBase.CyclicDependencyError remake( prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3]) end + +@testset "SCCNonlinearProblem" begin + @named sys1 = NonlinearSystem( + [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ]) + sys1 = complete(sys1) + @named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], []) + sys2 = complete(sys2) + @named fullsys = NonlinearSystem( + [0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4], + [x, y, z], [σ, β, ρ]) + fullsys = complete(fullsys) + + u0 = [x => 1.0, + y => 0.0, + z => 0.0] + + p = [σ => 28.0, + ρ => 10.0, + β => 8 / 3] + + prob1 = NonlinearProblem(sys1, u0, p) + prob2 = NonlinearProblem(sys2, u0, prob1.p) + sccprob = SCCNonlinearProblem( + [prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true) + + sccprob2 = remake(sccprob; u0 = 2ones(3)) + @test state_values(sccprob2) ≈ 2ones(3) + @test sccprob2.probs[1].u0 ≈ 2ones(2) + @test sccprob2.probs[2].u0 ≈ 2ones(1) + + sccprob3 = remake(sccprob; p = [σ => 2.0]) + @test sccprob3.parameter_object === sccprob3.probs[1].p + @test sccprob3.parameter_object === sccprob3.probs[2].p + + @test_throws ["parameters_alias", "SCCNonlinearProblem"] remake( + sccprob; parameters_alias = false, p = [σ => 2.0]) + + newp = remake_buffer(sys1, prob1.p, [σ], [3.0]) + sccprob4 = remake(sccprob; parameters_alias = false, p = newp, + probs = [remake(prob1; p = [σ => 3.0]), prob2]) + @test !sccprob4.parameters_alias + @test sccprob4.parameter_object !== sccprob4.probs[1].p + @test sccprob4.parameter_object !== sccprob4.probs[2].p +end