From ef4c8e7e2faa9f0b170ac5fde23035ec30f1c6e9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 28 Mar 2024 11:49:18 +0530 Subject: [PATCH 1/5] refactor: improve remake, use SII.remake_buffer, respect model defaults --- Project.toml | 2 +- src/remake.jl | 208 +++++++++++++++++++++++++++++++++++--------------- 2 files changed, 147 insertions(+), 63 deletions(-) diff --git a/Project.toml b/Project.toml index 19fbdfaef..e42611684 100644 --- a/Project.toml +++ b/Project.toml @@ -83,7 +83,7 @@ SciMLStructures = "1.1" StaticArrays = "1.7" StaticArraysCore = "1.4" Statistics = "1.10" -SymbolicIndexingInterface = "0.3.3" +SymbolicIndexingInterface = "0.3.15" Tables = "1.11" Zygote = "0.6.67" julia = "1.10" diff --git a/src/remake.jl b/src/remake.jl index ff64ad592..e239c824d 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -348,82 +348,166 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p end end -function updated_u0_p(prob, u0, p; interpret_symbolicmap = true) - newp = updated_p(prob, p; interpret_symbolicmap) - newu0 = updated_u0(prob, u0, p) - return newu0, newp +function varmap_has_var(varmap, var) + haskey(varmap, var) || hasname(var) && haskey(varmap, getname(var)) end -function updated_u0(prob, u0, p) - if u0 === missing || u0 isa Function - return state_values(prob) +function varmap_get(varmap, var, default = nothing) + if haskey(varmap, var) + return varmap[var] end - if u0 isa Number - return u0 + if hasname(var) + name = getname(var) + if haskey(varmap, name) + return varmap[name] + end end - if eltype(u0) <: Pair - u0 = Dict(u0) - else - return u0 + return default +end + +anydict(d) = Dict{Any, Any}(d) + +function _updated_u0_p_internal(prob, ::Missing, p; interpret_symbolicmap = true) + u0 = state_values(prob) + if p isa AbstractArray && isempty(p) + return _updated_u0_p_internal( + prob, u0, parameter_values(prob); interpret_symbolicmap) end - if !has_sys(prob.f) - throw(ArgumentError("This problem does not support symbolic maps with" * - " remake, i.e. it does not have a symbolic origin. Please use `remake`" * - "with the `u0` keyword argument as a vector of values, paying attention to the order.")) + eltype(p) <: Pair && interpret_symbolicmap || return u0, p + defs = default_values(prob) + p = fill_p(prob, anydict(p), defs) + return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true)) +end + +function _updated_u0_p_internal(prob, u0, ::Missing; interpret_symbolicmap = true) + p = parameter_values(prob) + eltype(u0) <: Pair || return u0, p + defs = default_values(prob) + u0 = fill_u0(prob, anydict(u0), defs) + return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false)) +end + +function _updated_u0_p_internal(prob, u0, p; interpret_symbolicmap = true) + isu0symbolic = eltype(u0) <: Pair + ispsymbolic = eltype(p) <: Pair && interpret_symbolicmap + + if !isu0symbolic && !ispsymbolic + return u0, p end - newu0 = copy(state_values(prob)) - if all(==(NotSymbolic()), symbolic_type.(values(u0))) - setu(prob, collect(keys(u0)))(newu0, collect(values(u0))) - else - value_syms = [k for (k, v) in u0 if symbolic_type(v) === NotSymbolic()] - dependent_syms = [k for (k, v) in u0 if symbolic_type(v) !== NotSymbolic()] - setu(prob, value_syms)(newu0, getindex.((u0,), value_syms)) - obs = SymbolicIndexingInterface.observed(prob, getindex.((u0,), dependent_syms)) - if is_time_dependent(prob) - dependent_vals = obs(newu0, p, current_time(prob)) - else - dependent_vals = obs(newu0, p) - end - setu(prob, dependent_syms)(newu0, dependent_vals) + defs = default_values(prob) + if isu0symbolic + u0 = fill_u0(prob, anydict(u0), defs) + end + if ispsymbolic + p = fill_p(prob, anydict(p), defs) end - return newu0 + return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic)) end -function updated_p(prob, p; interpret_symbolicmap = true) - if p === missing - return parameter_values(prob) - end - if eltype(p) <: Pair - if interpret_symbolicmap - has_sys(prob.f) || - throw(ArgumentError("This problem does not support symbolic maps with " * - "`remake`, i.e. it does not have a symbolic origin. Please use `remake`" * - "with the `p` keyword argument as a vector of values (paying attention to" * - "parameter order) or pass `interpret_symbolicmap = false` as a keyword argument")) - else - return p - end - p = Dict(p) - else +function fill_u0(prob, u0, defs) + vsyms = variable_symbols(prob) + if length(u0) == length(vsyms) + return u0 + end + newvals = anydict(sym => if varmap_has_var(defs, sym) + varmap_get(defs, sym) + else + getu(prob, sym)(prob) + end for sym in vsyms if !varmap_has_var(u0, sym)) + return merge(u0, newvals) +end + +function fill_p(prob, p, defs) + psyms = parameter_symbols(prob)::Vector + if length(p) == length(psyms) return p end + newvals = anydict(sym => if varmap_has_var(defs, sym) + varmap_get(defs, sym) + else + getp(prob, sym)(prob) + end for sym in psyms if !varmap_has_var(p, sym)) + return merge(p, newvals) +end - newp = copy(parameter_values(prob)) - if all(==(NotSymbolic()), symbolic_type.(values(p))) - setp(prob, collect(keys(p)))(newp, collect(values(p))) - else - value_syms = [k for (k, v) in p if symbolic_type(v) === NotSymbolic()] - dependent_syms = [k for (k, v) in p if symbolic_type(v) !== NotSymbolic()] - setp(prob, value_syms)(newp, getindex.((p,), value_syms)) - obs = SymbolicIndexingInterface.observed(prob, getindex.((p,), dependent_syms)) - if is_time_dependent(prob) - dependent_vals = obs(state_values(prob), newp, current_time(prob)) - else - dependent_vals = obs(state_values(prob), newp) +function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}) + isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0) + isdep || return remake_buffer(prob, state_values(prob), u0), p + + u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0) + for (k, v) in u0) + + isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0) + isdep || return remake_buffer(prob, state_values(prob), u0), p + + temp_state = ProblemState(; p = p) + u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state) + for (k, v) in u0) + return remake_buffer(prob, state_values(prob), u0), p +end + +function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}) + isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p) + isdep || return u0, remake_buffer(prob, parameter_values(prob), p) + + p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p) + for (k, v) in p) + + isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p) + isdep || return u0, remake_buffer(prob, parameter_values(prob), p) + + temp_state = ProblemState(; u = u0) + p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state) + for (k, v) in p) + return u0, remake_buffer(prob, parameter_values(prob), p) +end + +function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}) + isu0dep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0) + ispdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p) + + if !isu0dep && !ispdep + return remake_buffer(prob, state_values(prob), u0), + remake_buffer(prob, parameter_values(prob), p) + end + if !isu0dep + u0 = remake_buffer(prob, state_values(prob), u0) + return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true)) + end + if !ispdep + p = remake_buffer(prob, parameter_values(prob), p) + return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false)) + end + + varmap = merge(u0, p) + u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap) + for (k, v) in u0) + p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap) + for (k, v) in p) + return remake_buffer(prob, state_values(prob), u0), + remake_buffer(prob, parameter_values(prob), p) +end + +function updated_u0_p(prob, u0, p; interpret_symbolicmap = true) + if u0 === missing && p === missing + return state_values(prob), parameter_values(prob) + end + if !has_sys(prob.f) + if interpret_symbolicmap && eltype(p) <: Pair + throw(ArgumentError("This problem does not support symbolic maps with " * + "`remake`, i.e. it does not have a symbolic origin. Please use `remake`" * + "with the `p` keyword argument as a vector of values (paying attention to" * + "parameter order) or pass `interpret_symbolicmap = false` as a keyword argument")) + end + if eltype(u0) <: Pair + throw(ArgumentError("This problem does not support symbolic maps with" * + " remake, i.e. it does not have a symbolic origin. Please use `remake`" * + "with the `u0` keyword argument as a vector of values, paying attention to the order.")) end - setp(prob, dependent_syms)(newp, dependent_vals) + return (u0 === missing ? state_values(prob) : u0), + (p === missing ? parameter_values(prob) : p) end - return newp + return _updated_u0_p_internal(prob, u0, p; interpret_symbolicmap) end # overloaded in MTK to intercept symbolic remake From e1f5bfcfbd5bf759fd0ddd5705a536ab5825f911 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 28 Mar 2024 11:51:05 +0530 Subject: [PATCH 2/5] test: improve remake tests, move to SII test group --- Project.toml | 4 +- test/downstream/modelingtoolkit_remake.jl | 306 ++++++++++++---------- test/remake_tests.jl | 301 +++++++++++++-------- test/runtests.jl | 10 +- 4 files changed, 360 insertions(+), 261 deletions(-) diff --git a/Project.toml b/Project.toml index e42611684..6bdb52548 100644 --- a/Project.toml +++ b/Project.toml @@ -60,6 +60,7 @@ DataFrames = "1.6" Distributed = "1.10" DocStringExtensions = "0.9" EnumX = "1" +ForwardDiff = "0.10.36" FunctionWrappersWrappers = "0.1.3" IteratorInterfaceExtensions = "^1" LinearAlgebra = "1.10" @@ -93,6 +94,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" @@ -109,4 +111,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq"] +test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"] diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index fdc883f71..da8992739 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -1,18 +1,23 @@ using ModelingToolkit, SymbolicIndexingInterface using JumpProcesses using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq +using Optimization +using OptimizationOptimJL -@parameters σ ρ β +probs = [] +syss = [] + +@parameters σ ρ β q @variables x(t) y(t) z(t) -eqs = [D(D(x)) ~ σ * (y - x), +eqs = [D(x) ~ σ * (y - x), D(y) ~ x * (ρ - z) - y, D(z) ~ x * y - β * z] -@named sys = ODESystem(eqs, t) -sys = structural_simplify(sys) -u0 = [D(x) => 2.0, - x => 1.0, +@named sys = ODESystem(eqs, t; parameter_dependencies = [q => 3β]) +sys = complete(sys) +u0 = [x => 1.0, y => 0.0, z => 0.0] @@ -21,145 +26,156 @@ p = [σ => 28.0, β => 8 / 3] tspan = (0.0, 100.0) -oprob = ODEProblem(sys, u0, tspan, p, jac = true) - -@inferred typeof(oprob) remake(oprob; u0 = [x => 2.0], p = [σ => 29.0]) -oprob2 = remake( - oprob; - u0 = [x => 2.0, sys.y => 1.2, :z => 1.0], - p = [σ => 29.0, sys.ρ => 11.0, :β => 3.0] -) -@test oprob2.u0 isa Vector{<:Number} -@test oprob2.p isa ModelingToolkit.MTKParameters -@test oprob2[x] == oprob2[sys.x] == oprob2[:x] == 2.0 -@test oprob2[y] == oprob2[sys.y] == oprob2[:y] == 1.2 -@test oprob2[z] == oprob2[sys.z] == oprob2[:z] == 1.0 -@test getp(sys, σ)(oprob2) == 29.0 -@test getp(sys, sys.ρ)(oprob2) == 11.0 -@test getp(sys, :β)(oprob2) == 3.0 - -oprob3 = remake(oprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update -@test oprob3[x] == 3.0 -@test getp(sys, σ)(oprob3) == 30.0 - -# SDEProblem. -noiseeqs = [0.1 * x, - 0.1 * y, - 0.1 * z] -@named noise_sys = SDESystem(sys, noiseeqs) -noise_sys = complete(noise_sys) -sprob = SDEProblem(noise_sys, u0, (0.0, 100.0), p) - -@inferred typeof(sprob) remake(sprob; u0 = [x => 2.0], p = [σ => 29.0]) -sprob2 = remake( - sprob; - u0 = [x => 2.0, sys.y => 1.2, :z => 1.0], - p = [σ => 29.0, sys.ρ => 11.0, :β => 3.0] -) -@test sprob2.u0 isa Vector{<:Number} -@test sprob2.p isa ModelingToolkit.MTKParameters -@test sprob2[x] == sprob2[sys.x] == sprob2[:x] == 2.0 -@test sprob2[y] == sprob2[sys.y] == sprob2[:y] == 1.2 -@test sprob2[z] == sprob2[sys.z] == sprob2[:z] == 1.0 -@test getp(sys, σ)(sprob2) == 29.0 -@test getp(sys, sys.ρ)(sprob2) == 11.0 -@test getp(sys, :β)(sprob2) == 3.0 - -sprob3 = remake(sprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update -@test sprob3[x] == 3.0 -@test getp(sys, σ)(sprob3) == 30.0 - -# DiscreteProblem -# @named de = DiscreteSystem( -# [D(x) ~ σ*(y-x), -# D(y) ~ x*(ρ-z)-y, -# D(z) ~ x*y - β*z], -# t, -# [x, y, z], -# [σ, ρ, β], -# ) -# dprob = DiscreteProblem(de, u0, tspan, p) - -# @inferred typeof(dprob) remake(dprob; u0 = [x => 2.0], p = [σ => 29.0]) -# dprob2 = remake( -# dprob; -# u0 = [x => 2.0, sys.y => 1.2, :z => 1.0], -# p = [σ => 29.0, sys.ρ => 11.0, :β => 3.0] -# ) -# @test dprob2.u0 isa Vector{<:Number} -# @test dprob2.p isa ModelingToolkit.MTKParameters -# @test dprob2[x] == dprob2[sys.x] == dprob2[:x] == 2.0 -# @test dprob2[y] == dprob2[sys.y] == dprob2[:y] == 1.2 -# @test dprob2[z] == dprob2[sys.z] == dprob2[:z] == 1.0 -# @test getp(de, σ)(dprob2) == 29.0 -# @test getp(de, sys.ρ)(dprob2) == 11.0 -# @test getp(de, :β)(dprob2) == 3.0 - -# dprob3 = remake(dprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update -# @test dprob3[x] == 3.0 -# @test getp(de, σ)(dprob3) == 30.0 - -# NonlinearProblem -@named ns = NonlinearSystem( - [0 ~ σ * (y - x), - 0 ~ x * (ρ - z) - y, - 0 ~ x * y - β * z], - [x, y, z], - [σ, ρ, β] -) -ns = complete(ns) -nlprob = NonlinearProblem(ns, u0, p) - -@inferred typeof(nlprob) remake(nlprob; u0 = [x => 2.0], p = [σ => 29.0]) -nlprob2 = remake( - nlprob; - u0 = [x => 2.0, sys.y => 1.2, :z => 1.0], - p = [σ => 29.0, sys.ρ => 11.0, :β => 3.0] -) -@test nlprob2.u0 isa Vector{<:Number} -@test nlprob2.p isa ModelingToolkit.MTKParameters -@test nlprob2[x] == nlprob2[sys.x] == nlprob2[:x] == 2.0 -@test nlprob2[y] == nlprob2[sys.y] == nlprob2[:y] == 1.2 -@test nlprob2[z] == nlprob2[sys.z] == nlprob2[:z] == 1.0 -@test getp(ns, σ)(nlprob2) == 29.0 -@test getp(ns, sys.ρ)(nlprob2) == 11.0 -@test getp(ns, :β)(nlprob2) == 3.0 - -nlprob3 = remake(nlprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update -@test nlprob3[x] == 3.0 -@test getp(ns, σ)(nlprob3) == 30.0 - -@parameters β γ -@variables S(t) I(t) R(t) -rate₁ = β * S * I -affect₁ = [S ~ S - 1, I ~ I + 1] -rate₂ = γ * I -affect₂ = [I ~ I - 1, R ~ R + 1] + +push!(syss, sys) +push!(probs, ODEProblem(sys, u0, tspan, p, jac = true)) + +noise_eqs = [0.1x, 0.1y, 0.1z] +@named sdesys = SDESystem(sys, noise_eqs) +sdesys = complete(sdesys) + +push!(syss, sdesys) +push!(probs, SDEProblem(sdesys, u0, tspan, p, jac = true)) + +@named nsys = NonlinearSystem([0 ~ eq.rhs for eq in eqs], [x, y, z], [σ, β, ρ]) +nsys = complete(nsys) + +push!(syss, nsys) +push!(probs, NonlinearProblem(nsys, u0, p, jac = true)) + +rate₁ = β * x * y +affect₁ = [x ~ x - σ, y ~ y + σ] +rate₂ = ρ * y +affect₂ = [y ~ y - 1, z ~ z + 1] j₁ = ConstantRateJump(rate₁, affect₁) j₂ = ConstantRateJump(rate₂, affect₂) -j₃ = MassActionJump(2 * β + γ, [R => 1], [S => 1, R => -1]) -@named js = JumpSystem([j₁, j₂, j₃], t, [S, I, R], [β, γ]) +j₃ = MassActionJump(2 * β + ρ, [z => 1], [x => 1, z => -1]) +@named js = JumpSystem([j₁, j₂, j₃], t, [x, y, z], [σ, β, ρ]) js = complete(js) -u₀map = [S => 999, I => 1, R => 0.0] -parammap = [β => 0.1 / 1000, γ => 0.01] -tspan = (0.0, 250.0) -jump_dprob = DiscreteProblem(js, u₀map, tspan, parammap) -jprob = JumpProblem(js, jump_dprob, Direct()) - -@inferred typeof(jprob) remake(jprob; u0 = [S => 900], p = [β => 0.2e-3]) -jprob2 = remake( - jprob; - u0 = [S => 900, js.I => 2, :R => 0.1], - p = [β => 0.2 / 1000, js.γ => 11.0] -) -@test jprob2.prob.u0 isa Vector{<:Number} -@test jprob2.prob.p isa ModelingToolkit.MTKParameters -@test jprob2[S] == jprob2[js.S] == jprob2[:S] == 900.0 -@test jprob2[I] == jprob2[js.I] == jprob2[:I] == 2.0 -@test jprob2[R] == jprob2[js.R] == jprob2[:R] == 0.1 -@test getp(js, β)(jprob2) == 0.2 / 1000 -@test getp(js, js.γ)(jprob2) == 11.0 - -jprob3 = remake(jprob; u0 = [S => 901], p = [:β => 0.3 / 1000]) # partial update -@test jprob3[S] == 901 -@test getp(js, β)(jprob3) == 0.3 / 1000 +jump_dprob = DiscreteProblem(js, u0, tspan, p) + +push!(syss, js) +push!(probs, JumpProblem(js, jump_dprob, Direct())) + +@named optsys = OptimizationSystem(sum(eq.lhs for eq in eqs), [x, y, z], [σ, ρ, β]) +optsys = complete(optsys) +push!(syss, optsys) +push!(probs, OptimizationProblem(optsys, u0, p)) + +k = ShiftIndex(t) +@mtkbuild discsys = DiscreteSystem( + [x ~ x(k - 1) * ρ + y(k - 2), y ~ y(k - 1) * σ - z(k - 2), z ~ z(k - 1) * β + x(k - 2)], + t) +# Roundabout method to avoid having to specify values for previous timestep +fn = DiscreteFunction(discsys) +ps = ModelingToolkit.MTKParameters(discsys, p) +push!(syss, discsys) +push!(probs, DiscreteProblem(fn, [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0, 10), ps)) + +for (sys, prob) in zip(syss, probs) + @test parameter_values(prob) isa ModelingToolkit.MTKParameters + + @inferred typeof(prob) remake(prob) + + baseType = Base.typename(typeof(prob)).wrapper + ugetter = getu(prob, [x, y, z]) + prob2 = @inferred baseType remake(prob; u0 = [x => 2.0, y => 3.0, z => 4.0]) + @test ugetter(prob2) == [2.0, 3.0, 4.0] + prob2 = @inferred baseType remake(prob; u0 = [sys.x => 2.0, sys.y => 3.0, sys.z => 4.0]) + @test ugetter(prob2) == [2.0, 3.0, 4.0] + prob2 = @inferred baseType remake(prob; u0 = [:x => 2.0, :y => 3.0, :z => 4.0]) + @test ugetter(prob2) == [2.0, 3.0, 4.0] + prob2 = @inferred baseType remake(prob; u0 = [x => 2.0, sys.y => 3.0, :z => 4.0]) + @test ugetter(prob2) == [2.0, 3.0, 4.0] + + prob2 = @inferred baseType remake(prob; u0 = [x => 12.0]) + @test ugetter(prob2) == [12.0, 0.0, 0.0] + prob2 = @inferred baseType remake(prob; u0 = [sys.x => 12.0]) + @test ugetter(prob2) == [12.0, 0.0, 0.0] + prob2 = @inferred baseType remake(prob; u0 = [:x => 12.0]) + @test ugetter(prob2) == [12.0, 0.0, 0.0] + + pgetter = getp(prob, [σ, β, ρ]) + prob2 = @inferred baseType remake(prob; p = [σ => 0.1, β => 0.2, ρ => 0.3]) + @test pgetter(prob2) == [0.1, 0.2, 0.3] + if prob isa ODEProblem + @test prob2.ps[q] ≈ 0.6 + end + prob2 = @inferred baseType remake(prob; p = [sys.σ => 0.1, sys.β => 0.2, sys.ρ => 0.3]) + @test pgetter(prob2) == [0.1, 0.2, 0.3] + if prob isa ODEProblem + @test prob2.ps[q] ≈ 0.6 + end + prob2 = @inferred baseType remake(prob; p = [:σ => 0.1, :β => 0.2, :ρ => 0.3]) + @test pgetter(prob2) == [0.1, 0.2, 0.3] + if prob isa ODEProblem + @test prob2.ps[q] ≈ 0.6 + end + prob2 = @inferred baseType remake(prob; p = [σ => 0.1, sys.β => 0.2, :ρ => 0.3]) + @test pgetter(prob2) == [0.1, 0.2, 0.3] + if prob isa ODEProblem + @test prob2.ps[q] ≈ 0.6 + end + + prob2 = @inferred baseType remake(prob; p = [σ => 0.5]) + @test pgetter(prob2) == [0.5, 8 / 3, 10.0] + prob2 = @inferred baseType remake(prob; p = [sys.σ => 0.5]) + @test pgetter(prob2) == [0.5, 8 / 3, 10.0] + prob2 = @inferred baseType remake(prob; p = [:σ => 0.5]) + @test pgetter(prob2) == [0.5, 8 / 3, 10.0] +end + +@variables ud(t) xd(t) yd(t) +@parameters p1 p2::Int [tunable = false] p3 +dt = 0.1 +c = Clock(t, dt) +k = ShiftIndex(c) + +eqs = [D(x) ~ Hold(ud) + ud ~ ud(k - 1) * p1 + yd + yd ~ p2 * yd(k - 1) + xd(k - 2) * p3 + xd ~ Sample(t, dt)(x)] +@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [p3 => 2p1]) +prob = ODEProblem(sys, [x => 1.0], (0.0, 5.0), + [p1 => 1.0, p2 => 2, ud(k - 1) => 3.0, xd(k - 1) => 4.0, xd(k - 2) => 5.0]) + +# parameter dependencies +prob2 = @inferred ODEProblem remake(prob; p = [p1 => 2.0]) +@test prob2.ps[p1] == 2.0 +@test prob2.ps[p3] == 4.0 +@test prob2.ps[p2] isa Int # type is preserved + +# ignore dependent parameter, preserve type +# Vector of pairs automatically promotes the `3` to a `Float64` +prob2 = @inferred ODEProblem remake(prob; p = Dict(p1 => 2.0, p3 => 1.0, p2 => 3)) +@test prob2.ps[p3] == 4.0 +@test prob2.ps[p2] isa Int + +# discrete portion +prob2 = @inferred ODEProblem remake(prob; p = [ud => 2.5, xd => 3.5, xd(k - 1) => 4.5]) +@test prob2.ps[ud] == 2.5 +@test prob2.ps[xd] == 3.5 +@test prob2.ps[xd(k - 1)] == 4.5 + +# Optimization +@parameters p +@mtkbuild sys = ODESystem([D(x) ~ -p * x], t) +odeprob = ODEProblem(sys, [x => 1.0], (0.0, 10.0), [p => 0.5]) + +ts = 0.0:0.5:10.0 +data = exp.(-2.5 .* ts) + +function loss(x, p) + prob = p[1] + + prob = @inferred ODEProblem remake( + prob; p = [prob.f.sys.p => x[1]], u0 = typeof(x)(prob.u0)) + sol = solve(prob, Tsit5()) + vals = sol(ts; idxs = prob.f.sys.x).u + return sum((data .- vals) .^ 2) / length(ts) +end + +f = OptimizationFunction(loss, Optimization.AutoForwardDiff()) +prob = OptimizationProblem(f, [0.5], [odeprob]) +sol = solve(prob, BFGS()) +@test sol.u[1]≈2.5 rtol=1e-4 diff --git a/test/remake_tests.jl b/test/remake_tests.jl index e5d6f6600..96d722f36 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -1,130 +1,207 @@ using SciMLBase using SymbolicIndexingInterface +using StaticArrays +using ForwardDiff +probs = [] +containerTypes = [Vector, Tuple, SVector{3}, MVector{3}, SizedVector{3}] # ODE function lorenz!(du, u, p, t) du[1] = p[1] * (u[2] - u[1]) du[2] = u[1] * (p[2] - u[3]) - u[2] du[3] = u[1] * u[2] - p[3] * u[3] end -u0 = [1.0; 0.0; 0.0] +u0 = [1.0; 2.0; 3.0] tspan = (0.0, 100.0) -p = [10.0, 28.0, 8 / 3] -fn = ODEFunction(lorenz!; sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)) -prob = ODEProblem(fn, u0, tspan, p) - -@test remake(prob).u0 == u0 -@test remake(prob).p == p -@test remake(prob; u0 = [2.0, 3.0, 4.0]).u0 == [2.0, 3.0, 4.0] -@test remake(prob; u0 = [:x => 2.0, :z => 4.0, :y => 3.0]).u0 == [2.0, 3.0, 4.0] -@test remake(prob; p = [11.0, 12.0, 13.0]).p == [11.0, 12.0, 13.0] -@test remake(prob; p = [:a => 11.0, :c => 13.0, :b => 12.0]).p == [11.0, 12.0, 13.0] -@test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13) -@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0] -@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 8 / 3] - -# BVP -g = 9.81 -L = 1.0 -tspan = (0.0, pi / 2) -function simplependulum!(du, u, p, t) - θ = u[1] - dθ = u[2] - du[1] = dθ - du[2] = -(p[1] / p[2]) * sin(θ) +p = [10.0, 20.0, 30.0] +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +fn = ODEFunction(lorenz!; sys) +for T in containerTypes + push!(probs, ODEProblem(fn, u0, tspan, T(p))) end -function bc1!(residual, u, p, t) - residual[1] = u[end ÷ 2][1] + pi / 2 # the solution at the middle of the time span should be -pi/2 - residual[2] = u[end][1] - pi / 2 # the solution at the end of the time span should be pi/2 + +function residual!(resid, u, p, t) + resid[1] = u[1] - 0.5 + resid[2] = u[2] - 0.5 + resid[3] = u[3] - 0.5 end -u0 = [pi / 2, pi / 2] -p = [g, L] -fn = BVPFunction(simplependulum!, bc1!; sys = SymbolCache([:x, :y], [:a, :b], :t)) -prob = BVProblem(fn, u0, tspan, p) - -@test remake(prob).u0 == u0 -@test remake(prob).p == p -@test remake(prob; u0 = [2.0, 3.0]).u0 == [2.0, 3.0] -@test remake(prob; u0 = [:x => 2.0, :y => 3.0]).u0 == [2.0, 3.0] -@test remake(prob; p = [11.0, 12.0]).p == [11.0, 12.0] -@test remake(prob; p = [:a => 11.0, :b => 12.0]).p == [11.0, 12.0] -@test remake(prob; p = (11.0, 12.0)).p == (11.0, 12.0) -@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, pi / 2] -@test remake(prob; p = [:b => 11.0]).p == [g, 11.0] - -# SDE -function sdef(du, u, p, t) - du[1] = p[1] * (u[2] - u[1]) - du[2] = u[1] * (p[2] - u[3]) - u[2] - du[3] = u[1] * u[2] - p[3] * u[3] +fn = BVPFunction(lorenz!, residual!; sys) +for T in containerTypes + push!(probs, BVProblem(fn, u0, tspan, T(p))) end -function sdeg(du, u, p, t) + +function noise!(du, u, p, t) du .= 0.1u end +fn = SDEFunction(lorenz!, noise!; sys) +for T in containerTypes + push!(probs, SDEProblem(fn, u0, tspan, T(p))) +end -u0 = [1.0, 0.0, 0.0] -p = [10, 2.33, 26] -tspan = (0, 100) -fn = SDEFunction(sdef, sdeg; sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)) -prob = SDEProblem(fn, u0, tspan, p) - -@test remake(prob).u0 == u0 -@test remake(prob).p == p -@test remake(prob; u0 = [2.0, 3.0, 4.0]).u0 == [2.0, 3.0, 4.0] -@test remake(prob; u0 = [:x => 2.0, :z => 4.0, :y => 3.0]).u0 == [2.0, 3.0, 4.0] -@test remake(prob; p = [11.0, 12.0, 13.0]).p == [11.0, 12.0, 13.0] -@test remake(prob; p = [:a => 11.0, :c => 13.0, :b => 12.0]).p == [11.0, 12.0, 13.0] -@test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13) -@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0] -@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 26.0] - -# OptimizationProblem -function loss(u, p) - return (p[1] - u[1])^2 + p[2] * (u[2] - u[1]^2)^2 +function loss(x, p) + du = similar(x) + lorenz!(du, u, p, 0.0) + return sum(du) end -u0 = [1.0, 2.0] -p = [1.0, 100.0] -fn = OptimizationFunction(loss; sys = SymbolCache([:x, :y], [:a, :b], :t)) -prob = OptimizationProblem(fn, u0, p) -@test remake(prob).u0 == u0 -@test remake(prob).p == p -@test remake(prob; u0 = [2.0, 3.0]).u0 == [2.0, 3.0] -@test remake(prob; u0 = [:x => 2.0, :y => 3.0]).u0 == [2.0, 3.0] -@test remake(prob; p = [11.0, 12.0]).p == [11.0, 12.0] -@test remake(prob; p = [:a => 11.0, :b => 12.0]).p == [11.0, 12.0] -@test remake(prob; p = (11.0, 12.0)).p == (11.0, 12.0) -@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 2.0] -@test remake(prob; p = [:b => 11.0]).p == [1.0, 11.0] - -# NonlinearProblem -function nlf(du, u, p) - du[1] = p[1] * (u[2] - u[1]) - du[2] = u[1] * (p[2] - u[3]) - u[2] - du[3] = u[1] * u[2] - p[3] * u[3] + +fn = OptimizationFunction(loss; sys) +for T in containerTypes + push!(probs, OptimizationProblem(fn, u0, T(p))) +end + +function nllorenz!(du, u, p) + lorenz!(du, u, p, 0.0) +end + +fn = NonlinearFunction(nllorenz!; sys) +for T in containerTypes + push!(probs, NonlinearProblem(fn, u0, T(p))) +end + +for T in containerTypes + push!(probs, NonlinearLeastSquaresProblem(fn, u0, T(p))) +end + +for prob in deepcopy(probs) + prob2 = @inferred remake(prob) + @test prob2.u0 == u0 + @test prob2.p == typeof(prob.p)(p) + baseType = Base.typename(typeof(prob)).wrapper + for T in containerTypes + if T !== Tuple + local u0 = T([2.0, 3.0, 4.0]) + prob2 = @inferred baseType remake(prob; u0 = deepcopy(u0)) + @test prob2.u0 == u0 + @test prob2.u0 isa T + end + local p = T([11.0, 12.0, 13.0]) + prob2 = @inferred baseType remake(prob; p = deepcopy(p)) + @test prob2.p == p + @test prob2.p isa T + end + + for T in [Float32, Float64] + local u0 = [:x => T(2.0), :z => T(4.0), :y => T(3.0)] + prob2 = @inferred baseType remake(prob; u0) + @test all(prob2.u0 .≈ T[2.0, 3.0, 4.0]) + @test eltype(prob2.u0) == T + + local u0 = [:x => T(2.0)] + prob2 = @inferred baseType remake(prob; u0) + @test all(prob2.u0 .≈ [2.0, 2.0, 3.0]) + @test eltype(prob2.u0) == Float64 # partial update promotes, since fallback is Float64 + + local p = [:a => T(11.0), :b => T(12.0), :c => T(13.0)] + prob2 = @inferred baseType remake(prob; p) + @test all(prob2.p .≈ T[11.0, 12.0, 13.0]) + @test eltype(prob2.p) == T + + local p = [:a => T(11.0)] + prob2 = @inferred baseType remake(prob; p) + @test all(prob2.p .≈ [11.0, 20.0, 30.0]) + if prob.p isa Tuple + @test prob2.p isa Tuple{T, Float64, Float64} + else + @test eltype(prob2.p) == Float64 + end + end + + # constant defaults + begin + prob.f.sys.defaults[:a] = 0.1 + prob.f.sys.defaults[:x] = 0.1 + # remake with no updates should use existing values + prob2 = @inferred baseType remake(prob) + @test prob2.u0 == u0 + @test prob2.p == typeof(prob.p)(p) + + # respect defaults (:x), fallback to existing value (:z) + prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2]) + @test prob2.u0 ≈ [0.1, 0.2, 3.0] + @test prob2.p == typeof(prob.p)(p) # params unaffected + + # override defaults + prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2]) + @test prob2.u0 ≈ [0.2, 2.0, 3.0] + @test prob2.p == typeof(prob.p)(p) + + prob2 = @inferred baseType remake(prob; p = [:b => 0.2]) + @test prob2.u0 == u0 + @test all(prob2.p .≈ [0.1, 0.2, 30.0]) + + prob2 = @inferred baseType remake(prob; p = [:a => 0.2]) + @test prob2.u0 == u0 + @test all(prob2.p .≈ [0.2, 20.0, 30.0]) + + empty!(prob.f.sys.defaults) + end + + # dependent defaults + begin + prob.f.sys.defaults[:b] = :(3a) + prob.f.sys.defaults[:y] = :(3x) + # remake with no updates should use existing values + prob2 = @inferred baseType remake(prob) + @test prob2.u0 == u0 + @test prob2.p == typeof(prob.p)(p) + + # respect defaults (:y), fallback to existing value (:z) + prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2]) + @test prob2.u0 ≈ [0.2, 0.6, 3.0] + @test prob2.p == typeof(prob.p)(p) # params unaffected + + # override defaults + prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2]) + @test prob2.u0 ≈ [1.0, 0.2, 3.0] + @test prob2.p == typeof(prob.p)(p) + + prob2 = @inferred baseType remake(prob; p = [:a => 0.2]) + @test prob2.u0 == u0 + @test all(prob2.p .≈ [0.2, 0.6, 30.0]) + + prob2 = @inferred baseType remake(prob; p = [:b => 0.2]) + @test prob2.u0 == u0 + @test all(prob2.p .≈ [10.0, 0.2, 30.0]) + + empty!(prob.f.sys.defaults) + end + + # defaults dependent on each other (params <-> states) + begin + prob.f.sys.defaults[:b] = :(3x) + prob.f.sys.defaults[:y] = :(3a) + # remake with no updates should use existing values + prob2 = @inferred baseType remake(prob) + @test prob2.u0 == u0 + @test prob2.p == typeof(prob.p)(p) + + # need to pass empty `Dict()` to prevent defaulting to existing values + prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2], p = Dict()) + @test prob2.u0 ≈ [0.2, 30.0, 3.0] + @test all(prob2.p .≈ [10.0, 0.6, 30.0]) + + # override defaults + prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2], p = Dict()) + @test prob2.u0 ≈ [1.0, 0.2, 3.0] + @test all(prob2.p .≈ [10.0, 3.0, 30.0]) + + prob2 = @inferred baseType remake(prob; p = [:a => 0.2], u0 = Dict()) + @test prob2.u0 ≈ [1.0, 0.6, 3.0] + @test all(prob2.p .≈ [0.2, 3.0, 30.0]) + + prob2 = @inferred baseType remake(prob; p = [:b => 0.2], u0 = Dict()) + @test prob2.u0 ≈ [1.0, 30.0, 3.0] + @test all(prob2.p .≈ [10.0, 0.2, 30.0]) + + empty!(prob.f.sys.defaults) + end + + if !isa(prob.p, Tuple) + function fakeloss!(p) + prob2 = @inferred baseType remake(prob; p = [:a => p]) + @test eltype(prob2.p) <: ForwardDiff.Dual + return prob2.ps[:a] + end + ForwardDiff.derivative(fakeloss!, 1.0) + end end -u0 = [1.0, 0.0, 0.0] -p = [10, 2.33, 26] -fn = NonlinearFunction(nlf; sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)) -prob = NonlinearProblem(fn, u0, p) - -@test remake(prob).u0 == u0 -@test remake(prob).p == p -@test remake(prob; u0 = [2.0, 3.0, 4.0]).u0 == [2.0, 3.0, 4.0] -@test remake(prob; u0 = [:x => 2.0, :z => 4.0, :y => 3.0]).u0 == [2.0, 3.0, 4.0] -@test remake(prob; p = [11.0, 12.0, 13.0]).p == [11.0, 12.0, 13.0] -@test remake(prob; p = [:a => 11.0, :c => 13.0, :b => 12.0]).p == [11.0, 12.0, 13.0] -@test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13) -@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0] -@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 26.0] - -# NonlinearLeastSquaresProblem -prob = NonlinearLeastSquaresProblem(fn, u0, p) -@test remake(prob).u0 == u0 -@test remake(prob).p == p -@test remake(prob; u0 = [2.0, 3.0, 4.0]).u0 == [2.0, 3.0, 4.0] -@test remake(prob; u0 = [:x => 2.0, :z => 4.0, :y => 3.0]).u0 == [2.0, 3.0, 4.0] -@test remake(prob; p = [11.0, 12.0, 13.0]).p == [11.0, 12.0, 13.0] -@test remake(prob; p = [:a => 11.0, :c => 13.0, :b => 12.0]).p == [11.0, 12.0, 13.0] -@test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13) -@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0] -@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 26.0] diff --git a/test/runtests.jl b/test/runtests.jl index 3fcbc36b4..3cd39c34a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,6 +64,10 @@ end @time @safetestset "Problem building tests" begin include("problem_building_test.jl") end + end + + if !is_APPVEYOR && + (GROUP == "Core" || GROUP == "All" || GROUP == "SymbolicIndexingInterface") @time @safetestset "Remake" begin include("remake_tests.jl") end @@ -106,15 +110,15 @@ end @time @safetestset "Partial Functions" begin include("downstream/partial_functions.jl") end - @time @safetestset "ModelingToolkit Remake" begin - include("downstream/modelingtoolkit_remake.jl") - end end if !is_APPVEYOR && (GROUP == "Downstream" || GROUP == "SymbolicIndexingInterface") if GROUP != "Downstream" activate_downstream_env() end + @time @safetestset "ModelingToolkit Remake" begin + include("downstream/modelingtoolkit_remake.jl") + end @time @safetestset "Symbol and integer based indexing of interpolated solutions" begin include("downstream/symbol_indexing.jl") end From e7965966da5f8bebeb255eeb566af1c0fb698290 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 28 Mar 2024 20:50:54 +0530 Subject: [PATCH 3/5] feat: support observed for SciMLFunctions containing a SymbolCache --- src/scimlfunctions.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 0df138559..eba6f8dd9 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4083,13 +4083,18 @@ function SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym) end function SymbolicIndexingInterface.observed(fn::AbstractSciMLFunction, sym) - if has_observed(fn) + if has_observed(fn) && fn.observed !== DEFAULT_OBSERVED && + fn.observed !== DEFAULT_OBSERVED_NO_TIME if hasmethod(fn.observed, Tuple{Any}) return fn.observed(sym) else return (args...) -> fn.observed(sym, args...) end end + if has_sys(fn) && + hasmethod(SymbolicIndexingInterface.observed, Tuple{typeof(fn.sys), typeof(sym)}) + return SymbolicIndexingInterface.observed(fn.sys, sym) + end error("SciMLFunction does not have observed") end From a16bc61828deb53ed2e44193b147c7f46085e320 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Mar 2024 13:46:48 +0530 Subject: [PATCH 4/5] fix: fix solution indexing in plot --- src/solutions/solution_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 1e40aa490..f72284ee9 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -490,7 +490,7 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic, push!(strs, "u[$(x[j])]") end else - _tmp = Vector{eltype(sol[1])}(undef, length(plot_analytic_timeseries)) + _tmp = Vector{eltype(sol.u[1])}(undef, length(plot_analytic_timeseries)) for n in 1:length(plot_analytic_timeseries) _tmp[n] = plot_analytic_timeseries[n][x[j]] end From a4370265a0bb738e00582c6eeb5bac6390afb62b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 1 Apr 2024 13:09:43 +0530 Subject: [PATCH 5/5] refactor: use existing values in remake, add `use_defaults` keyword --- src/remake.jl | 78 ++++++++++++++++++++++++++++---------------- test/remake_tests.jl | 43 +++++++++++++++++------- 2 files changed, 81 insertions(+), 40 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index e239c824d..96d3e743d 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -45,9 +45,21 @@ function isrecompile(prob::ODEProblem{iip}) where {iip} (prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true end +""" + remake(prob::AbstractSciMLProblem; u0 = missing, p = missing, interpret_symbolicmap = true, use_defaults = false) + +Remake the given problem `prob`. If `u0` or `p` are given, they will be used instead +of the unknowns/parameters of the problem. Either of them can be a symbolic map if +the problem has an associated system. If `interpret_symbolicmap == false`, `p` will never +be interpreted as a symbolic map and used as-is for parameters. `use_defaults` allows +controlling whether the default values from the system will be used to calculate missing +values in the symbolic map passed to `u0` or `p`. It is only valid when either `u0` or +`p` have been explicitly provided as a symbolic map and the problem has an associated +system. +""" function remake(prob::AbstractSciMLProblem; u0 = missing, - p = missing, interpret_symbolicmap = true, kwargs...) - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) + p = missing, interpret_symbolicmap = true, use_defaults = false, kwargs...) + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) _remake_internal(prob; kwargs..., u0, p) end @@ -56,8 +68,8 @@ function remake(prob::AbstractNoiseProblem; kwargs...) end function remake( - prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...) - p = updated_p(prob, p; interpret_symbolicmap) + prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, use_defaults = false, kwargs...) + _, p = updated_u0_p(prob, nothing, p; interpret_symbolicmap, use_defaults) _remake_internal(prob; kwargs..., p) end @@ -74,12 +86,13 @@ function remake(prob::ODEProblem; f = missing, p = missing, kwargs = missing, interpret_symbolicmap = true, + use_defaults = false, _kwargs...) if tspan === missing tspan = prob.tspan end - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) iip = isinplace(prob) @@ -132,12 +145,13 @@ Remake the given `BVProblem`. """ function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = missing, u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing, - interpret_symbolicmap = true, _kwargs...) where {uType, tType, iip, nlls} + interpret_symbolicmap = true, use_defaults = false, _kwargs...) where { + uType, tType, iip, nlls} if tspan === missing tspan = prob.tspan end - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) if problem_type === missing problem_type = prob.problem_type @@ -194,6 +208,7 @@ function remake(prob::SDEProblem; noise = missing, noise_rate_prototype = missing, interpret_symbolicmap = true, + use_defaults = false, seed = missing, kwargs = missing, _kwargs...) @@ -201,7 +216,7 @@ function remake(prob::SDEProblem; tspan = prob.tspan end - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) if noise === missing noise = prob.noise @@ -256,8 +271,9 @@ function remake(prob::OptimizationProblem; sense = missing, kwargs = missing, interpret_symbolicmap = true, + use_defaults = false, _kwargs...) - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) if f === missing f = prob.f end @@ -307,8 +323,9 @@ function remake(prob::NonlinearProblem; problem_type = missing, kwargs = missing, interpret_symbolicmap = true, + use_defaults = false, _kwargs...) - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) if f === missing f = prob.f end @@ -333,8 +350,8 @@ end Remake the given `NonlinearLeastSquaresProblem`. """ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing, - interpret_symbolicmap = true, kwargs = missing, _kwargs...) - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) + interpret_symbolicmap = true, use_defaults = false, kwargs = missing, _kwargs...) + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) if f === missing f = prob.f @@ -367,49 +384,54 @@ end anydict(d) = Dict{Any, Any}(d) -function _updated_u0_p_internal(prob, ::Missing, p; interpret_symbolicmap = true) +function _updated_u0_p_internal( + prob, ::Missing, p; interpret_symbolicmap = true, use_defaults = false) u0 = state_values(prob) + if p isa AbstractArray && isempty(p) return _updated_u0_p_internal( prob, u0, parameter_values(prob); interpret_symbolicmap) end eltype(p) <: Pair && interpret_symbolicmap || return u0, p - defs = default_values(prob) - p = fill_p(prob, anydict(p), defs) + defs = use_defaults ? default_values(prob) : nothing + p = fill_p(prob, anydict(p); defs) return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true)) end -function _updated_u0_p_internal(prob, u0, ::Missing; interpret_symbolicmap = true) +function _updated_u0_p_internal( + prob, u0, ::Missing; interpret_symbolicmap = true, use_defaults = false) p = parameter_values(prob) + eltype(u0) <: Pair || return u0, p - defs = default_values(prob) - u0 = fill_u0(prob, anydict(u0), defs) + defs = use_defaults ? default_values(prob) : nothing + u0 = fill_u0(prob, anydict(u0); defs) return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false)) end -function _updated_u0_p_internal(prob, u0, p; interpret_symbolicmap = true) +function _updated_u0_p_internal( + prob, u0, p; interpret_symbolicmap = true, use_defaults = false) isu0symbolic = eltype(u0) <: Pair ispsymbolic = eltype(p) <: Pair && interpret_symbolicmap if !isu0symbolic && !ispsymbolic return u0, p end - defs = default_values(prob) + defs = use_defaults ? default_values(prob) : nothing if isu0symbolic - u0 = fill_u0(prob, anydict(u0), defs) + u0 = fill_u0(prob, anydict(u0); defs) end if ispsymbolic - p = fill_p(prob, anydict(p), defs) + p = fill_p(prob, anydict(p); defs) end return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic)) end -function fill_u0(prob, u0, defs) +function fill_u0(prob, u0; defs = nothing) vsyms = variable_symbols(prob) if length(u0) == length(vsyms) return u0 end - newvals = anydict(sym => if varmap_has_var(defs, sym) + newvals = anydict(sym => if defs !== nothing && varmap_has_var(defs, sym) varmap_get(defs, sym) else getu(prob, sym)(prob) @@ -417,12 +439,12 @@ function fill_u0(prob, u0, defs) return merge(u0, newvals) end -function fill_p(prob, p, defs) +function fill_p(prob, p; defs = nothing) psyms = parameter_symbols(prob)::Vector if length(p) == length(psyms) return p end - newvals = anydict(sym => if varmap_has_var(defs, sym) + newvals = anydict(sym => if defs !== nothing && varmap_has_var(defs, sym) varmap_get(defs, sym) else getp(prob, sym)(prob) @@ -488,7 +510,7 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}) remake_buffer(prob, parameter_values(prob), p) end -function updated_u0_p(prob, u0, p; interpret_symbolicmap = true) +function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults = false) if u0 === missing && p === missing return state_values(prob), parameter_values(prob) end @@ -507,7 +529,7 @@ function updated_u0_p(prob, u0, p; interpret_symbolicmap = true) return (u0 === missing ? state_values(prob) : u0), (p === missing ? parameter_values(prob) : p) end - return _updated_u0_p_internal(prob, u0, p; interpret_symbolicmap) + return _updated_u0_p_internal(prob, u0, p; interpret_symbolicmap, use_defaults) end # overloaded in MTK to intercept symbolic remake diff --git a/test/remake_tests.jl b/test/remake_tests.jl index 96d722f36..2678bcc27 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -115,21 +115,26 @@ for prob in deepcopy(probs) @test prob2.u0 == u0 @test prob2.p == typeof(prob.p)(p) - # respect defaults (:x), fallback to existing value (:z) + # not passing use_defaults ignores defaults prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2]) + @test prob2.u0 == [1.0, 0.2, 3.0] + @test prob2.p == typeof(prob.p)(p) + + # respect defaults (:x), fallback to existing value (:z) + prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2], use_defaults = true) @test prob2.u0 ≈ [0.1, 0.2, 3.0] @test prob2.p == typeof(prob.p)(p) # params unaffected # override defaults - prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2]) + prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2], use_defaults = true) @test prob2.u0 ≈ [0.2, 2.0, 3.0] @test prob2.p == typeof(prob.p)(p) - prob2 = @inferred baseType remake(prob; p = [:b => 0.2]) + prob2 = @inferred baseType remake(prob; p = [:b => 0.2], use_defaults = true) @test prob2.u0 == u0 @test all(prob2.p .≈ [0.1, 0.2, 30.0]) - prob2 = @inferred baseType remake(prob; p = [:a => 0.2]) + prob2 = @inferred baseType remake(prob; p = [:a => 0.2], use_defaults = true) @test prob2.u0 == u0 @test all(prob2.p .≈ [0.2, 20.0, 30.0]) @@ -145,21 +150,26 @@ for prob in deepcopy(probs) @test prob2.u0 == u0 @test prob2.p == typeof(prob.p)(p) - # respect defaults (:y), fallback to existing value (:z) + # not passing use_defaults ignores defaults prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2]) + @test prob2.u0 == [0.2, 2.0, 3.0] + @test prob2.p == typeof(prob.p)(p) + + # respect defaults (:y), fallback to existing value (:z) + prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2], use_defaults = true) @test prob2.u0 ≈ [0.2, 0.6, 3.0] @test prob2.p == typeof(prob.p)(p) # params unaffected # override defaults - prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2]) + prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2], use_defaults = true) @test prob2.u0 ≈ [1.0, 0.2, 3.0] @test prob2.p == typeof(prob.p)(p) - prob2 = @inferred baseType remake(prob; p = [:a => 0.2]) + prob2 = @inferred baseType remake(prob; p = [:a => 0.2], use_defaults = true) @test prob2.u0 == u0 @test all(prob2.p .≈ [0.2, 0.6, 30.0]) - prob2 = @inferred baseType remake(prob; p = [:b => 0.2]) + prob2 = @inferred baseType remake(prob; p = [:b => 0.2], use_defaults = true) @test prob2.u0 == u0 @test all(prob2.p .≈ [10.0, 0.2, 30.0]) @@ -175,21 +185,30 @@ for prob in deepcopy(probs) @test prob2.u0 == u0 @test prob2.p == typeof(prob.p)(p) + # not passing use_defaults ignores defaults + prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2]) + @test prob2.u0 == [0.2, 2.0, 3.0] + @test prob2.p == typeof(prob.p)(p) + # need to pass empty `Dict()` to prevent defaulting to existing values - prob2 = @inferred baseType remake(prob; u0 = [:x => 0.2], p = Dict()) + prob2 = @inferred baseType remake( + prob; u0 = [:x => 0.2], p = Dict(), use_defaults = true) @test prob2.u0 ≈ [0.2, 30.0, 3.0] @test all(prob2.p .≈ [10.0, 0.6, 30.0]) # override defaults - prob2 = @inferred baseType remake(prob; u0 = [:y => 0.2], p = Dict()) + prob2 = @inferred baseType remake( + prob; u0 = [:y => 0.2], p = Dict(), use_defaults = true) @test prob2.u0 ≈ [1.0, 0.2, 3.0] @test all(prob2.p .≈ [10.0, 3.0, 30.0]) - prob2 = @inferred baseType remake(prob; p = [:a => 0.2], u0 = Dict()) + prob2 = @inferred baseType remake( + prob; p = [:a => 0.2], u0 = Dict(), use_defaults = true) @test prob2.u0 ≈ [1.0, 0.6, 3.0] @test all(prob2.p .≈ [0.2, 3.0, 30.0]) - prob2 = @inferred baseType remake(prob; p = [:b => 0.2], u0 = Dict()) + prob2 = @inferred baseType remake( + prob; p = [:b => 0.2], u0 = Dict(), use_defaults = true) @test prob2.u0 ≈ [1.0, 30.0, 3.0] @test all(prob2.p .≈ [10.0, 0.2, 30.0])