Skip to content

Commit

Permalink
Merge pull request #644 from AayushSabharwal/as/remake
Browse files Browse the repository at this point in the history
fix: fix remake for symbolic indexing, add tests for remake
  • Loading branch information
ChrisRackauckas authored Mar 8, 2024
2 parents e01b907 + 30c96ab commit 0974769
Show file tree
Hide file tree
Showing 7 changed files with 411 additions and 123 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Expand Down Expand Up @@ -78,6 +79,7 @@ RecursiveArrayTools = "3.8.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
SciMLStructures = "1.1"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ if isdefined(Base, :Experimental) &&
end
using ConstructionBase
using RecipesBase, RecursiveArrayTools, Tables
using SciMLStructures
using SymbolicIndexingInterface
using DocStringExtensions
using LinearAlgebra
Expand Down
227 changes: 104 additions & 123 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Re-construct `thing` with new field values specified by the keyword
arguments.
"""
function remake(thing; kwargs...)
_remake_internal(thing; kwargs...)
end

function _remake_internal(thing; kwargs...)
T = remaker_of(thing)
if :kwargs fieldnames(typeof(thing))
if :kwargs keys(kwargs)
Expand All @@ -41,6 +45,20 @@ function isrecompile(prob::ODEProblem{iip}) where {iip}
(prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true
end

function remake(prob::AbstractSciMLProblem; u0 = missing, p = missing, interpret_symbolicmap = true, kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., u0, p)
end

function remake(prob::AbstractNoiseProblem; kwargs...)
_remake_internal(prob; kwargs...)
end

function remake(prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...)
p = updated_p(prob, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., p)
end

"""
remake(prob::ODEProblem; f = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, _kwargs...)
Expand All @@ -59,37 +77,7 @@ function remake(prob::ODEProblem; f = missing,
tspan = prob.tspan
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
isu0symbolic = eltype(u0) <: Pair && !isempty(u0)
ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap
if isu0symbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to" *
"parameter order."))
end
if ispsymbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if isu0symbolic && ispsymbolic
p, u0 = process_p_u0_symbolic(prob, p, u0)
elseif isu0symbolic
_, u0 = process_p_u0_symbolic(prob, prob.p, u0)
elseif ispsymbolic
p, _ = process_p_u0_symbolic(prob, p, prob.u0)
end
end
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

iip = isinplace(prob)

Expand Down Expand Up @@ -141,21 +129,12 @@ end
Remake the given `BVProblem`.
"""
function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, _kwargs...)
p = missing, kwargs = missing, problem_type = missing, interpret_symbolicmap = true, _kwargs...)
if tspan === missing
tspan = prob.tspan
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
end
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

iip = isinplace(prob)

Expand Down Expand Up @@ -211,20 +190,15 @@ function remake(prob::SDEProblem;
p = missing,
noise = missing,
noise_rate_prototype = missing,
interpret_symbolicmap = true,
seed = missing,
kwargs = missing,
_kwargs...)
if tspan === missing
tspan = prob.tspan
end

if p === missing
p = prob.p
end

if u0 === missing
u0 = prob.u0
end
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

if noise === missing
noise = prob.noise
Expand Down Expand Up @@ -280,38 +254,8 @@ function remake(prob::OptimizationProblem;
kwargs = missing,
interpret_symbolicmap = true,
_kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
isu0symbolic = eltype(u0) <: Pair && !isempty(u0)
ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap
if isu0symbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to" *
"parameter order."))
end
if ispsymbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if isu0symbolic && ispsymbolic
p, u0 = process_p_u0_symbolic(prob, p, u0)
elseif isu0symbolic
_, u0 = process_p_u0_symbolic(prob, prob.p, u0)
elseif ispsymbolic
p, _ = process_p_u0_symbolic(prob, p, prob.u0)
end
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
if f === missing
f = prob.f
end
Expand Down Expand Up @@ -362,38 +306,7 @@ function remake(prob::NonlinearProblem;
kwargs = missing,
interpret_symbolicmap = true,
_kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
isu0symbolic = eltype(u0) <: Pair && !isempty(u0)
ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap
if isu0symbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remke`" *
"with the `u0` keyword argument as a vector of values, paying attention to" *
"parameter order."))
end
if ispsymbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if isu0symbolic && ispsymbolic
p, u0 = process_p_u0_symbolic(prob, p, u0)
elseif isu0symbolic
_, u0 = process_p_u0_symbolic(prob, prob.p, u0)
elseif ispsymbolic
p, _ = process_p_u0_symbolic(prob, p, prob.u0)
end
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
if f === missing
f = prob.f
end
Expand All @@ -418,17 +331,8 @@ end
Remake the given `NonlinearLeastSquaresProblem`.
"""
function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing,
kwargs = missing, _kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
end
interpret_symbolicmap = true, kwargs = missing, _kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

if f === missing
f = prob.f
Expand All @@ -442,6 +346,83 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
end
end

function updated_u0_p(prob, u0, p; interpret_symbolicmap = true)
newp = updated_p(prob, p; interpret_symbolicmap)
newu0 = updated_u0(prob, u0, p)
return newu0, newp
end

function updated_u0(prob, u0, p)
if u0 === missing || u0 isa Function
return state_values(prob)
end
if u0 isa Number
return u0
end
if eltype(u0) <: Pair
u0 = Dict(u0)
else
return u0
end
if !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to the order."))
end
newu0 = copy(state_values(prob))
if all(==(NotSymbolic()), symbolic_type.(values(u0)))
setu(prob, collect(keys(u0)))(newu0, collect(values(u0)))
else
value_syms = [k for (k, v) in u0 if symbolic_type(v) === NotSymbolic()]
dependent_syms = [k for (k, v) in u0 if symbolic_type(v) !== NotSymbolic()]
setu(prob, value_syms)(newu0, getindex.((u0,), value_syms))
obs = SymbolicIndexingInterface.observed(prob, getindex.((u0,), dependent_syms))
if is_time_dependent(prob)
dependent_vals = obs(newu0, p, current_time(prob))
else
dependent_vals = obs(newu0, p)
end
setu(prob, dependent_syms)(newu0, dependent_vals)
end
return newu0
end

function updated_p(prob, p; interpret_symbolicmap = true)
if p === missing
return parameter_values(prob)
end
if eltype(p) <: Pair
if interpret_symbolicmap
has_sys(prob.f) || throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
else
return p
end
p = Dict(p)
else
return p
end

newp = copy(parameter_values(prob))
if all(==(NotSymbolic()), symbolic_type.(values(p)))
setp(prob, collect(keys(p)))(newp, collect(values(p)))
else
value_syms = [k for (k, v) in p if symbolic_type(v) === NotSymbolic()]
dependent_syms = [k for (k, v) in p if symbolic_type(v) !== NotSymbolic()]
setp(prob, value_syms)(newp, getindex.((p,), value_syms))
obs = SymbolicIndexingInterface.observed(prob, getindex.((p,), dependent_syms))
if is_time_dependent(prob)
dependent_vals = obs(state_values(prob), newp, current_time(prob))
else
dependent_vals = obs(state_values(prob), newp)
end
setp(prob, dependent_syms)(newp, dependent_vals)
end
return newp
end

# overloaded in MTK to intercept symbolic remake
function process_p_u0_symbolic(prob, p, u0)
if prob isa Union{AbstractDEProblem, OptimizationProblem, NonlinearProblem}
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand All @@ -19,6 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
BoundaryValueDiffEq = "5"
ForwardDiff = "0.10"
JumpProcesses = "9.10"
ModelingToolkit = "8.37, 9"
NonlinearSolve = "2, 3"
Optimization = "3"
Expand Down
Loading

0 comments on commit 0974769

Please sign in to comment.