Skip to content

Commit

Permalink
Merge pull request #658 from AayushSabharwal/as/better-remake
Browse files Browse the repository at this point in the history
refactor: improve remake, use SII.remake_buffer, respect model defaults
  • Loading branch information
ChrisRackauckas authored Apr 1, 2024
2 parents a0d9a31 + a437026 commit ae4f9a8
Show file tree
Hide file tree
Showing 7 changed files with 567 additions and 338 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ DataFrames = "1.6"
Distributed = "1.10"
DocStringExtensions = "0.9"
EnumX = "1"
ForwardDiff = "0.10.36"
FunctionWrappersWrappers = "0.1.3"
IteratorInterfaceExtensions = "^1"
LinearAlgebra = "1.10"
Expand All @@ -83,7 +84,7 @@ SciMLStructures = "1.1"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
SymbolicIndexingInterface = "0.3.3"
SymbolicIndexingInterface = "0.3.15"
Tables = "1.11"
Zygote = "0.6.67"
julia = "1.10"
Expand All @@ -93,6 +94,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Expand All @@ -109,4 +111,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq"]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
254 changes: 180 additions & 74 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,21 @@ function isrecompile(prob::ODEProblem{iip}) where {iip}
(prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true
end

"""
remake(prob::AbstractSciMLProblem; u0 = missing, p = missing, interpret_symbolicmap = true, use_defaults = false)
Remake the given problem `prob`. If `u0` or `p` are given, they will be used instead
of the unknowns/parameters of the problem. Either of them can be a symbolic map if
the problem has an associated system. If `interpret_symbolicmap == false`, `p` will never
be interpreted as a symbolic map and used as-is for parameters. `use_defaults` allows
controlling whether the default values from the system will be used to calculate missing
values in the symbolic map passed to `u0` or `p`. It is only valid when either `u0` or
`p` have been explicitly provided as a symbolic map and the problem has an associated
system.
"""
function remake(prob::AbstractSciMLProblem; u0 = missing,
p = missing, interpret_symbolicmap = true, kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
p = missing, interpret_symbolicmap = true, use_defaults = false, kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
_remake_internal(prob; kwargs..., u0, p)
end

Expand All @@ -56,8 +68,8 @@ function remake(prob::AbstractNoiseProblem; kwargs...)
end

function remake(
prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...)
p = updated_p(prob, p; interpret_symbolicmap)
prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, use_defaults = false, kwargs...)
_, p = updated_u0_p(prob, nothing, p; interpret_symbolicmap, use_defaults)
_remake_internal(prob; kwargs..., p)
end

Expand All @@ -74,12 +86,13 @@ function remake(prob::ODEProblem; f = missing,
p = missing,
kwargs = missing,
interpret_symbolicmap = true,
use_defaults = false,
_kwargs...)
if tspan === missing
tspan = prob.tspan
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)

iip = isinplace(prob)

Expand Down Expand Up @@ -132,12 +145,13 @@ Remake the given `BVProblem`.
"""
function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = missing,
u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing,
interpret_symbolicmap = true, _kwargs...) where {uType, tType, iip, nlls}
interpret_symbolicmap = true, use_defaults = false, _kwargs...) where {
uType, tType, iip, nlls}
if tspan === missing
tspan = prob.tspan
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)

if problem_type === missing
problem_type = prob.problem_type
Expand Down Expand Up @@ -194,14 +208,15 @@ function remake(prob::SDEProblem;
noise = missing,
noise_rate_prototype = missing,
interpret_symbolicmap = true,
use_defaults = false,
seed = missing,
kwargs = missing,
_kwargs...)
if tspan === missing
tspan = prob.tspan
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)

if noise === missing
noise = prob.noise
Expand Down Expand Up @@ -256,8 +271,9 @@ function remake(prob::OptimizationProblem;
sense = missing,
kwargs = missing,
interpret_symbolicmap = true,
use_defaults = false,
_kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
if f === missing
f = prob.f
end
Expand Down Expand Up @@ -307,8 +323,9 @@ function remake(prob::NonlinearProblem;
problem_type = missing,
kwargs = missing,
interpret_symbolicmap = true,
use_defaults = false,
_kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
if f === missing
f = prob.f
end
Expand All @@ -333,8 +350,8 @@ end
Remake the given `NonlinearLeastSquaresProblem`.
"""
function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing,
interpret_symbolicmap = true, kwargs = missing, _kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
interpret_symbolicmap = true, use_defaults = false, kwargs = missing, _kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)

if f === missing
f = prob.f
Expand All @@ -348,82 +365,171 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
end
end

function updated_u0_p(prob, u0, p; interpret_symbolicmap = true)
newp = updated_p(prob, p; interpret_symbolicmap)
newu0 = updated_u0(prob, u0, p)
return newu0, newp
function varmap_has_var(varmap, var)
haskey(varmap, var) || hasname(var) && haskey(varmap, getname(var))
end

function updated_u0(prob, u0, p)
if u0 === missing || u0 isa Function
return state_values(prob)
function varmap_get(varmap, var, default = nothing)
if haskey(varmap, var)
return varmap[var]
end
if u0 isa Number
return u0
if hasname(var)
name = getname(var)
if haskey(varmap, name)
return varmap[name]
end
end
if eltype(u0) <: Pair
u0 = Dict(u0)
else
return u0
return default
end

anydict(d) = Dict{Any, Any}(d)

function _updated_u0_p_internal(
prob, ::Missing, p; interpret_symbolicmap = true, use_defaults = false)
u0 = state_values(prob)

if p isa AbstractArray && isempty(p)
return _updated_u0_p_internal(
prob, u0, parameter_values(prob); interpret_symbolicmap)
end
if !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to the order."))
eltype(p) <: Pair && interpret_symbolicmap || return u0, p
defs = use_defaults ? default_values(prob) : nothing
p = fill_p(prob, anydict(p); defs)
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true))
end

function _updated_u0_p_internal(
prob, u0, ::Missing; interpret_symbolicmap = true, use_defaults = false)
p = parameter_values(prob)

eltype(u0) <: Pair || return u0, p
defs = use_defaults ? default_values(prob) : nothing
u0 = fill_u0(prob, anydict(u0); defs)
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false))
end

function _updated_u0_p_internal(
prob, u0, p; interpret_symbolicmap = true, use_defaults = false)
isu0symbolic = eltype(u0) <: Pair
ispsymbolic = eltype(p) <: Pair && interpret_symbolicmap

if !isu0symbolic && !ispsymbolic
return u0, p
end
newu0 = copy(state_values(prob))
if all(==(NotSymbolic()), symbolic_type.(values(u0)))
setu(prob, collect(keys(u0)))(newu0, collect(values(u0)))
else
value_syms = [k for (k, v) in u0 if symbolic_type(v) === NotSymbolic()]
dependent_syms = [k for (k, v) in u0 if symbolic_type(v) !== NotSymbolic()]
setu(prob, value_syms)(newu0, getindex.((u0,), value_syms))
obs = SymbolicIndexingInterface.observed(prob, getindex.((u0,), dependent_syms))
if is_time_dependent(prob)
dependent_vals = obs(newu0, p, current_time(prob))
else
dependent_vals = obs(newu0, p)
end
setu(prob, dependent_syms)(newu0, dependent_vals)
defs = use_defaults ? default_values(prob) : nothing
if isu0symbolic
u0 = fill_u0(prob, anydict(u0); defs)
end
return newu0
if ispsymbolic
p = fill_p(prob, anydict(p); defs)
end
return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic))
end

function updated_p(prob, p; interpret_symbolicmap = true)
if p === missing
return parameter_values(prob)
end
if eltype(p) <: Pair
if interpret_symbolicmap
has_sys(prob.f) ||
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
else
return p
end
p = Dict(p)
else
function fill_u0(prob, u0; defs = nothing)
vsyms = variable_symbols(prob)
if length(u0) == length(vsyms)
return u0
end
newvals = anydict(sym => if defs !== nothing && varmap_has_var(defs, sym)
varmap_get(defs, sym)
else
getu(prob, sym)(prob)
end for sym in vsyms if !varmap_has_var(u0, sym))
return merge(u0, newvals)
end

function fill_p(prob, p; defs = nothing)
psyms = parameter_symbols(prob)::Vector
if length(p) == length(psyms)
return p
end
newvals = anydict(sym => if defs !== nothing && varmap_has_var(defs, sym)
varmap_get(defs, sym)
else
getp(prob, sym)(prob)
end for sym in psyms if !varmap_has_var(p, sym))
return merge(p, newvals)
end

newp = copy(parameter_values(prob))
if all(==(NotSymbolic()), symbolic_type.(values(p)))
setp(prob, collect(keys(p)))(newp, collect(values(p)))
else
value_syms = [k for (k, v) in p if symbolic_type(v) === NotSymbolic()]
dependent_syms = [k for (k, v) in p if symbolic_type(v) !== NotSymbolic()]
setp(prob, value_syms)(newp, getindex.((p,), value_syms))
obs = SymbolicIndexingInterface.observed(prob, getindex.((p,), dependent_syms))
if is_time_dependent(prob)
dependent_vals = obs(state_values(prob), newp, current_time(prob))
else
dependent_vals = obs(state_values(prob), newp)
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), u0), p

u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
for (k, v) in u0)

isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), u0), p

temp_state = ProblemState(; p = p)
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
for (k, v) in u0)
return remake_buffer(prob, state_values(prob), u0), p
end

function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true})
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), p)

p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
for (k, v) in p)

isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), p)

temp_state = ProblemState(; u = u0)
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
for (k, v) in p)
return u0, remake_buffer(prob, parameter_values(prob), p)
end

function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
isu0dep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
ispdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)

if !isu0dep && !ispdep
return remake_buffer(prob, state_values(prob), u0),
remake_buffer(prob, parameter_values(prob), p)
end
if !isu0dep
u0 = remake_buffer(prob, state_values(prob), u0)
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true))
end
if !ispdep
p = remake_buffer(prob, parameter_values(prob), p)
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false))
end

varmap = merge(u0, p)
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
for (k, v) in u0)
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
for (k, v) in p)
return remake_buffer(prob, state_values(prob), u0),
remake_buffer(prob, parameter_values(prob), p)
end

function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults = false)
if u0 === missing && p === missing
return state_values(prob), parameter_values(prob)
end
if !has_sys(prob.f)
if interpret_symbolicmap && eltype(p) <: Pair
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if eltype(u0) <: Pair
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to the order."))
end
setp(prob, dependent_syms)(newp, dependent_vals)
return (u0 === missing ? state_values(prob) : u0),
(p === missing ? parameter_values(prob) : p)
end
return newp
return _updated_u0_p_internal(prob, u0, p; interpret_symbolicmap, use_defaults)
end

# overloaded in MTK to intercept symbolic remake
Expand Down
Loading

0 comments on commit ae4f9a8

Please sign in to comment.