Skip to content

Commit

Permalink
refactor: change SCCNonlinearProblem fields
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 2, 2024
1 parent 7d4a687 commit 28a6176
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
30 changes: 16 additions & 14 deletions src/problems/nonlinear_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,28 +462,30 @@ Note that this example aliases the parameters together for a memory-reduced repr
* `probs`: the collection of problems to solve
* `explictfuns!`: the explicit functions for mutating the parameter set
"""
mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <:
mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip}, Par} <:
AbstractNonlinearProblem{uType, iip}
probs::P
explicitfuns!::E
full_index_provider::I
parameter_object::Par
# NonlinearFunction with `f = Returns(nothing)`
f::F
p::Par
parameters_alias::Bool

function SCCNonlinearProblem{P, E, I, Par}(
probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par}
function SCCNonlinearProblem{P, E, F, Par}(probs::P, funs::E, f::F, pobj::Par,
alias::Bool) where {P, E, F <: NonlinearFunction, Par}
u0 = mapreduce(
state_values, vcat, probs; init = similar(state_values(first(probs)), 0))
uType = typeof(u0)
new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias)
new{uType, false, P, E, F, Par}(probs, funs, f, pobj, alias)
end
end

function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing,
parameter_object = nothing, parameters_alias = false)
function SCCNonlinearProblem(probs, explicitfuns!, parameter_object = nothing,
parameters_alias = false; kwargs...)
f = NonlinearFunction{false}(Returns(nothing); kwargs...)
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!),
typeof(full_index_provider), typeof(parameter_object)}(
probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias)
typeof(f), typeof(parameter_object)}(
probs, explicitfuns!, f, parameter_object, parameters_alias)
end

function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
Expand All @@ -496,10 +498,10 @@ function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
end

function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem)
prob.full_index_provider
prob.f
end
function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem)
prob.parameter_object
prob.p
end
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
mapreduce(
Expand All @@ -516,8 +518,8 @@ function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, id
end

function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx)
if prob.parameter_object !== nothing
set_parameter!(prob.parameter_object, val, idx)
if prob.p !== nothing
set_parameter!(prob.p, val, idx)
prob.parameters_alias && return
end
for scc in prob.probs
Expand Down
4 changes: 2 additions & 2 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ prob = SteadyStateProblem(osys, u0, ps)
prob = NonlinearProblem(model, [])
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]),
model, copy(cache))
copy(cache); sys = model)

for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]]
@test prob[sym] sccprob[sym]
Expand All @@ -384,7 +384,7 @@ prob = SteadyStateProblem(osys, u0, ps)
end
sccprob.ps[p] = 2.5
@test sccprob.ps[p] 2.5
@test sccprob.parameter_object[1] 2.5
@test sccprob.p[1] 2.5
for scc in sccprob.probs
@test parameter_values(scc)[1] 2.5
end
Expand Down

0 comments on commit 28a6176

Please sign in to comment.