Skip to content

Commit

Permalink
fix: fix remake for symbolic indexing, add tests for remake
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Mar 6, 2024
1 parent 0998e07 commit 92a3ebc
Show file tree
Hide file tree
Showing 7 changed files with 410 additions and 123 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Expand Down Expand Up @@ -78,6 +79,7 @@ RecursiveArrayTools = "3.8.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
SciMLStructures = "1.1"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ if isdefined(Base, :Experimental) &&
end
using ConstructionBase
using RecipesBase, RecursiveArrayTools, Tables
using SciMLStructures
using SymbolicIndexingInterface
using DocStringExtensions
using LinearAlgebra
Expand Down
226 changes: 103 additions & 123 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Re-construct `thing` with new field values specified by the keyword
arguments.
"""
function remake(thing; kwargs...)
_remake_internal(thing; kwargs...)

Check warning on line 28 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L28

Added line #L28 was not covered by tests
end

function _remake_internal(thing; kwargs...)

Check warning on line 31 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L31

Added line #L31 was not covered by tests
T = remaker_of(thing)
if :kwargs fieldnames(typeof(thing))
if :kwargs keys(kwargs)
Expand All @@ -41,6 +45,11 @@ function isrecompile(prob::ODEProblem{iip}) where {iip}
(prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true
end

function remake(prob::AbstractSciMLProblem; u0 = missing, p = missing, interpret_symbolicmap = true, kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., u0, p)

Check warning on line 50 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L48-L50

Added lines #L48 - L50 were not covered by tests
end

"""
remake(prob::ODEProblem; f = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, _kwargs...)
Expand All @@ -59,37 +68,7 @@ function remake(prob::ODEProblem; f = missing,
tspan = prob.tspan
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
isu0symbolic = eltype(u0) <: Pair && !isempty(u0)
ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap
if isu0symbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to" *
"parameter order."))
end
if ispsymbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if isu0symbolic && ispsymbolic
p, u0 = process_p_u0_symbolic(prob, p, u0)
elseif isu0symbolic
_, u0 = process_p_u0_symbolic(prob, prob.p, u0)
elseif ispsymbolic
p, _ = process_p_u0_symbolic(prob, p, prob.u0)
end
end
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

iip = isinplace(prob)

Expand Down Expand Up @@ -141,21 +120,12 @@ end
Remake the given `BVProblem`.
"""
function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, _kwargs...)
p = missing, kwargs = missing, problem_type = missing, interpret_symbolicmap = true, _kwargs...)
if tspan === missing
tspan = prob.tspan
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
end
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

iip = isinplace(prob)

Expand Down Expand Up @@ -211,20 +181,15 @@ function remake(prob::SDEProblem;
p = missing,
noise = missing,
noise_rate_prototype = missing,
interpret_symbolicmap = true,
seed = missing,
kwargs = missing,
_kwargs...)
if tspan === missing
tspan = prob.tspan
end

if p === missing
p = prob.p
end

if u0 === missing
u0 = prob.u0
end
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

if noise === missing
noise = prob.noise
Expand Down Expand Up @@ -280,38 +245,8 @@ function remake(prob::OptimizationProblem;
kwargs = missing,
interpret_symbolicmap = true,
_kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
isu0symbolic = eltype(u0) <: Pair && !isempty(u0)
ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap
if isu0symbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to" *
"parameter order."))
end
if ispsymbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if isu0symbolic && ispsymbolic
p, u0 = process_p_u0_symbolic(prob, p, u0)
elseif isu0symbolic
_, u0 = process_p_u0_symbolic(prob, prob.p, u0)
elseif ispsymbolic
p, _ = process_p_u0_symbolic(prob, p, prob.u0)
end
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
if f === missing
f = prob.f
end
Expand Down Expand Up @@ -362,38 +297,7 @@ function remake(prob::NonlinearProblem;
kwargs = missing,
interpret_symbolicmap = true,
_kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
isu0symbolic = eltype(u0) <: Pair && !isempty(u0)
ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap
if isu0symbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *
" remake, i.e. it does not have a symbolic origin. Please use `remke`" *
"with the `u0` keyword argument as a vector of values, paying attention to" *
"parameter order."))
end
if ispsymbolic && !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
end
if isu0symbolic && ispsymbolic
p, u0 = process_p_u0_symbolic(prob, p, u0)
elseif isu0symbolic
_, u0 = process_p_u0_symbolic(prob, prob.p, u0)
elseif ispsymbolic
p, _ = process_p_u0_symbolic(prob, p, prob.u0)
end
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
if f === missing
f = prob.f
end
Expand All @@ -418,17 +322,8 @@ end
Remake the given `NonlinearLeastSquaresProblem`.
"""
function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing,
kwargs = missing, _kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
end
interpret_symbolicmap = true, kwargs = missing, _kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

if f === missing
f = prob.f
Expand All @@ -442,6 +337,91 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
end
end

function updated_u0_p(prob, u0, p; interpret_symbolicmap = true)
newp = updated_p(prob, p; interpret_symbolicmap)
newu0 = updated_u0(prob, u0, p)
return newu0, newp
end

function updated_u0(prob, u0, p)
if u0 === missing
return copy(state_values(prob))
end
if u0 isa Number
return u0
end
if eltype(u0) <: Pair
u0 = Dict(u0)
elseif u0 isa AbstractArray
if length(u0) == length(state_values(prob))
return u0
else
throw(ArgumentError("Invalid value for u0: $u0. Must be an array of appropriate length or a symbolic map"))

Check warning on line 359 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L359

Added line #L359 was not covered by tests
end
else
throw(ArgumentError("Invalid value for u0: $u0. Must be an array of appropriate length or a symbolic map"))

Check warning on line 362 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L362

Added line #L362 was not covered by tests
end
if !has_sys(prob.f)
throw(ArgumentError("This problem does not support symbolic maps with" *

Check warning on line 365 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L365

Added line #L365 was not covered by tests
" remake, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `u0` keyword argument as a vector of values, paying attention to the order."))
end
newu0 = copy(state_values(prob))
if all(==(NotSymbolic()), symbolic_type.(values(u0)))
setu(prob, collect(keys(u0)))(newu0, collect(values(u0)))
else
value_syms = [k for (k, v) in u0 if symbolic_type(v) === NotSymbolic()]
dependent_syms = [k for (k, v) in u0 if symbolic_type(v) !== NotSymbolic()]
setu(prob, value_syms)(newu0, getindex.((u0,), value_syms))
obs = SymbolicIndexingInterface.observed(prob, getindex.((u0,), dependent_syms))
if is_time_dependent(prob)
dependent_vals = obs(newu0, p, current_time(prob))

Check warning on line 378 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L373-L378

Added lines #L373 - L378 were not covered by tests
else
dependent_vals = obs(newu0, p)

Check warning on line 380 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L380

Added line #L380 was not covered by tests
end
setu(prob, dependent_syms)(newu0, dependent_vals)

Check warning on line 382 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L382

Added line #L382 was not covered by tests
end
return newu0
end

function updated_p(prob, p; interpret_symbolicmap = true)
if p === missing
return copy(parameter_values(prob))
end
if eltype(p) <: Pair
if interpret_symbolicmap
has_sys(prob.f) || throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
else
return p

Check warning on line 398 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L398

Added line #L398 was not covered by tests
end
p = Dict(p)
elseif SciMLStructures.isscimlstructure(p) || p isa Union{AbstractArray, Tuple} || typeof(p) === typeof(parameter_values(prob))
return p
else
throw(ArgumentError("Invalid value for p: $p. Must be an array of appropriate length, a symbolic map, or a SciMLStructure."))

Check warning on line 404 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L404

Added line #L404 was not covered by tests
end

newp = copy(parameter_values(prob))
if all(==(NotSymbolic()), symbolic_type.(values(p)))
setp(prob, collect(keys(p)))(newp, collect(values(p)))
else
value_syms = [k for (k, v) in p if symbolic_type(v) === NotSymbolic()]
dependent_syms = [k for (k, v) in p if symbolic_type(v) !== NotSymbolic()]
setp(prob, value_syms)(newp, getindex.((p,), value_syms))
obs = SymbolicIndexingInterface.observed(prob, getindex.((p,), dependent_syms))
if is_time_dependent(prob)
dependent_vals = obs(state_values(prob), newp, current_time(prob))

Check warning on line 416 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L411-L416

Added lines #L411 - L416 were not covered by tests
else
dependent_vals = obs(state_values(prob), newp)

Check warning on line 418 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L418

Added line #L418 was not covered by tests
end
setp(prob, dependent_syms)(newp, dependent_vals)

Check warning on line 420 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L420

Added line #L420 was not covered by tests
end
return newp
end

# overloaded in MTK to intercept symbolic remake
function process_p_u0_symbolic(prob, p, u0)
if prob isa Union{AbstractDEProblem, OptimizationProblem, NonlinearProblem}
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand All @@ -17,6 +18,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
BoundaryValueDiffEq = "5"
ForwardDiff = "0.10"
JumpProcesses = "9.10"
ModelingToolkit = "8.37, 9"
NonlinearSolve = "2, 3"
Optimization = "3"
Expand Down
Loading

0 comments on commit 92a3ebc

Please sign in to comment.