From e8238e2a2c6ce36df8291eb7281ec0e27e9d3f73 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 15:03:38 +0530 Subject: [PATCH] refactor: change `SCCNonlinearProblem` fields --- src/problems/nonlinear_problems.jl | 30 ++++++++++++----------- src/remake.jl | 5 ++-- test/downstream/modelingtoolkit_remake.jl | 12 ++++----- test/downstream/problem_interface.jl | 4 +-- 4 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/problems/nonlinear_problems.jl b/src/problems/nonlinear_problems.jl index b61b19ff1..c7d5d36b2 100644 --- a/src/problems/nonlinear_problems.jl +++ b/src/problems/nonlinear_problems.jl @@ -462,28 +462,30 @@ Note that this example aliases the parameters together for a memory-reduced repr * `probs`: the collection of problems to solve * `explictfuns!`: the explicit functions for mutating the parameter set """ -mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <: +mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip}, Par} <: AbstractNonlinearProblem{uType, iip} probs::P explicitfuns!::E - full_index_provider::I - parameter_object::Par + # NonlinearFunction with `f = Returns(nothing)` + f::F + p::Par parameters_alias::Bool - function SCCNonlinearProblem{P, E, I, Par}( - probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par} + function SCCNonlinearProblem{P, E, F, Par}(probs::P, funs::E, f::F, pobj::Par, + alias::Bool) where {P, E, F <: NonlinearFunction, Par} u0 = mapreduce( state_values, vcat, probs; init = similar(state_values(first(probs)), 0)) uType = typeof(u0) - new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias) + new{uType, false, P, E, F, Par}(probs, funs, f, pobj, alias) end end -function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing, - parameter_object = nothing, parameters_alias = false) +function SCCNonlinearProblem(probs, explicitfuns!, parameter_object = nothing, + parameters_alias = false; kwargs...) + f = NonlinearFunction{false}(Returns(nothing); kwargs...) return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!), - typeof(full_index_provider), typeof(parameter_object)}( - probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias) + typeof(f), typeof(parameter_object)}( + probs, explicitfuns!, f, parameter_object, parameters_alias) end function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol) @@ -496,10 +498,10 @@ function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol) end function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem) - prob.full_index_provider + prob.f end function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem) - prob.parameter_object + prob.p end function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem) mapreduce( @@ -516,8 +518,8 @@ function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, id end function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx) - if prob.parameter_object !== nothing - set_parameter!(prob.parameter_object, val, idx) + if prob.p !== nothing + set_parameter!(prob.p, val, idx) prob.parameters_alias && return end for scc in prob.probs diff --git a/src/remake.jl b/src/remake.jl index 1a3d1c34b..f7edb3be8 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -549,9 +549,8 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi if sys === missing sys = prob.full_index_provider end - return SCCNonlinearProblem{ - typeof(probs), typeof(explicitfuns!), typeof(sys), typeof(newp)}( - probs, explicitfuns!, sys, newp, parameters_alias) + return SCCNonlinearProblem( + probs, explicitfuns!, newp, parameters_alias; sys) end function varmap_has_var(varmap, var) diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 0395bb395..ef0bca75f 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -87,7 +87,7 @@ fullsys = complete(fullsys) prob1 = NonlinearProblem(sys1, u0, p) prob2 = NonlinearProblem(sys2, u0, prob1.p) sccprob = SCCNonlinearProblem( - [prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true) + [prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys) push!(syss, fullsys) push!(probs, sccprob) @@ -315,7 +315,7 @@ end prob1 = NonlinearProblem(sys1, u0, p) prob2 = NonlinearProblem(sys2, u0, prob1.p) sccprob = SCCNonlinearProblem( - [prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true) + [prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys) sccprob2 = remake(sccprob; u0 = 2ones(3)) @test state_values(sccprob2) ≈ 2ones(3) @@ -323,8 +323,8 @@ end @test sccprob2.probs[2].u0 ≈ 2ones(1) sccprob3 = remake(sccprob; p = [σ => 2.0]) - @test sccprob3.parameter_object === sccprob3.probs[1].p - @test sccprob3.parameter_object === sccprob3.probs[2].p + @test sccprob3.p === sccprob3.probs[1].p + @test sccprob3.p === sccprob3.probs[2].p @test_throws ["parameters_alias", "SCCNonlinearProblem"] remake( sccprob; parameters_alias = false, p = [σ => 2.0]) @@ -333,6 +333,6 @@ end sccprob4 = remake(sccprob; parameters_alias = false, p = newp, probs = [remake(prob1; p = [σ => 3.0]), prob2]) @test !sccprob4.parameters_alias - @test sccprob4.parameter_object !== sccprob4.probs[1].p - @test sccprob4.parameter_object !== sccprob4.probs[2].p + @test sccprob4.p !== sccprob4.probs[1].p + @test sccprob4.p !== sccprob4.probs[2].p end diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index 7e68b127b..1a500ebcf 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -367,7 +367,7 @@ prob = SteadyStateProblem(osys, u0, ps) prob = NonlinearProblem(model, []) sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3], SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]), - model, copy(cache)) + copy(cache); sys = model) for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]] @test prob[sym] ≈ sccprob[sym] @@ -384,7 +384,7 @@ prob = SteadyStateProblem(osys, u0, ps) end sccprob.ps[p] = 2.5 @test sccprob.ps[p] ≈ 2.5 - @test sccprob.parameter_object[1] ≈ 2.5 + @test sccprob.p[1] ≈ 2.5 for scc in sccprob.probs @test parameter_values(scc)[1] ≈ 2.5 end