Skip to content

Commit

Permalink
refactor: improve remake, use SII.remake_buffer, respect model defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Mar 28, 2024
1 parent 7888fa4 commit 0a74ca1
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 62 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ SciMLStructures = "1.1"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
SymbolicIndexingInterface = "0.3.3"
SymbolicIndexingInterface = "0.3.12"
Tables = "1.11"
Zygote = "0.6.67"
julia = "1.10"
Expand Down
189 changes: 128 additions & 61 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,82 +348,149 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
end
end

function updated_u0_p(prob, u0, p; interpret_symbolicmap = true)
newp = updated_p(prob, p; interpret_symbolicmap)
newu0 = updated_u0(prob, u0, p)
return newu0, newp
function varmap_has_var(varmap, var)
haskey(varmap, var) || hasname(var) && haskey(varmap, getname(var))
end

function updated_u0(prob, u0, p)
if u0 === missing || u0 isa Function
return state_values(prob)
function varmap_get(varmap, var, default = nothing)
if haskey(varmap, var)
return varmap[var]
end
if u0 isa Number
return u0
if hasname(var)
name = getname(var)
if haskey(varmap, name)
return varmap[name]
end
end
if eltype(u0) <: Pair
u0 = Dict(u0)
else
return u0
return default
end

function updated_u0_p(prob, u0, p; interpret_symbolicmap = true)
if u0 === missing && p === missing
return state_values(prob), parameter_values(prob)
end
if !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to the order."))
if interpret_symbolicmap && eltype(p) <: Pair
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if eltype(u0) <: Pair
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to the order."))
end
return (u0 === missing ? state_values(prob) : u0),
(p === missing ? parameter_values(prob) : p)
end
varmap = Dict()
newu0 = nothing
newp = nothing
isu0symbolic = eltype(u0) <: Pair
isu0dependent = false
ispsymbolic = (eltype(p) <: Pair) && interpret_symbolicmap
ispdependent = false
if u0 === missing || u0 isa AbstractArray && isempty(u0)
newu0 = state_values(prob)
elseif !isu0symbolic
newu0 = u0
end
if p === missing || p isa AbstractArray && isempty(p) ||
p === SciMLBase.NullParameters()
newp = parameter_values(prob)
elseif !ispsymbolic
newp = p
end
if newu0 !== nothing && newp !== nothing
return newu0, newp
end

defs = default_values(prob)
varsyms = variable_symbols(prob)
parsyms = parameter_symbols(prob)
if isu0symbolic
u0 = Dict{Any, Any}(u0)
if length(u0) != length(state_values(prob)) # some missing values
for sym in varsyms
if !varmap_has_var(u0, sym)
if varmap_has_var(defs, sym) # prefer defaults
u0[sym] = varmap_get(defs, sym)
else
u0[sym] = getu(prob, sym)(prob)
end
end
end
end
isu0dependent = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
end

if ispsymbolic
p = Dict{Any, Any}(p)
if length(p) != length(state_values(prob)) # some missing values
for sym in parsyms
if !varmap_has_var(p, sym)
if varmap_has_var(defs, sym) # prefer defaults
p[sym] = varmap_get(defs, sym)
else
p[sym] = getp(prob, sym)(prob)
end
end
end
end
ispdependent = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
end
newu0 = copy(state_values(prob))
if all(==(NotSymbolic()), symbolic_type.(values(u0)))
setu(prob, collect(keys(u0)))(newu0, collect(values(u0)))
else
value_syms = [k for (k, v) in u0 if symbolic_type(v) === NotSymbolic()]
dependent_syms = [k for (k, v) in u0 if symbolic_type(v) !== NotSymbolic()]
setu(prob, value_syms)(newu0, getindex.((u0,), value_syms))
obs = SymbolicIndexingInterface.observed(prob, getindex.((u0,), dependent_syms))
if is_time_dependent(prob)
dependent_vals = obs(newu0, p, current_time(prob))
else
dependent_vals = obs(newu0, p)
if !isu0dependent && !ispdependent
if newu0 === nothing
newu0 = remake_buffer(prob, state_values(prob), u0)
end
setu(prob, dependent_syms)(newu0, dependent_vals)
if newp === nothing
newp = remake_buffer(prob, parameter_values(prob), p)
end
return newu0, newp
end
return newu0
end
merge!(varmap, defs)

function updated_p(prob, p; interpret_symbolicmap = true)
if p === missing
return parameter_values(prob)
end
if eltype(p) <: Pair
if interpret_symbolicmap
has_sys(prob.f) ||
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
else
return p
end
p = Dict(p)
if isu0symbolic
merge!(varmap, u0)
else
return p
for sym in varsyms
if !varmap_has_var(varmap, sym)
varmap[sym] = getu(prob, sym)(newu0)
end
end
end

newp = copy(parameter_values(prob))
if all(==(NotSymbolic()), symbolic_type.(values(p)))
setp(prob, collect(keys(p)))(newp, collect(values(p)))
if ispsymbolic
merge!(varmap, p)
else
value_syms = [k for (k, v) in p if symbolic_type(v) === NotSymbolic()]
dependent_syms = [k for (k, v) in p if symbolic_type(v) !== NotSymbolic()]
setp(prob, value_syms)(newp, getindex.((p,), value_syms))
obs = SymbolicIndexingInterface.observed(prob, getindex.((p,), dependent_syms))
if is_time_dependent(prob)
dependent_vals = obs(state_values(prob), newp, current_time(prob))
else
dependent_vals = obs(state_values(prob), newp)
for sym in parsyms
if !varmap_has_var(varmap, sym)
varmap[sym] = getp(prob, sym)(newp)
end
end
end
if isu0dependent
for (k, v) in u0
if symbolic_type(v) !== NotSymbolic()
u0[k] = symbolic_evaluate(v, varmap)
end
end
end
if newu0 === nothing
newu0 = remake_buffer(prob, state_values(prob), u0)
end
if ispdependent
for (k, v) in p
if symbolic_type(v) !== NotSymbolic()
p[k] = symbolic_evaluate(v, varmap)
end
end
setp(prob, dependent_syms)(newp, dependent_vals)
newp = remake_buffer(prob, parameter_values(prob), p)
end
return newp
if newp === nothing
newp = remake_buffer(prob, parameter_values(prob), p)
end
return newu0, newp
end

# overloaded in MTK to intercept symbolic remake
Expand Down

0 comments on commit 0a74ca1

Please sign in to comment.