From 747ca22e24608025a517560c338f499b0fecf255 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 15:03:38 +0530 Subject: [PATCH] refactor: change `SCCNonlinearProblem` fields --- src/problems/nonlinear_problems.jl | 30 +++++++++++++++------------- test/downstream/problem_interface.jl | 2 +- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/problems/nonlinear_problems.jl b/src/problems/nonlinear_problems.jl index b61b19ff1..c7d5d36b2 100644 --- a/src/problems/nonlinear_problems.jl +++ b/src/problems/nonlinear_problems.jl @@ -462,28 +462,30 @@ Note that this example aliases the parameters together for a memory-reduced repr * `probs`: the collection of problems to solve * `explictfuns!`: the explicit functions for mutating the parameter set """ -mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <: +mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip}, Par} <: AbstractNonlinearProblem{uType, iip} probs::P explicitfuns!::E - full_index_provider::I - parameter_object::Par + # NonlinearFunction with `f = Returns(nothing)` + f::F + p::Par parameters_alias::Bool - function SCCNonlinearProblem{P, E, I, Par}( - probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par} + function SCCNonlinearProblem{P, E, F, Par}(probs::P, funs::E, f::F, pobj::Par, + alias::Bool) where {P, E, F <: NonlinearFunction, Par} u0 = mapreduce( state_values, vcat, probs; init = similar(state_values(first(probs)), 0)) uType = typeof(u0) - new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias) + new{uType, false, P, E, F, Par}(probs, funs, f, pobj, alias) end end -function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing, - parameter_object = nothing, parameters_alias = false) +function SCCNonlinearProblem(probs, explicitfuns!, parameter_object = nothing, + parameters_alias = false; kwargs...) + f = NonlinearFunction{false}(Returns(nothing); kwargs...) return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!), - typeof(full_index_provider), typeof(parameter_object)}( - probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias) + typeof(f), typeof(parameter_object)}( + probs, explicitfuns!, f, parameter_object, parameters_alias) end function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol) @@ -496,10 +498,10 @@ function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol) end function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem) - prob.full_index_provider + prob.f end function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem) - prob.parameter_object + prob.p end function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem) mapreduce( @@ -516,8 +518,8 @@ function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, id end function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx) - if prob.parameter_object !== nothing - set_parameter!(prob.parameter_object, val, idx) + if prob.p !== nothing + set_parameter!(prob.p, val, idx) prob.parameters_alias && return end for scc in prob.probs diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index 7e68b127b..3fad0c24c 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -367,7 +367,7 @@ prob = SteadyStateProblem(osys, u0, ps) prob = NonlinearProblem(model, []) sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3], SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]), - model, copy(cache)) + copy(cache); sys = model) for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]] @test prob[sym] ≈ sccprob[sym]