Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix remake for symbolic indexing, add tests for remake #644

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Expand Down Expand Up @@ -78,6 +79,7 @@ RecursiveArrayTools = "3.8.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
SciMLStructures = "1.1"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ if isdefined(Base, :Experimental) &&
end
using ConstructionBase
using RecipesBase, RecursiveArrayTools, Tables
using SciMLStructures
using SymbolicIndexingInterface
using DocStringExtensions
using LinearAlgebra
Expand Down
227 changes: 104 additions & 123 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
arguments.
"""
function remake(thing; kwargs...)
_remake_internal(thing; kwargs...)

Check warning on line 28 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L28

Added line #L28 was not covered by tests
end

function _remake_internal(thing; kwargs...)
T = remaker_of(thing)
if :kwargs ∈ fieldnames(typeof(thing))
if :kwargs ∉ keys(kwargs)
Expand All @@ -41,6 +45,20 @@
(prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true
end

function remake(prob::AbstractSciMLProblem; u0 = missing, p = missing, interpret_symbolicmap = true, kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., u0, p)
end

function remake(prob::AbstractNoiseProblem; kwargs...)
_remake_internal(prob; kwargs...)

Check warning on line 54 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L53-L54

Added lines #L53 - L54 were not covered by tests
end

function remake(prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...)
p = updated_p(prob, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., p)

Check warning on line 59 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L57-L59

Added lines #L57 - L59 were not covered by tests
end

"""
remake(prob::ODEProblem; f = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, _kwargs...)
Expand All @@ -59,37 +77,7 @@
tspan = prob.tspan
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
isu0symbolic = eltype(u0) <: Pair && !isempty(u0)
ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap
if isu0symbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to" *
"parameter order."))
end
if ispsymbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if isu0symbolic && ispsymbolic
p, u0 = process_p_u0_symbolic(prob, p, u0)
elseif isu0symbolic
_, u0 = process_p_u0_symbolic(prob, prob.p, u0)
elseif ispsymbolic
p, _ = process_p_u0_symbolic(prob, p, prob.u0)
end
end
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

iip = isinplace(prob)

Expand Down Expand Up @@ -141,21 +129,12 @@
Remake the given `BVProblem`.
"""
function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, _kwargs...)
p = missing, kwargs = missing, problem_type = missing, interpret_symbolicmap = true, _kwargs...)
if tspan === missing
tspan = prob.tspan
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
end
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

iip = isinplace(prob)

Expand Down Expand Up @@ -211,20 +190,15 @@
p = missing,
noise = missing,
noise_rate_prototype = missing,
interpret_symbolicmap = true,
seed = missing,
kwargs = missing,
_kwargs...)
if tspan === missing
tspan = prob.tspan
end

if p === missing
p = prob.p
end

if u0 === missing
u0 = prob.u0
end
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

if noise === missing
noise = prob.noise
Expand Down Expand Up @@ -280,38 +254,8 @@
kwargs = missing,
interpret_symbolicmap = true,
_kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
isu0symbolic = eltype(u0) <: Pair && !isempty(u0)
ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap
if isu0symbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to" *
"parameter order."))
end
if ispsymbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if isu0symbolic && ispsymbolic
p, u0 = process_p_u0_symbolic(prob, p, u0)
elseif isu0symbolic
_, u0 = process_p_u0_symbolic(prob, prob.p, u0)
elseif ispsymbolic
p, _ = process_p_u0_symbolic(prob, p, prob.u0)
end
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
if f === missing
f = prob.f
end
Expand Down Expand Up @@ -362,38 +306,7 @@
kwargs = missing,
interpret_symbolicmap = true,
_kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
isu0symbolic = eltype(u0) <: Pair && !isempty(u0)
ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap
if isu0symbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remke`" *
"with the `u0` keyword argument as a vector of values, paying attention to" *
"parameter order."))
end
if ispsymbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if isu0symbolic && ispsymbolic
p, u0 = process_p_u0_symbolic(prob, p, u0)
elseif isu0symbolic
_, u0 = process_p_u0_symbolic(prob, prob.p, u0)
elseif ispsymbolic
p, _ = process_p_u0_symbolic(prob, p, prob.u0)
end
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
if f === missing
f = prob.f
end
Expand All @@ -418,17 +331,8 @@
Remake the given `NonlinearLeastSquaresProblem`.
"""
function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing,
kwargs = missing, _kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
end
interpret_symbolicmap = true, kwargs = missing, _kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

if f === missing
f = prob.f
Expand All @@ -442,6 +346,83 @@
end
end

function updated_u0_p(prob, u0, p; interpret_symbolicmap = true)
newp = updated_p(prob, p; interpret_symbolicmap)
newu0 = updated_u0(prob, u0, p)
return newu0, newp
end

function updated_u0(prob, u0, p)
if u0 === missing || u0 isa Function
return state_values(prob)
end
if u0 isa Number
return u0
end
if eltype(u0) <: Pair
u0 = Dict(u0)
else
return u0
end
if !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *

Check warning on line 368 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L368

Added line #L368 was not covered by tests
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to the order."))
end
newu0 = copy(state_values(prob))
if all(==(NotSymbolic()), symbolic_type.(values(u0)))
setu(prob, collect(keys(u0)))(newu0, collect(values(u0)))
else
value_syms = [k for (k, v) in u0 if symbolic_type(v) === NotSymbolic()]
dependent_syms = [k for (k, v) in u0 if symbolic_type(v) !== NotSymbolic()]
setu(prob, value_syms)(newu0, getindex.((u0,), value_syms))
obs = SymbolicIndexingInterface.observed(prob, getindex.((u0,), dependent_syms))
if is_time_dependent(prob)
dependent_vals = obs(newu0, p, current_time(prob))

Check warning on line 381 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L376-L381

Added lines #L376 - L381 were not covered by tests
else
dependent_vals = obs(newu0, p)

Check warning on line 383 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L383

Added line #L383 was not covered by tests
end
setu(prob, dependent_syms)(newu0, dependent_vals)

Check warning on line 385 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L385

Added line #L385 was not covered by tests
end
return newu0
end

function updated_p(prob, p; interpret_symbolicmap = true)
if p === missing
return parameter_values(prob)
end
if eltype(p) <: Pair
if interpret_symbolicmap
has_sys(prob.f) || throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
else
return p

Check warning on line 401 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L401

Added line #L401 was not covered by tests
end
p = Dict(p)
else
return p
end

newp = copy(parameter_values(prob))
if all(==(NotSymbolic()), symbolic_type.(values(p)))
setp(prob, collect(keys(p)))(newp, collect(values(p)))
else
value_syms = [k for (k, v) in p if symbolic_type(v) === NotSymbolic()]
dependent_syms = [k for (k, v) in p if symbolic_type(v) !== NotSymbolic()]
setp(prob, value_syms)(newp, getindex.((p,), value_syms))
obs = SymbolicIndexingInterface.observed(prob, getindex.((p,), dependent_syms))
if is_time_dependent(prob)
dependent_vals = obs(state_values(prob), newp, current_time(prob))

Check warning on line 417 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L412-L417

Added lines #L412 - L417 were not covered by tests
else
dependent_vals = obs(state_values(prob), newp)

Check warning on line 419 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L419

Added line #L419 was not covered by tests
end
setp(prob, dependent_syms)(newp, dependent_vals)

Check warning on line 421 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L421

Added line #L421 was not covered by tests
end
return newp
end

# overloaded in MTK to intercept symbolic remake
function process_p_u0_symbolic(prob, p, u0)
if prob isa Union{AbstractDEProblem, OptimizationProblem, NonlinearProblem}
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand All @@ -19,6 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
BoundaryValueDiffEq = "5"
ForwardDiff = "0.10"
JumpProcesses = "9.10"
ModelingToolkit = "8.37, 9"
NonlinearSolve = "2, 3"
Optimization = "3"
Expand Down
Loading
Loading