diff --git a/Project.toml b/Project.toml index 4dd31fd39..c3e2b2ad2 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -78,6 +79,7 @@ RecursiveArrayTools = "3.8.0" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" SciMLOperators = "0.3.7" +SciMLStructures = "1.1" StaticArrays = "1.7" StaticArraysCore = "1.4" Statistics = "1.10" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index fd396383f..0736503d1 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -5,6 +5,7 @@ if isdefined(Base, :Experimental) && end using ConstructionBase using RecipesBase, RecursiveArrayTools, Tables +using SciMLStructures using SymbolicIndexingInterface using DocStringExtensions using LinearAlgebra diff --git a/src/remake.jl b/src/remake.jl index 291ce9994..c0f05dc8b 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -25,6 +25,10 @@ Re-construct `thing` with new field values specified by the keyword arguments. """ function remake(thing; kwargs...) + _remake_internal(thing; kwargs...) +end + +function _remake_internal(thing; kwargs...) T = remaker_of(thing) if :kwargs ∈ fieldnames(typeof(thing)) if :kwargs ∉ keys(kwargs) @@ -41,6 +45,20 @@ function isrecompile(prob::ODEProblem{iip}) where {iip} (prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true end +function remake(prob::AbstractSciMLProblem; u0 = missing, p = missing, interpret_symbolicmap = true, kwargs...) + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) + _remake_internal(prob; kwargs..., u0, p) +end + +function remake(prob::AbstractNoiseProblem; kwargs...) + _remake_internal(prob; kwargs...) +end + +function remake(prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...) + p = updated_p(prob, p; interpret_symbolicmap) + _remake_internal(prob; kwargs..., p) +end + """ remake(prob::ODEProblem; f = missing, u0 = missing, tspan = missing, p = missing, kwargs = missing, _kwargs...) @@ -59,37 +77,7 @@ function remake(prob::ODEProblem; f = missing, tspan = prob.tspan end - if p === missing && u0 === missing - p, u0 = prob.p, prob.u0 - else # at least one of them has a value - if p === missing - p = prob.p - end - if u0 === missing - u0 = prob.u0 - end - isu0symbolic = eltype(u0) <: Pair && !isempty(u0) - ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap - if isu0symbolic && !has_sys(prob.f) - throw(ArgumentError("This problem does not support symbolic maps with" * - " remake, i.e. it does not have a symbolic origin. Please use `remake`" * - "with the `u0` keyword argument as a vector of values, paying attention to" * - "parameter order.")) - end - if ispsymbolic && !has_sys(prob.f) - throw(ArgumentError("This problem does not support symbolic maps with " * - "`remake`, i.e. it does not have a symbolic origin. Please use `remake`" * - "with the `p` keyword argument as a vector of values (paying attention to" * - "parameter order) or pass `interpret_symbolicmap = false` as a keyword argument")) - end - if isu0symbolic && ispsymbolic - p, u0 = process_p_u0_symbolic(prob, p, u0) - elseif isu0symbolic - _, u0 = process_p_u0_symbolic(prob, prob.p, u0) - elseif ispsymbolic - p, _ = process_p_u0_symbolic(prob, p, prob.u0) - end - end + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) iip = isinplace(prob) @@ -141,21 +129,12 @@ end Remake the given `BVProblem`. """ function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing, - p = missing, kwargs = missing, problem_type = missing, _kwargs...) + p = missing, kwargs = missing, problem_type = missing, interpret_symbolicmap = true, _kwargs...) if tspan === missing tspan = prob.tspan end - if p === missing && u0 === missing - p, u0 = prob.p, prob.u0 - else # at least one of them has a value - if p === missing - p = prob.p - end - if u0 === missing - u0 = prob.u0 - end - end + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) iip = isinplace(prob) @@ -211,6 +190,7 @@ function remake(prob::SDEProblem; p = missing, noise = missing, noise_rate_prototype = missing, + interpret_symbolicmap = true, seed = missing, kwargs = missing, _kwargs...) @@ -218,13 +198,7 @@ function remake(prob::SDEProblem; tspan = prob.tspan end - if p === missing - p = prob.p - end - - if u0 === missing - u0 = prob.u0 - end + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) if noise === missing noise = prob.noise @@ -280,38 +254,8 @@ function remake(prob::OptimizationProblem; kwargs = missing, interpret_symbolicmap = true, _kwargs...) - if p === missing && u0 === missing - p, u0 = prob.p, prob.u0 - else # at least one of them has a value - if p === missing - p = prob.p - end - if u0 === missing - u0 = prob.u0 - end - isu0symbolic = eltype(u0) <: Pair && !isempty(u0) - ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap - if isu0symbolic && !has_sys(prob.f) - throw(ArgumentError("This problem does not support symbolic maps with" * - " remake, i.e. it does not have a symbolic origin. Please use `remake`" * - "with the `u0` keyword argument as a vector of values, paying attention to" * - "parameter order.")) - end - if ispsymbolic && !has_sys(prob.f) - throw(ArgumentError("This problem does not support symbolic maps with " * - "`remake`, i.e. it does not have a symbolic origin. Please use `remake`" * - "with the `p` keyword argument as a vector of values (paying attention to" * - "parameter order) or pass `interpret_symbolicmap = false` as a keyword argument")) - end - if isu0symbolic && ispsymbolic - p, u0 = process_p_u0_symbolic(prob, p, u0) - elseif isu0symbolic - _, u0 = process_p_u0_symbolic(prob, prob.p, u0) - elseif ispsymbolic - p, _ = process_p_u0_symbolic(prob, p, prob.u0) - end - end + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) if f === missing f = prob.f end @@ -362,38 +306,7 @@ function remake(prob::NonlinearProblem; kwargs = missing, interpret_symbolicmap = true, _kwargs...) - if p === missing && u0 === missing - p, u0 = prob.p, prob.u0 - else # at least one of them has a value - if p === missing - p = prob.p - end - if u0 === missing - u0 = prob.u0 - end - isu0symbolic = eltype(u0) <: Pair && !isempty(u0) - ispsymbolic = eltype(p) <: Pair && !isempty(p) && interpret_symbolicmap - if isu0symbolic && !has_sys(prob.f) - throw(ArgumentError("This problem does not support symbolic maps with" * - " remake, i.e. it does not have a symbolic origin. Please use `remke`" * - "with the `u0` keyword argument as a vector of values, paying attention to" * - "parameter order.")) - end - if ispsymbolic && !has_sys(prob.f) - throw(ArgumentError("This problem does not support symbolic maps with " * - "`remake`, i.e. it does not have a symbolic origin. Please use `remake`" * - "with the `p` keyword argument as a vector of values (paying attention to" * - "parameter order) or pass `interpret_symbolicmap = false` as a keyword argument")) - end - if isu0symbolic && ispsymbolic - p, u0 = process_p_u0_symbolic(prob, p, u0) - elseif isu0symbolic - _, u0 = process_p_u0_symbolic(prob, prob.p, u0) - elseif ispsymbolic - p, _ = process_p_u0_symbolic(prob, p, prob.u0) - end - end - + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) if f === missing f = prob.f end @@ -418,17 +331,8 @@ end Remake the given `NonlinearLeastSquaresProblem`. """ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing, - kwargs = missing, _kwargs...) - if p === missing && u0 === missing - p, u0 = prob.p, prob.u0 - else # at least one of them has a value - if p === missing - p = prob.p - end - if u0 === missing - u0 = prob.u0 - end - end + interpret_symbolicmap = true, kwargs = missing, _kwargs...) + u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) if f === missing f = prob.f @@ -442,6 +346,83 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p end end +function updated_u0_p(prob, u0, p; interpret_symbolicmap = true) + newp = updated_p(prob, p; interpret_symbolicmap) + newu0 = updated_u0(prob, u0, p) + return newu0, newp +end + +function updated_u0(prob, u0, p) + if u0 === missing || u0 isa Function + return state_values(prob) + end + if u0 isa Number + return u0 + end + if eltype(u0) <: Pair + u0 = Dict(u0) + else + return u0 + end + if !has_sys(prob.f) + throw(ArgumentError("This problem does not support symbolic maps with" * + " remake, i.e. it does not have a symbolic origin. Please use `remake`" * + "with the `u0` keyword argument as a vector of values, paying attention to the order.")) + end + newu0 = copy(state_values(prob)) + if all(==(NotSymbolic()), symbolic_type.(values(u0))) + setu(prob, collect(keys(u0)))(newu0, collect(values(u0))) + else + value_syms = [k for (k, v) in u0 if symbolic_type(v) === NotSymbolic()] + dependent_syms = [k for (k, v) in u0 if symbolic_type(v) !== NotSymbolic()] + setu(prob, value_syms)(newu0, getindex.((u0,), value_syms)) + obs = SymbolicIndexingInterface.observed(prob, getindex.((u0,), dependent_syms)) + if is_time_dependent(prob) + dependent_vals = obs(newu0, p, current_time(prob)) + else + dependent_vals = obs(newu0, p) + end + setu(prob, dependent_syms)(newu0, dependent_vals) + end + return newu0 +end + +function updated_p(prob, p; interpret_symbolicmap = true) + if p === missing + return parameter_values(prob) + end + if eltype(p) <: Pair + if interpret_symbolicmap + has_sys(prob.f) || throw(ArgumentError("This problem does not support symbolic maps with " * + "`remake`, i.e. it does not have a symbolic origin. Please use `remake`" * + "with the `p` keyword argument as a vector of values (paying attention to" * + "parameter order) or pass `interpret_symbolicmap = false` as a keyword argument")) + else + return p + end + p = Dict(p) + else + return p + end + + newp = copy(parameter_values(prob)) + if all(==(NotSymbolic()), symbolic_type.(values(p))) + setp(prob, collect(keys(p)))(newp, collect(values(p))) + else + value_syms = [k for (k, v) in p if symbolic_type(v) === NotSymbolic()] + dependent_syms = [k for (k, v) in p if symbolic_type(v) !== NotSymbolic()] + setp(prob, value_syms)(newp, getindex.((p,), value_syms)) + obs = SymbolicIndexingInterface.observed(prob, getindex.((p,), dependent_syms)) + if is_time_dependent(prob) + dependent_vals = obs(state_values(prob), newp, current_time(prob)) + else + dependent_vals = obs(state_values(prob), newp) + end + setp(prob, dependent_syms)(newp, dependent_vals) + end + return newp +end + # overloaded in MTK to intercept symbolic remake function process_p_u0_symbolic(prob, p, u0) if prob isa Union{AbstractDEProblem, OptimizationProblem, NonlinearProblem} diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 747b83a7e..2400e6be9 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,6 +1,7 @@ [deps] BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" @@ -19,6 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] BoundaryValueDiffEq = "5" ForwardDiff = "0.10" +JumpProcesses = "9.10" ModelingToolkit = "8.37, 9" NonlinearSolve = "2, 3" Optimization = "3" diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl new file mode 100644 index 000000000..f23648fd6 --- /dev/null +++ b/test/downstream/modelingtoolkit_remake.jl @@ -0,0 +1,165 @@ +using ModelingToolkit, SymbolicIndexingInterface +using JumpProcesses +using ModelingToolkit: t_nounits as t, D_nounits as D + +@parameters σ ρ β +@variables x(t) y(t) z(t) + +eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + +@named sys = ODESystem(eqs, t) +sys = structural_simplify(sys) +u0 = [D(x) => 2.0, + x => 1.0, + y => 0.0, + z => 0.0] + +p = [σ => 28.0, + ρ => 10.0, + β => 8 / 3] + +tspan = (0.0, 100.0) +oprob = ODEProblem(sys, u0, tspan, p, jac = true) + +@inferred typeof(oprob) remake(oprob; u0 = [x => 2.0], p = [σ => 29.0]) +oprob2 = remake( + oprob; + u0 = [x => 2.0, sys.y => 1.2, :z => 1.0], + p = [σ => 29.0, sys.ρ => 11.0, :β => 3.0] +) +@test oprob2.u0 isa Vector{<:Number} +@test oprob2.p isa ModelingToolkit.MTKParameters +@test oprob2[x] == oprob2[sys.x] == oprob2[:x] == 2.0 +@test oprob2[y] == oprob2[sys.y] == oprob2[:y] == 1.2 +@test oprob2[z] == oprob2[sys.z] == oprob2[:z] == 1.0 +@test getp(sys, σ)(oprob2) == 29.0 +@test getp(sys, sys.ρ)(oprob2) == 11.0 +@test getp(sys, :β)(oprob2) == 3.0 + +oprob3 = remake(oprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update +@test oprob3[x] == 3.0 +@test getp(sys, σ)(oprob3) == 30.0 + +# SDEProblem. +noiseeqs = [0.1 * x, + 0.1 * y, + 0.1 * z] +@named noise_sys = SDESystem(sys, noiseeqs) +noise_sys = complete(noise_sys) +sprob = SDEProblem(noise_sys, u0, (0.0, 100.0), p) + +@inferred typeof(sprob) remake(sprob; u0 = [x => 2.0], p = [σ => 29.0]) +sprob2 = remake( + sprob; + u0 = [x => 2.0, sys.y => 1.2, :z => 1.0], + p = [σ => 29.0, sys.ρ => 11.0, :β => 3.0] +) +@test sprob2.u0 isa Vector{<:Number} +@test sprob2.p isa ModelingToolkit.MTKParameters +@test sprob2[x] == sprob2[sys.x] == sprob2[:x] == 2.0 +@test sprob2[y] == sprob2[sys.y] == sprob2[:y] == 1.2 +@test sprob2[z] == sprob2[sys.z] == sprob2[:z] == 1.0 +@test getp(sys, σ)(sprob2) == 29.0 +@test getp(sys, sys.ρ)(sprob2) == 11.0 +@test getp(sys, :β)(sprob2) == 3.0 + +sprob3 = remake(sprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update +@test sprob3[x] == 3.0 +@test getp(sys, σ)(sprob3) == 30.0 + +# DiscreteProblem +# @named de = DiscreteSystem( +# [D(x) ~ σ*(y-x), +# D(y) ~ x*(ρ-z)-y, +# D(z) ~ x*y - β*z], +# t, +# [x, y, z], +# [σ, ρ, β], +# ) +# dprob = DiscreteProblem(de, u0, tspan, p) + +# @inferred typeof(dprob) remake(dprob; u0 = [x => 2.0], p = [σ => 29.0]) +# dprob2 = remake( +# dprob; +# u0 = [x => 2.0, sys.y => 1.2, :z => 1.0], +# p = [σ => 29.0, sys.ρ => 11.0, :β => 3.0] +# ) +# @test dprob2.u0 isa Vector{<:Number} +# @test dprob2.p isa ModelingToolkit.MTKParameters +# @test dprob2[x] == dprob2[sys.x] == dprob2[:x] == 2.0 +# @test dprob2[y] == dprob2[sys.y] == dprob2[:y] == 1.2 +# @test dprob2[z] == dprob2[sys.z] == dprob2[:z] == 1.0 +# @test getp(de, σ)(dprob2) == 29.0 +# @test getp(de, sys.ρ)(dprob2) == 11.0 +# @test getp(de, :β)(dprob2) == 3.0 + +# dprob3 = remake(dprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update +# @test dprob3[x] == 3.0 +# @test getp(de, σ)(dprob3) == 30.0 + +# NonlinearProblem +@named ns = NonlinearSystem( + [0 ~ σ*(y-x), + 0 ~ x*(ρ-z)-y, + 0 ~ x*y - β*z], + [x,y,z], + [σ,ρ,β] +) +ns = complete(ns) +nlprob = NonlinearProblem(ns, u0, p) + +@inferred typeof(nlprob) remake(nlprob; u0 = [x => 2.0], p = [σ => 29.0]) +nlprob2 = remake( + nlprob; + u0 = [x => 2.0, sys.y => 1.2, :z => 1.0], + p = [σ => 29.0, sys.ρ => 11.0, :β => 3.0] +) +@test nlprob2.u0 isa Vector{<:Number} +@test nlprob2.p isa ModelingToolkit.MTKParameters +@test nlprob2[x] == nlprob2[sys.x] == nlprob2[:x] == 2.0 +@test nlprob2[y] == nlprob2[sys.y] == nlprob2[:y] == 1.2 +@test nlprob2[z] == nlprob2[sys.z] == nlprob2[:z] == 1.0 +@test getp(ns, σ)(nlprob2) == 29.0 +@test getp(ns, sys.ρ)(nlprob2) == 11.0 +@test getp(ns, :β)(nlprob2) == 3.0 + +nlprob3 = remake(nlprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update +@test nlprob3[x] == 3.0 +@test getp(ns, σ)(nlprob3) == 30.0 + +@parameters β γ +@variables S(t) I(t) R(t) +rate₁ = β*S*I +affect₁ = [S ~ S - 1, I ~ I + 1] +rate₂ = γ*I +affect₂ = [I ~ I - 1, R ~ R + 1] +j₁ = ConstantRateJump(rate₁,affect₁) +j₂ = ConstantRateJump(rate₂,affect₂) +j₃ = MassActionJump(2*β+γ, [R => 1], [S => 1, R => -1]) +@named js = JumpSystem([j₁,j₂,j₃], t, [S,I,R], [β,γ]) +js = complete(js) +u₀map = [S => 999, I => 1, R => 0.0] +parammap = [β => 0.1 / 1000, γ => 0.01] +tspan = (0.0, 250.0) +jump_dprob = DiscreteProblem(js, u₀map, tspan, parammap) +jprob = JumpProblem(js, jump_dprob, Direct()) + +@inferred typeof(jprob) remake(jprob; u0 = [S => 900], p = [β => 0.2e-3]) +jprob2 = remake( + jprob; + u0 = [S => 900, js.I => 2, :R => 0.1], + p = [β => 0.2 / 1000, js.γ => 11.0] +) +@test jprob2.prob.u0 isa Vector{<:Number} +@test jprob2.prob.p isa ModelingToolkit.MTKParameters +@test jprob2[S] == jprob2[js.S] == jprob2[:S] == 900.0 +@test jprob2[I] == jprob2[js.I] == jprob2[:I] == 2.0 +@test jprob2[R] == jprob2[js.R] == jprob2[:R] == 0.1 +@test getp(js, β)(jprob2) == 0.2 / 1000 +@test getp(js, js.γ)(jprob2) == 11.0 + +jprob3 = remake(jprob; u0 = [S => 901], p = [:β => 0.3 / 1000]) # partial update +@test jprob3[S] == 901 +@test getp(js, β)(jprob3) == 0.3 / 1000 diff --git a/test/remake_tests.jl b/test/remake_tests.jl new file mode 100644 index 000000000..6b0937297 --- /dev/null +++ b/test/remake_tests.jl @@ -0,0 +1,131 @@ +using SciMLBase +using SymbolicIndexingInterface + +# ODE +function lorenz!(du,u,p,t) + du[1] = p[1] * (u[2]-u[1]) + du[2] = u[1]*(p[2]-u[3]) - u[2] + du[3] = u[1]*u[2] - p[3]*u[3] +end +u0 = [1.0;0.0;0.0] +tspan = (0.0,100.0) +p = [10.0, 28.0, 8/3] +fn = ODEFunction(lorenz!; sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)) +prob = ODEProblem(fn, u0, tspan, p) + +@test remake(prob).u0 == u0 +@test remake(prob).p == p +@test remake(prob; u0 = [2.0, 3.0, 4.0]).u0 == [2.0, 3.0, 4.0] +@test remake(prob; u0 = [:x => 2.0, :z => 4.0, :y => 3.0]).u0 == [2.0, 3.0, 4.0] +@test remake(prob; p = [11.0, 12.0, 13.0]).p == [11.0, 12.0, 13.0] +@test remake(prob; p = [:a => 11.0, :c => 13.0, :b => 12.0]).p == [11.0, 12.0, 13.0] +@test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13) +@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0] +@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 8/3] + +# BVP +g = 9.81 +L = 1.0 +tspan = (0.0, pi / 2) +function simplependulum!(du, u, p, t) + θ = u[1] + dθ = u[2] + du[1] = dθ + du[2] = -(p[1] / p[2]) * sin(θ) +end +function bc1!(residual, u, p, t) + residual[1] = u[end ÷ 2][1] + pi / 2 # the solution at the middle of the time span should be -pi/2 + residual[2] = u[end][1] - pi / 2 # the solution at the end of the time span should be pi/2 +end +u0 = [pi / 2, pi / 2] +p = [g, L] +fn = BVPFunction(simplependulum!, bc1!; sys = SymbolCache([:x, :y], [:a, :b], :t) ) +prob = BVProblem(fn, u0, tspan, p) + +@test remake(prob).u0 == u0 +@test remake(prob).p == p +@test remake(prob; u0 = [2.0, 3.0]).u0 == [2.0, 3.0] +@test remake(prob; u0 = [:x => 2.0, :y => 3.0]).u0 == [2.0, 3.0] +@test remake(prob; p = [11.0, 12.0]).p == [11.0, 12.0] +@test remake(prob; p = [:a => 11.0, :b => 12.0]).p == [11.0, 12.0] +@test remake(prob; p = (11.0, 12.0)).p == (11.0, 12.0) +@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, pi / 2] +@test remake(prob; p = [:b => 11.0]).p == [g, 11.0] + +# SDE +function sdef(du, u, p, t) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] +end +function sdeg(du, u, p, t) + du .= 0.1u +end + +u0 = [1.0, 0.0, 0.0] +p = [10, 2.33, 26] +tspan = (0, 100) +fn = SDEFunction(sdef, sdeg; sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)) +prob = SDEProblem(fn, u0, tspan, p) + +@test remake(prob).u0 == u0 +@test remake(prob).p == p +@test remake(prob; u0 = [2.0, 3.0, 4.0]).u0 == [2.0, 3.0, 4.0] +@test remake(prob; u0 = [:x => 2.0, :z => 4.0, :y => 3.0]).u0 == [2.0, 3.0, 4.0] +@test remake(prob; p = [11.0, 12.0, 13.0]).p == [11.0, 12.0, 13.0] +@test remake(prob; p = [:a => 11.0, :c => 13.0, :b => 12.0]).p == [11.0, 12.0, 13.0] +@test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13) +@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0] +@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 26.0] + +# OptimizationProblem +function loss(u, p) + return (p[1] - u[1]) ^ 2 + p[2] * (u[2] - u[1] ^ 2) ^ 2 +end +u0 = [1.0, 2.0] +p = [1.0, 100.0] +fn = OptimizationFunction(loss; sys = SymbolCache([:x, :y], [:a, :b], :t) ) +prob = OptimizationProblem(fn, u0, p) +@test remake(prob).u0 == u0 +@test remake(prob).p == p +@test remake(prob; u0 = [2.0, 3.0]).u0 == [2.0, 3.0] +@test remake(prob; u0 = [:x => 2.0, :y => 3.0]).u0 == [2.0, 3.0] +@test remake(prob; p = [11.0, 12.0]).p == [11.0, 12.0] +@test remake(prob; p = [:a => 11.0, :b => 12.0]).p == [11.0, 12.0] +@test remake(prob; p = (11.0, 12.0)).p == (11.0, 12.0) +@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 2.0] +@test remake(prob; p = [:b => 11.0]).p == [1.0, 11.0] + +# NonlinearProblem +function nlf(du, u, p) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] +end +u0 = [1.0, 0.0, 0.0] +p = [10, 2.33, 26] +fn = NonlinearFunction(nlf; sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)) +prob = NonlinearProblem(fn, u0, p) + +@test remake(prob).u0 == u0 +@test remake(prob).p == p +@test remake(prob; u0 = [2.0, 3.0, 4.0]).u0 == [2.0, 3.0, 4.0] +@test remake(prob; u0 = [:x => 2.0, :z => 4.0, :y => 3.0]).u0 == [2.0, 3.0, 4.0] +@test remake(prob; p = [11.0, 12.0, 13.0]).p == [11.0, 12.0, 13.0] +@test remake(prob; p = [:a => 11.0, :c => 13.0, :b => 12.0]).p == [11.0, 12.0, 13.0] +@test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13) +@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0] +@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 26.0] + +# NonlinearLeastSquaresProblem +prob = NonlinearLeastSquaresProblem(fn, u0, p) +@test remake(prob).u0 == u0 +@test remake(prob).p == p +@test remake(prob; u0 = [2.0, 3.0, 4.0]).u0 == [2.0, 3.0, 4.0] +@test remake(prob; u0 = [:x => 2.0, :z => 4.0, :y => 3.0]).u0 == [2.0, 3.0, 4.0] +@test remake(prob; p = [11.0, 12.0, 13.0]).p == [11.0, 12.0, 13.0] +@test remake(prob; p = [:a => 11.0, :c => 13.0, :b => 12.0]).p == [11.0, 12.0, 13.0] +@test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13) +@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0] +@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 26.0] + diff --git a/test/runtests.jl b/test/runtests.jl index c8c11626c..3fcbc36b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,6 +64,9 @@ end @time @safetestset "Problem building tests" begin include("problem_building_test.jl") end + @time @safetestset "Remake" begin + include("remake_tests.jl") + end end if !is_APPVEYOR && GROUP == "Downstream" @@ -103,6 +106,9 @@ end @time @safetestset "Partial Functions" begin include("downstream/partial_functions.jl") end + @time @safetestset "ModelingToolkit Remake" begin + include("downstream/modelingtoolkit_remake.jl") + end end if !is_APPVEYOR && (GROUP == "Downstream" || GROUP == "SymbolicIndexingInterface")