From ae62632c56de84d1dd97d83192644cd568036379 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 15:57:05 +0530 Subject: [PATCH 01/30] feat: add `constructorof` for `SDEProblem` --- src/problems/sde_problems.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/problems/sde_problems.jl b/src/problems/sde_problems.jl index 60017f740..6592940a1 100644 --- a/src/problems/sde_problems.jl +++ b/src/problems/sde_problems.jl @@ -125,6 +125,17 @@ function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...) SDEProblem{iip}(SDEFunction{iip}(f, g), u0, tspan, p; kwargs...) end +function ConstructionBase.constructorof(::Type{P}) where {P <: SDEProblem} + function ctor(f, g, u0, tspan, p, noise, kw, noise_rate_prototype, seed) + if f isa AbstractSDEFunction + iip = isinplace(f) + else + iip = isinplace(f, 4) + end + return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed) + end +end + """ $(TYPEDEF) """ From 97e1725de9afb298ee9b17bd7c88ccfb2e31ffe6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 15:57:17 +0530 Subject: [PATCH 02/30] feat: add `constructorof` for `SDDEProblem` --- src/problems/sdde_problems.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/problems/sdde_problems.jl b/src/problems/sdde_problems.jl index cfcf1aec5..31d239c36 100644 --- a/src/problems/sdde_problems.jl +++ b/src/problems/sdde_problems.jl @@ -157,3 +157,17 @@ end function SDDEProblem(f::AbstractSDDEFunction, args...; kwargs...) SDDEProblem{isinplace(f)}(f, args...; kwargs...) end + +function ConstructionBase.constructorof(::Type{P}) where {P <: SDDEProblem} + function ctor(f, g, u0, h, tspan, p, noise, constant_lags, dependent_lags, kw, + noise_rate_prototype, seed, neutral, order_discontinuity_t0) + if f isa AbstractSDDEFunction + iip = isinplace(f) + else + iip = isinplace(f, 5) + end + return SDDEProblem{iip}( + f, g, u0, h, tspan, p; kw..., noise, constant_lags, dependent_lags, + noise_rate_prototype, seed, neutral, order_discontinuity_t0) + end +end From 73c5d7dcc3b6048afb7d921603edb5e5150b8043 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 15:57:31 +0530 Subject: [PATCH 03/30] feat: add `constructorof` for `DDEProblem` --- src/problems/dde_problems.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/problems/dde_problems.jl b/src/problems/dde_problems.jl index a93c1d320..641b1651b 100644 --- a/src/problems/dde_problems.jl +++ b/src/problems/dde_problems.jl @@ -253,6 +253,19 @@ struct DDEProblem{uType, tType, lType, lType2, isinplace, P, F, H, K, PT} <: end end +function ConstructionBase.constructorof(::Type{P}) where {P <: DDEProblem} + function ctor(f, u0, h, tspan, p, constant_lags, dependent_lags, + kw, neutral, order_discontinuity_t0, problem_type) + if f isa AbstractDDEFunction + iip = isinplace(f) + else + iip = isinplace(f, 5) + end + return DDEProblem{iip}(f, u0, h, tspan, p; kw..., constant_lags, dependent_lags, + neutral, order_discontinuity_t0, problem_type) + end +end + DDEProblem(f, args...; kwargs...) = DDEProblem(DDEFunction(f), args...; kwargs...) function DDEProblem(f::AbstractDDEFunction, args...; kwargs...) From 69891573076459490fdeda27e8da27c3a0e6158f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 15:58:55 +0530 Subject: [PATCH 04/30] feat: add proper remake for `SDEProblem` --- src/SciMLBase.jl | 2 +- src/remake.jl | 25 +++++++++++++++---------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 99495d93f..f9636caa0 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -22,7 +22,7 @@ import FunctionWrappersWrappers import RuntimeGeneratedFunctions import EnumX import ADTypes: ADTypes, AbstractADType -import Accessors: @set, @reset +import Accessors: @set, @reset, @delete using Expronicon.ADT: @match using Reexport diff --git a/src/remake.jl b/src/remake.jl index d3db11b9c..5b1f69d91 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -354,13 +354,13 @@ function remake(prob::SDEProblem; if prob.f isa SDEFunction f = remake(prob.f; g = g) else - f = SDEFunction(prob.f, g) + f = SDEFunction(prob.f, g; sys = prob.f.sys) end else if f isa SDEFunction f = remake(f; g = g) else - f = SDEFunction(f, g) + f = SDEFunction(f, g; sys = prob.f.sys) end end @@ -387,12 +387,20 @@ end Remake the given `SDEFunction`. """ -function remake(func::SDEFunction; +function remake(func::Union{SDEFunction, SDDEFunction}; f = missing, g = missing, mass_matrix = missing, analytic = missing, + sys = missing, kwargs...) + props = getproperties(func) + props = @delete props.f + props = @delete props.g + @reset props.mass_matrix = coalesce(mass_matrix, func.mass_matrix) + @reset props.analytic = coalesce(analytic, func.analytic) + @reset props.sys = coalesce(sys, func.sys) + if f === missing f = func.f end @@ -401,15 +409,12 @@ function remake(func::SDEFunction; g = func.g end - if mass_matrix === missing - mass_matrix = func.mass_matrix - end - - if analytic === missing - analytic = func.analytic + if f isa AbstractSciMLFunction + f = f.f end - return SDEFunction(f, g; mass_matrix, analytic, kwargs...) + T = func isa SDEFunction ? SDEFunction : SDDEFunction + return T{isinplace(func)}(f, g; props..., kwargs...) end """ From 0e2a089ae069f08d2b0dc25f4b953ea75fc5da82 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 15:59:26 +0530 Subject: [PATCH 05/30] feat: run `remake_initialization_data` when remaking `SDEProblem` --- src/remake.jl | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 5b1f69d91..3875088bd 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -326,12 +326,20 @@ function remake(prob::SDEProblem; use_defaults = false, seed = missing, kwargs = missing, + build_initializeprob = true, _kwargs...) if tspan === missing tspan = prob.tspan end - u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults) + newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults) + + if build_initializeprob + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) + else + initialization_data = nothing + end if noise === missing noise = prob.noise @@ -363,21 +371,21 @@ function remake(prob::SDEProblem; f = SDEFunction(f, g; sys = prob.f.sys) end end - + f = remake(f; initialization_data) iip = isinplace(prob) if kwargs === missing SDEProblem{iip}(f, - u0, + newu0, tspan, - p; + newp; noise, noise_rate_prototype, seed, prob.kwargs..., _kwargs...) else - SDEProblem{iip}(f, u0, tspan, p; noise, noise_rate_prototype, seed, kwargs...) + SDEProblem{iip}(f, newu0, tspan, newp; noise, noise_rate_prototype, seed, kwargs...) end end From 1980fa6d094de9bcf55f04b02d38524a1b596bd2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:00:19 +0530 Subject: [PATCH 06/30] feat: add proper `remake` for `DDEProblem` --- src/remake.jl | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/src/remake.jl b/src/remake.jl index 3875088bd..9ae5767b8 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -425,6 +425,79 @@ function remake(func::Union{SDEFunction, SDDEFunction}; return T{isinplace(func)}(f, g; props..., kwargs...) end +function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing, + tspan = missing, p = missing, constant_lags = missing, + dependent_lags = missing, order_discontinuity_t0 = missing, + neutral = missing, kwargs = missing, interpret_symbolicmap = true, + use_defaults = false, build_initializeprob = true, _kwargs...) + if tspan === missing + tspan = prob.tspan + end + + newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults) + + if build_initializeprob + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) + else + initialization_data = nothing + end + + if f === missing + f = prob.f + elseif !(f isa DDEFunction) + f = remake(prob.f; f = f) + end + f = remake(f; initialization_data) + + h = coalesce(h, prob.h) + constant_lags = coalesce(constant_lags, prob.constant_lags) + dependent_lags = coalesce(dependent_lags, prob.dependent_lags) + order_discontinuity_t0 = coalesce(order_discontinuity_t0, prob.order_discontinuity_t0) + neutral = coalesce(neutral, prob.neutral) + + iip = isinplace(prob) + + if kwargs === missing + DDEProblem{iip}(f, + newu0, + h, + tspan, + newp; + constant_lags, + dependent_lags, + order_discontinuity_t0, + neutral, + prob.kwargs..., + _kwargs...) + else + DDEProblem{iip}(f, newu0, h, tspan, newp; constant_lags, dependent_lags, + order_discontinuity_t0, neutral, kwargs...) + end +end + +function remake(func::DDEFunction; + f = missing, + mass_matrix = missing, + analytic = missing, + sys = missing, + kwargs...) + props = getproperties(func) + props = @delete props.f + @reset props.mass_matrix = coalesce(mass_matrix, func.mass_matrix) + @reset props.analytic = coalesce(analytic, func.analytic) + @reset props.sys = coalesce(sys, func.sys) + + if f === missing + f = func.f + end + if f isa AbstractSciMLFunction + f = f.f + end + + return DDEFunction{isinplace(func)}(f; props..., kwargs...) +end + """ remake(prob::OptimizationProblem; f = missing, u0 = missing, p = missing, lb = missing, ub = missing, int = missing, lcons = missing, ucons = missing, From 555cf5d5b6c3aa2dd577330672d91602eb1c3e2f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:00:41 +0530 Subject: [PATCH 07/30] feat: add proper `remake` for `SDDEProblem` --- src/remake.jl | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/src/remake.jl b/src/remake.jl index 9ae5767b8..7f4ff3c00 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -498,6 +498,100 @@ function remake(func::DDEFunction; return DDEFunction{isinplace(func)}(f; props..., kwargs...) end +function remake(prob::SDDEProblem; + f = missing, + g = missing, + h = missing, + u0 = missing, + tspan = missing, + p = missing, + constant_lags = missing, + dependent_lags = missing, + order_discontinuity_t0 = missing, + neutral = missing, + noise = missing, + noise_rate_prototype = missing, + interpret_symbolicmap = true, + use_defaults = false, + seed = missing, + kwargs = missing, + build_initializeprob = true, + _kwargs...) + if tspan === missing + tspan = prob.tspan + end + + newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults) + + if build_initializeprob + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) + else + initialization_data = nothing + end + + if noise === missing + noise = prob.noise + end + + if noise_rate_prototype === missing + noise_rate_prototype = prob.noise_rate_prototype + end + + if seed === missing + seed = prob.seed + end + + if f === missing && g === missing + f = prob.f + g = prob.g + elseif f !== missing && g === missing + g = prob.g + elseif f === missing && g !== missing + if prob.f isa SDEFunction + f = remake(prob.f; g = g) + else + f = SDEFunction(prob.f, g; sys = prob.f.sys) + end + else + if f isa SDEFunction + f = remake(f; g = g) + else + f = SDEFunction(f, g; sys = prob.f.sys) + end + end + f = remake(f; initialization_data) + iip = isinplace(prob) + + h = coalesce(h, prob.h) + constant_lags = coalesce(constant_lags, prob.constant_lags) + dependent_lags = coalesce(dependent_lags, prob.dependent_lags) + order_discontinuity_t0 = coalesce(order_discontinuity_t0, prob.order_discontinuity_t0) + neutral = coalesce(neutral, prob.neutral) + + if kwargs === missing + SDDEProblem{iip}(f, + g, + newu0, + h, + tspan, + newp; + noise, + noise_rate_prototype, + seed, + constant_lags, + dependent_lags, + order_discontinuity_t0, + neutral, + prob.kwargs..., + _kwargs...) + else + SDDEProblem{iip}( + f, newu0, tspan, newp; noise, noise_rate_prototype, seed, constant_lags, + dependent_lags, order_discontinuity_t0, neutral, kwargs...) + end +end + """ remake(prob::OptimizationProblem; f = missing, u0 = missing, p = missing, lb = missing, ub = missing, int = missing, lcons = missing, ucons = missing, From d0b31ca025df09501ca0d6bbf2eaae8d3e602952 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:02:30 +0530 Subject: [PATCH 08/30] fix: support non-markovian index providers in `updated_u0_p` --- src/remake.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 7f4ff3c00..489839808 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -920,7 +920,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0) # FIXME: need to provide `u` since the observed function expects it. # This is sort of an implicit dependency on MTK. The values of `u` won't actually be # used, since any state symbols in the expression were substituted out earlier. - temp_state = ProblemState(; u = state_values(prob), p = p, t = t0) + temp_state = ProblemState(; u = state_values(prob), p = p, t = t0, + h = is_markovian(prob) ? nothing : get_history_function(prob)) for (k, v) in u0 u0[k] = symbolic_type(v) === NotSymbolic() ? v : getsym(prob, v)(temp_state) end @@ -944,7 +945,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0) # FIXME: need to provide `p` since the observed function expects an `MTKParameters` # this is sort of an implicit dependency on MTK. The values of `p` won't actually be # used, since any parameter symbols in the expression were substituted out earlier. - temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0) + temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0, + h = is_markovian(prob) ? nothing : get_history_function(prob)) for (k, v) in p p[k] = symbolic_type(v) === NotSymbolic() ? v : getsym(prob, v)(temp_state) end From abf7ff6f6cfc84bc50930439f99913469507d5b0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:11:55 +0530 Subject: [PATCH 09/30] feat: implement `get_history_function` for `AbstractSDDEProblem` --- src/problems/sdde_problems.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/problems/sdde_problems.jl b/src/problems/sdde_problems.jl index 31d239c36..abbee7c9a 100644 --- a/src/problems/sdde_problems.jl +++ b/src/problems/sdde_problems.jl @@ -171,3 +171,5 @@ function ConstructionBase.constructorof(::Type{P}) where {P <: SDDEProblem} noise_rate_prototype, seed, neutral, order_discontinuity_t0) end end + +SymbolicIndexingInterface.get_history_function(prob::AbstractSDDEProblem) = prob.h From 071ddeeb7ad4fa4773d216adb4d4d9a3f8efe311 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:12:40 +0530 Subject: [PATCH 10/30] feat: implement `get_history_function` for `AbstractSDDEIntegrator` --- src/integrator_interface.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index b38790410..f3a14bbea 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -925,6 +925,7 @@ function isadaptive(integrator::DEIntegrator) isdefined(integrator.opts, :adaptive) ? integrator.opts.adaptive : false end -function SymbolicIndexingInterface.get_history_function(integ::AbstractDDEIntegrator) +function SymbolicIndexingInterface.get_history_function(integ::Union{ + AbstractDDEIntegrator, AbstractSDDEIntegrator}) DDESolutionHistoryWrapper(get_sol(integ)) end From 0488aaf368cbb6e9aff1c5a05bcb00008a875713 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 11:47:58 +0530 Subject: [PATCH 11/30] build: bump SymbolicIndexingInterface compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1d004d399..cc4131cad 100644 --- a/Project.toml +++ b/Project.toml @@ -88,7 +88,7 @@ StableRNGs = "1.0" StaticArrays = "1.7" StaticArraysCore = "1.4" Statistics = "1.10" -SymbolicIndexingInterface = "0.3.34" +SymbolicIndexingInterface = "0.3.36" Tables = "1.11" Zygote = "0.6.67" julia = "1.10" From 86d61b4fa3a0e33a9b3a2c5ad046dd6664b59500 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 12:45:17 +0530 Subject: [PATCH 12/30] fix: fix type stability of `remake(::SDEProblem)` --- src/remake.jl | 55 +++++++++------------------------------------------ 1 file changed, 9 insertions(+), 46 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 489839808..cc6781a74 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -352,26 +352,9 @@ function remake(prob::SDEProblem; if seed === missing seed = prob.seed end - - if f === missing && g === missing - f = prob.f - g = prob.g - elseif f !== missing && g === missing - g = prob.g - elseif f === missing && g !== missing - if prob.f isa SDEFunction - f = remake(prob.f; g = g) - else - f = SDEFunction(prob.f, g; sys = prob.f.sys) - end - else - if f isa SDEFunction - f = remake(f; g = g) - else - f = SDEFunction(f, g; sys = prob.f.sys) - end - end - f = remake(f; initialization_data) + f = coalesce(f, prob.f) + g = coalesce(g, prob.g) + f = remake(prob.f; f, g, initialization_data) iip = isinplace(prob) if kwargs === missing @@ -443,12 +426,8 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing, initialization_data = nothing end - if f === missing - f = prob.f - elseif !(f isa DDEFunction) - f = remake(prob.f; f = f) - end - f = remake(f; initialization_data) + f = coalesce(f, prob.f) + f = remake(prob.f; f, initialization_data) h = coalesce(h, prob.h) constant_lags = coalesce(constant_lags, prob.constant_lags) @@ -542,25 +521,9 @@ function remake(prob::SDDEProblem; seed = prob.seed end - if f === missing && g === missing - f = prob.f - g = prob.g - elseif f !== missing && g === missing - g = prob.g - elseif f === missing && g !== missing - if prob.f isa SDEFunction - f = remake(prob.f; g = g) - else - f = SDEFunction(prob.f, g; sys = prob.f.sys) - end - else - if f isa SDEFunction - f = remake(f; g = g) - else - f = SDEFunction(f, g; sys = prob.f.sys) - end - end - f = remake(f; initialization_data) + f = coalesce(f, prob.f) + g = coalesce(g, prob.g) + f = remake(prob.f; f, g, initialization_data) iip = isinplace(prob) h = coalesce(h, prob.h) @@ -587,7 +550,7 @@ function remake(prob::SDDEProblem; _kwargs...) else SDDEProblem{iip}( - f, newu0, tspan, newp; noise, noise_rate_prototype, seed, constant_lags, + f, g, newu0, tspan, newp; noise, noise_rate_prototype, seed, constant_lags, dependent_lags, order_discontinuity_t0, neutral, kwargs...) end end From c30969371be31579cdab240f5fc8cd9f35527e5a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 12:45:33 +0530 Subject: [PATCH 13/30] test: test `remake` for `DDEProblem`, `SDDEProblem` --- test/remake_tests.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/remake_tests.jl b/test/remake_tests.jl index dd6fdae5f..0a9b156be 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -20,6 +20,21 @@ for T in containerTypes push!(probs, ODEProblem(fn, u0, tspan, T(p))) end +function ddelorenz!(du, u, h, p, t) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] +end + +function history(p, t) + return u0 .- t +end + +fn = DDEFunction(ddelorenz!; sys) +for T in containerTypes + push!(probs, DDEProblem(fn, u0, history, tspan, T(p))) +end + function residual!(resid, u, p, t) resid[1] = u[1] - 0.5 resid[2] = u[2] - 0.5 @@ -38,6 +53,11 @@ for T in containerTypes push!(probs, SDEProblem(fn, u0, tspan, T(p))) end +fn = SDDEFunction(ddelorenz!, noise!; sys) +for T in containerTypes + push!(probs, SDDEProblem(fn, noise!, u0, history, tspan, T(p))) +end + function loss(x, p) du = similar(x) lorenz!(du, u, p, 0.0) From 071723fcfad59a20c4d3188d7009229268fae1a4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 13:37:43 +0530 Subject: [PATCH 14/30] fix: add `.initializeprob` syntax to all applicable SciMLFunctions --- src/scimlfunctions.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 4d0442ecd..c18868c8f 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4678,9 +4678,10 @@ end SymbolicIndexingInterface.constant_structure(::AbstractSciMLFunction) = true -function Base.getproperty(x::Union{ODEFunction, SplitFunction, DAEFunction}, sym::Symbol) - if sym == :initializeprob || sym == :update_initializeprob! || - sym == :initializeprobmap || sym == :initializeprobpmap +function Base.getproperty(x::AbstractSciMLFunction, sym::Symbol) + if __has_initialization_data(x) && + (sym == :initializeprob || sym == :update_initializeprob! || + sym == :initializeprobmap || sym == :initializeprobpmap) if x.initialization_data === nothing return nothing else From 68f74f28309315718fc2b05c3507a12239598d07 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 13:38:23 +0530 Subject: [PATCH 15/30] refactor: update `remake_initializeprob` fallback --- src/remake.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index cc6781a74..47340a596 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -213,11 +213,12 @@ if it does. Note that `u0` or `p` may be `missing` if the user does not provide a value for them. """ function remake_initializeprob(sys, scimlfn, u0, t0, p) - if !has_initializeprob(scimlfn) + if !has_initialization_data(scimlfn) return nothing, nothing, nothing, nothing end - return scimlfn.initializeprob, - scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap + initdata = scimlfn.initialization_data + return initdata.initializeprob, initdata.update_initializeprob!, + initdata.initializeprobmap, initdata.initializeprobpmap end """ From 041c5394da291a7664e5d35c9c1e2d8e27f04ecb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 13:39:53 +0530 Subject: [PATCH 16/30] test: fix DDE indexing test --- test/downstream/comprehensive_indexing.jl | 31 +++++------------------ 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl index f5c519ce1..5388c31c0 100644 --- a/test/downstream/comprehensive_indexing.jl +++ b/test/downstream/comprehensive_indexing.jl @@ -926,7 +926,7 @@ end @testset "DDEs" begin function oscillator(; name, k = 1.0, τ = 0.01) @parameters k=k τ=τ - @variables x(..)=0.1 y(t)=0.1 jcn(t)=0.0 delx(t) + @variables x(..)=0.1 + t y(t)=0.1 + t jcn(t)=0.0 + t delx(t) eqs = [D(x(t)) ~ y, D(y) ~ -k * x(t - τ) + jcn, delx ~ x(t - τ)] @@ -942,32 +942,13 @@ end @named coupledOsc = compose(coupledOsc, systems) sys = structural_simplify(coupledOsc) prob = DDEProblem(sys, [], (0.0, 10.0); constant_lags = [sys.osc1.τ, sys.osc2.τ]) - # TODO: Remove this hack once MTK can generate appropriate observed functions - fn = prob.f - function fake_observed(_) - return function obsfn(u, h, p, t) - return u + h(p, t - 0.1) - end - end - - struct NonMarkovianWrapper{S} - sys::S - end - SymbolicIndexingInterface.symbolic_container(x::NonMarkovianWrapper) = x.sys - SymbolicIndexingInterface.is_markovian(::NonMarkovianWrapper) = false - fn = DDEFunction(fn.f; observed = fake_observed, sys = NonMarkovianWrapper(fn.sys)) - function fake_hist(p, t) - return ones(length(prob.u0)) .* t - end - prob = DDEProblem( - fn, prob.u0, fake_hist, prob.tspan, prob.p; constant_lags = prob.constant_lags) sym = sys.osc1.delx - @test prob[sym] ≈ prob.u0 .+ (prob.tspan[1] - 0.1) + delay = sys.osc1.τ + original = sys.osc1.x + @test prob[sym] ≈ prob[original] .+ (prob.tspan[1] - prob.ps[delay]) integ = init(prob, MethodOfSteps(Tsit5())) step!(integ, 10.0, true) - # DelayDiffEq wraps `integ.f` and that doesn't contain `.observed` - # so the hack above doesn't work. `@reset` also doesn't work. - @test_broken integ[sym] ≈ integ.u + SciMLBase.get_sol(integ)(9.9) + @test integ[sym] ≈ SciMLBase.get_sol(integ)(integ.t - integ.ps[delay]; idxs = original) sol = solve(prob, MethodOfSteps(Tsit5())) - @test sol[sym] ≈ sol.u .+ sol(sol.t .- 0.1).u + @test sol[sym] ≈ sol(sol.t .- sol.ps[delay]; idxs = original) end From 7bffc07e0b53607e3e1b68ed042039e396067762 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 13:40:04 +0530 Subject: [PATCH 17/30] test: test SDDE indexing --- test/downstream/comprehensive_indexing.jl | 33 ++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl index 5388c31c0..0d59c8084 100644 --- a/test/downstream/comprehensive_indexing.jl +++ b/test/downstream/comprehensive_indexing.jl @@ -1,7 +1,7 @@ using ModelingToolkit, JumpProcesses, LinearAlgebra, NonlinearSolve, Optimization, OptimizationOptimJL, OrdinaryDiffEq, RecursiveArrayTools, SciMLBase, SteadyStateDiffEq, StochasticDiffEq, DelayDiffEq, SymbolicIndexingInterface, - DiffEqCallbacks, Test, Plots + DiffEqCallbacks, StochasticDelayDiffEq, Test, Plots using ModelingToolkit: t_nounits as t, D_nounits as D # Sets rnd number. @@ -952,3 +952,34 @@ end sol = solve(prob, MethodOfSteps(Tsit5())) @test sol[sym] ≈ sol(sol.t .- sol.ps[delay]; idxs = original) end + +@testset "SDDEs" begin + function oscillator(; name, k = 1.0, τ = 0.01) + @parameters k=k τ=τ + @brownian a + @variables x(..)=0.1 + t y(t)=0.1 + t jcn(t)=0.0 + t delx(t) + eqs = [D(x(t)) ~ y + a, + D(y) ~ -k * x(t - τ) + jcn, + delx ~ x(t - τ)] + return System(eqs, t; name = name) + end + systems = @named begin + osc1 = oscillator(k = 1.0, τ = 0.01) + osc2 = oscillator(k = 2.0, τ = 0.04) + end + eqs = [osc1.jcn ~ osc2.delx, + osc2.jcn ~ osc1.delx] + @named coupledOsc = System(eqs, t) + @named coupledOsc = compose(coupledOsc, systems) + sys = structural_simplify(coupledOsc) + prob = SDDEProblem(sys, [], (0.0, 10.0); constant_lags = [sys.osc1.τ, sys.osc2.τ]) + sym = sys.osc1.delx + delay = sys.osc1.τ + original = sys.osc1.x + @test prob[sym] ≈ prob[original] .+ (prob.tspan[1] - prob.ps[delay]) + integ = init(prob, ImplicitEM()) + step!(integ, 10.0, true) + @test integ[sym] ≈ SciMLBase.get_sol(integ)(integ.t - integ.ps[delay]; idxs = original) + sol = solve(prob, ImplicitEM()) + @test sol[sym] ≈ sol(sol.t .- sol.ps[delay]; idxs = original) +end From b75e3a43e064ccd8c6fdfef73421bc969c802c68 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 16:24:46 +0530 Subject: [PATCH 18/30] build: bump MTK compat in downstream CI --- test/downstream/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 4b09aecdd..7e01ecb53 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -33,7 +33,7 @@ DelayDiffEq = "5" DiffEqCallbacks = "3, 4" ForwardDiff = "0.10" JumpProcesses = "9.10" -ModelingToolkit = "9.52" +ModelingToolkit = "9.56" ModelingToolkitStandardLibrary = "2.7" NonlinearSolve = "2, 3, 4" Optimization = "4" From 1624c01cad529415e5872fae16f8b17e1a37b0cf Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 13:35:27 +0530 Subject: [PATCH 19/30] feat: add lazy initialization to new `remake` methods --- src/remake.jl | 54 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 47340a596..5a0110eea 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -327,6 +327,7 @@ function remake(prob::SDEProblem; use_defaults = false, seed = missing, kwargs = missing, + lazy_initialization = nothing, build_initializeprob = true, _kwargs...) if tspan === missing @@ -358,7 +359,7 @@ function remake(prob::SDEProblem; f = remake(prob.f; f, g, initialization_data) iip = isinplace(prob) - if kwargs === missing + prob = if kwargs === missing SDEProblem{iip}(f, newu0, tspan, @@ -371,6 +372,20 @@ function remake(prob::SDEProblem; else SDEProblem{iip}(f, newu0, tspan, newp; noise, noise_rate_prototype, seed, kwargs...) end + if lazy_initialization === nothing + lazy_initialization = !is_trivial_initialization(initialization_data) + end + if !lazy_initialization + u0, p, _ = get_initial_values( + prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) + if u0 !== nothing && eltype(u0) == Any && isempty(u0) + u0 = nothing + end + @reset prob.u0 = u0 + @reset prob.p = p + end + + return prob end """ @@ -413,7 +428,8 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing, tspan = missing, p = missing, constant_lags = missing, dependent_lags = missing, order_discontinuity_t0 = missing, neutral = missing, kwargs = missing, interpret_symbolicmap = true, - use_defaults = false, build_initializeprob = true, _kwargs...) + use_defaults = false, lazy_initialization = nothing, build_initializeprob = true, + _kwargs...) if tspan === missing tspan = prob.tspan end @@ -438,7 +454,7 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing, iip = isinplace(prob) - if kwargs === missing + prob = if kwargs === missing DDEProblem{iip}(f, newu0, h, @@ -454,6 +470,20 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing, DDEProblem{iip}(f, newu0, h, tspan, newp; constant_lags, dependent_lags, order_discontinuity_t0, neutral, kwargs...) end + if lazy_initialization === nothing + lazy_initialization = !is_trivial_initialization(initialization_data) + end + if !lazy_initialization + u0, p, _ = get_initial_values( + prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) + if u0 !== nothing && eltype(u0) == Any && isempty(u0) + u0 = nothing + end + @reset prob.u0 = u0 + @reset prob.p = p + end + + return prob end function remake(func::DDEFunction; @@ -495,6 +525,7 @@ function remake(prob::SDDEProblem; use_defaults = false, seed = missing, kwargs = missing, + lazy_initialization = nothing, build_initializeprob = true, _kwargs...) if tspan === missing @@ -533,7 +564,7 @@ function remake(prob::SDDEProblem; order_discontinuity_t0 = coalesce(order_discontinuity_t0, prob.order_discontinuity_t0) neutral = coalesce(neutral, prob.neutral) - if kwargs === missing + prob = if kwargs === missing SDDEProblem{iip}(f, g, newu0, @@ -554,6 +585,21 @@ function remake(prob::SDDEProblem; f, g, newu0, tspan, newp; noise, noise_rate_prototype, seed, constant_lags, dependent_lags, order_discontinuity_t0, neutral, kwargs...) end + + if lazy_initialization === nothing + lazy_initialization = !is_trivial_initialization(initialization_data) + end + if !lazy_initialization + u0, p, _ = get_initial_values( + prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) + if u0 !== nothing && eltype(u0) == Any && isempty(u0) + u0 = nothing + end + @reset prob.u0 = u0 + @reset prob.p = p + end + + return prob end """ From 93483efadcceeb52c6da5c0bafb8d1709ef9866c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 17:04:58 +0530 Subject: [PATCH 20/30] feat: add `constructorof` for `NonlinearProblem`, `NonlinearLeastSquaresProblem` --- src/problems/nonlinear_problems.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/problems/nonlinear_problems.jl b/src/problems/nonlinear_problems.jl index c7d5d36b2..fe5448dba 100644 --- a/src/problems/nonlinear_problems.jl +++ b/src/problems/nonlinear_problems.jl @@ -222,6 +222,17 @@ function NonlinearProblem(f::AbstractODEFunction, u0, p = NullParameters(); kwar NonlinearProblem{isinplace(f)}(f, u0, p; kwargs...) end +function ConstructionBase.constructorof(::Type{P}) where {P <: NonlinearProblem} + function ctor(f, u0, p, pt, kw) + if f isa AbstractNonlinearFunction + iip = isinplace(f) + else + iip = isinplace(f, 4) + end + return NonlinearProblem{iip}(f, u0, p, pt; kw...) + end +end + """ $(SIGNATURES) @@ -322,6 +333,17 @@ function NonlinearLeastSquaresProblem(f, u0, p = NullParameters(); kwargs...) return NonlinearLeastSquaresProblem(NonlinearFunction(f), u0, p; kwargs...) end +function ConstructionBase.constructorof(::Type{P}) where {P <: NonlinearLeastSquaresProblem} + function ctor(f, u0, p, kw) + if f isa AbstractNonlinearFunction + iip = isinplace(f) + else + iip = isinplace(f, 4) + end + return NonlinearProblem{iip}(f, u0, p; kw...) + end +end + @doc doc""" SCCNonlinearProblem(probs, explicitfuns!) From 8b8d8f217528b23cf5e9046de461695dce59bc79 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 17:07:10 +0530 Subject: [PATCH 21/30] feat: add proper `remake` for `NonlinearProblem`, `NonlinearLeastSquaresProblem` --- src/remake.jl | 89 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 76 insertions(+), 13 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 5a0110eea..4cf8396ed 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -675,23 +675,63 @@ function remake(prob::NonlinearProblem; kwargs = missing, interpret_symbolicmap = true, use_defaults = false, + lazy_initialization = nothing, + build_initializeprob = true, _kwargs...) - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) - if f === missing - f = prob.f + newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) + + if build_initializeprob + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, nothing, p, newu0, newp) + else + initialization_data = nothing end + + f = remake(prob.f; f, initialization_data) + if problem_type === missing problem_type = prob.problem_type end - if kwargs === missing - NonlinearProblem{isinplace(prob)}(f = f, u0 = u0, p = p, + prob = if kwargs === missing + NonlinearProblem{isinplace(prob)}(f = f, u0 = newu0, p = newp, problem_type = problem_type; prob.kwargs..., _kwargs...) else - NonlinearProblem{isinplace(prob)}(f = f, u0 = u0, p = p, + NonlinearProblem{isinplace(prob)}(f = f, u0 = newu0, p = newp, problem_type = problem_type; kwargs...) end + + if lazy_initialization === nothing + lazy_initialization = !is_trivial_initialization(initialization_data) + end + if !lazy_initialization + u0, p, _ = get_initial_values( + prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) + if u0 !== nothing && eltype(u0) == Any && isempty(u0) + u0 = nothing + end + @reset prob.u0 = u0 + @reset prob.p = p + end + + return prob +end + +function remake(func::NonlinearFunction; + f = missing, + kwargs...) + props = getproperties(func) + props = @delete props.f + + if f === missing + f = func.f + end + if f isa AbstractSciMLFunction + f = f.f + end + + return NonlinearFunction{isinplace(func)}(f; props..., kwargs...) end """ @@ -701,19 +741,42 @@ end Remake the given `NonlinearLeastSquaresProblem`. """ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing, - interpret_symbolicmap = true, use_defaults = false, kwargs = missing, _kwargs...) - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) + interpret_symbolicmap = true, use_defaults = false, kwargs = missing, + lazy_initialization = nothing, build_initializeprob = true, _kwargs...) + newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) - if f === missing - f = prob.f + if build_initializeprob + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, nothing, p, newu0, newp) + else + initialization_data = nothing end - if kwargs === missing - return NonlinearLeastSquaresProblem{isinplace(prob)}(; f, u0, p, prob.kwargs..., + f = remake(prob.f; f, initialization_data) + + prob = if kwargs === missing + return NonlinearLeastSquaresProblem{isinplace(prob)}(; + f, u0 = newu0, p = newp, prob.kwargs..., _kwargs...) else - return NonlinearLeastSquaresProblem{isinplace(prob)}(; f, u0, p, kwargs...) + return NonlinearLeastSquaresProblem{isinplace(prob)}(; + f, u0 = newu0, p = newp, kwargs...) + end + + if lazy_initialization === nothing + lazy_initialization = !is_trivial_initialization(initialization_data) + end + if !lazy_initialization + u0, p, _ = get_initial_values( + prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) + if u0 !== nothing && eltype(u0) == Any && isempty(u0) + u0 = nothing + end + @reset prob.u0 = u0 + @reset prob.p = p end + + return prob end """ From 21881042ea2970ea1a80f77a53043d9118f51633 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 17:47:58 +0530 Subject: [PATCH 22/30] fix: allow specifying `f` for `remake` of `SCCNonlinearProblem` --- src/remake.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 4cf8396ed..9a9b6e80d 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -792,7 +792,7 @@ error and require that `probs` be specified. `probs` is the collection of subpro override the values in `probs`. `sys` is the index provider for the full system. """ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing, - parameters_alias = prob.parameters_alias, sys = missing, + parameters_alias = prob.parameters_alias, f = missing, sys = missing, interpret_symbolicmap = true, use_defaults = false, explicitfuns! = missing) if p !== missing && !parameters_alias && probs === missing throw(ArgumentError("`parameters_alias` is `false` for the given `SCCNonlinearProblem`. Please provide the subproblems using the keyword `probs` with the parameters updated appropriately in each.")) @@ -816,11 +816,13 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi return subprob end end - if sys === missing - sys = prob.f.sys - end + f = coalesce(f, prob.f) + f = remake(f; sys) + props = getproperties(f) + props = @delete props.f + return SCCNonlinearProblem( - probs, explicitfuns!, newp, parameters_alias; sys) + probs, explicitfuns!, newp, parameters_alias; props...) end function varmap_has_var(varmap, var) From 58f68527e73637d07033e08589eb2f152a7f0c03 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 9 Dec 2024 16:36:26 +0530 Subject: [PATCH 23/30] fix: handle `initialization_data` in `f` passed to `remake` --- src/remake.jl | 61 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 9a9b6e80d..32cada1f0 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -124,15 +124,18 @@ function remake(prob::ODEProblem; f = missing, iip = isinplace(prob) - initialization_data = prob.f.initialization_data - - if f === missing - if build_initializeprob - initialization_data = remake_initialization_data_compat_wrapper( - prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) + if build_initializeprob + if f !== missing && has_initialization_data(f) + initialization_data = f.initialization_data else - initialization_data = nothing + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) end + else + initialization_data = nothing + end + + if f === missing if specialization(prob.f) === FunctionWrapperSpecialize ptspan = promote_tspan(tspan) if iip @@ -337,8 +340,12 @@ function remake(prob::SDEProblem; newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults) if build_initializeprob - initialization_data = remake_initialization_data( - prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) + if f !== missing && has_initialization_data(f) + initialization_data = f.initialization_data + else + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) + end else initialization_data = nothing end @@ -437,8 +444,12 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing, newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults) if build_initializeprob - initialization_data = remake_initialization_data( - prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) + if f !== missing && has_initialization_data(f) + initialization_data = f.initialization_data + else + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) + end else initialization_data = nothing end @@ -535,8 +546,12 @@ function remake(prob::SDDEProblem; newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults) if build_initializeprob - initialization_data = remake_initialization_data( - prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) + if f !== missing && has_initialization_data(f) + initialization_data = f.initialization_data + else + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) + end else initialization_data = nothing end @@ -681,8 +696,12 @@ function remake(prob::NonlinearProblem; newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) if build_initializeprob - initialization_data = remake_initialization_data( - prob.f.sys, prob.f, u0, nothing, p, newu0, newp) + if f !== missing && has_initialization_data(f) + initialization_data = f.initialization_data + else + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, nothing, p, newu0, newp) + end else initialization_data = nothing end @@ -746,8 +765,12 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) if build_initializeprob - initialization_data = remake_initialization_data( - prob.f.sys, prob.f, u0, nothing, p, newu0, newp) + if f !== missing && has_initialization_data(f) + initialization_data = f.initialization_data + else + initialization_data = remake_initialization_data( + prob.f.sys, prob.f, u0, nothing, p, newu0, newp) + end else initialization_data = nothing end @@ -755,11 +778,11 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p f = remake(prob.f; f, initialization_data) prob = if kwargs === missing - return NonlinearLeastSquaresProblem{isinplace(prob)}(; + prob = NonlinearLeastSquaresProblem{isinplace(prob)}(; f, u0 = newu0, p = newp, prob.kwargs..., _kwargs...) else - return NonlinearLeastSquaresProblem{isinplace(prob)}(; + prob = NonlinearLeastSquaresProblem{isinplace(prob)}(; f, u0 = newu0, p = newp, kwargs...) end From fbcc39e26054888e5995dc031cea662d4465cb46 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 9 Dec 2024 16:38:07 +0530 Subject: [PATCH 24/30] test: test lazy initialization in `remake` for supported problem types --- test/downstream/modelingtoolkit_remake.jl | 46 ++++++++++++++++++++--- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 7b69e7746..00cb97d46 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -337,11 +337,45 @@ end @test sccprob4.p !== sccprob4.probs[2].p end +# TODO: Rewrite this test when MTK build initialization for everything @testset "Lazy initialization" begin - @variables x(t) [guess = 1.0] y(t) [guess = 1.0] - @parameters p=missing [guess = 1.0] - @mtkbuild sys = ODESystem([D(x) ~ x, x + y ~ p], t) - prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0)) - prob2 = remake(prob; u0 = [x => 2.0]) - @test prob2.ps[p] ≈ 3.0 + @variables _x(..) [guess = 1.0] y(t) [guess = 1.0] + @parameters p=1.0 [guess = 1.0] + @brownian a + x = _x(t) + + initprob = NonlinearProblem(nothing) do args... + return 0.0 + end + iprobmap = (_...) -> [1.0, 1.0] + iprobpmap = function (orig, sol) + ps = parameter_values(orig) + setp(orig, p)(ps, 3.0) + return ps + end + initdata = SciMLBase.OverrideInitData(initprob, nothing, iprobmap, iprobpmap) + @test SciMLBase.is_trivial_initialization(initdata) + + @testset "$Problem" for (SystemT, rhss, Problem, Func) in [ + (ODESystem, 0.0, ODEProblem, ODEFunction), + (System, a, SDEProblem, SDEFunction), + (ODESystem, _x(t - 0.1), DDEProblem, DDEFunction), + (System, _x(t - 0.1) + a, SDDEProblem, SDDEFunction), + (NonlinearSystem, y + 2, NonlinearProblem, NonlinearFunction), + (NonlinearSystem, y + 2, NonlinearLeastSquaresProblem, NonlinearFunction) + ] + is_nlsolve = SystemT == NonlinearSystem + D = is_nlsolve ? (v) -> v^3 : Differential(t) + sys_args = is_nlsolve ? () : (t,) + prob_args = is_nlsolve ? () : ((0.0, 1.0),) + + @mtkbuild sys = SystemT([D(x) ~ x + rhss, x + y ~ p], sys_args...) + prob = Problem(sys, [x => 1.0, y => 1.0], prob_args...) + func_args = isdefined(prob.f, :g) ? (prob.f.g,) : () + func = Func{true, SciMLBase.FullSpecialize}( + prob.f.f, func_args...; initialization_data = initdata, sys = prob.f.sys) + prob2 = remake(prob; f = func) + @test SciMLBase.is_trivial_initialization(prob2) + @test prob2.ps[p] ≈ 3.0 + end end From 0b8e4607e0f7074d47357d4263f45df09e74c5cb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 9 Dec 2024 17:12:40 +0530 Subject: [PATCH 25/30] fix: check if `initialization_data` exists before running eager initialization --- src/remake.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 32cada1f0..53e1e2911 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -187,7 +187,7 @@ function remake(prob::ODEProblem; f = missing, if lazy_initialization === nothing lazy_initialization = !is_trivial_initialization(initialization_data) end - if !lazy_initialization + if initialization_data !== nothing && !lazy_initialization u0, p, _ = get_initial_values( prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) if u0 !== nothing && eltype(u0) == Any && isempty(u0) @@ -382,7 +382,7 @@ function remake(prob::SDEProblem; if lazy_initialization === nothing lazy_initialization = !is_trivial_initialization(initialization_data) end - if !lazy_initialization + if initialization_data !== nothing && !lazy_initialization u0, p, _ = get_initial_values( prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) if u0 !== nothing && eltype(u0) == Any && isempty(u0) @@ -484,7 +484,7 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing, if lazy_initialization === nothing lazy_initialization = !is_trivial_initialization(initialization_data) end - if !lazy_initialization + if initialization_data !== nothing && !lazy_initialization u0, p, _ = get_initial_values( prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) if u0 !== nothing && eltype(u0) == Any && isempty(u0) @@ -604,7 +604,7 @@ function remake(prob::SDDEProblem; if lazy_initialization === nothing lazy_initialization = !is_trivial_initialization(initialization_data) end - if !lazy_initialization + if initialization_data !== nothing && !lazy_initialization u0, p, _ = get_initial_values( prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) if u0 !== nothing && eltype(u0) == Any && isempty(u0) @@ -724,7 +724,7 @@ function remake(prob::NonlinearProblem; if lazy_initialization === nothing lazy_initialization = !is_trivial_initialization(initialization_data) end - if !lazy_initialization + if initialization_data !== nothing && !lazy_initialization u0, p, _ = get_initial_values( prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) if u0 !== nothing && eltype(u0) == Any && isempty(u0) @@ -789,7 +789,7 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p if lazy_initialization === nothing lazy_initialization = !is_trivial_initialization(initialization_data) end - if !lazy_initialization + if initialization_data !== nothing && !lazy_initialization u0, p, _ = get_initial_values( prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) if u0 !== nothing && eltype(u0) == Any && isempty(u0) From 665582eeb00432736a70502a2059b6cbdfe9d4f5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 9 Dec 2024 17:42:40 +0530 Subject: [PATCH 26/30] test: add StochasticDelayDiffEq to downstream CI --- test/downstream/Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 7e01ecb53..17ad5e447 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -19,6 +19,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" +StochasticDelayDiffEq = "29a0d76e-afc8-11e9-03a4-eda52ae4b960" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -37,8 +38,8 @@ ModelingToolkit = "9.56" ModelingToolkitStandardLibrary = "2.7" NonlinearSolve = "2, 3, 4" Optimization = "4" -OptimizationOptimJL = "0.4" OptimizationMOI = "0.5" +OptimizationOptimJL = "0.4" OrdinaryDiffEq = "6.33" PartialFunctions = "1" Plots = "1.40" From dcde577f9db7840456741e9ad1bc448ccb44d96e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 9 Dec 2024 18:15:41 +0530 Subject: [PATCH 27/30] test: do not test unimplemented SDDE integrator stepping --- test/downstream/comprehensive_indexing.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl index 0d59c8084..3104ddde3 100644 --- a/test/downstream/comprehensive_indexing.jl +++ b/test/downstream/comprehensive_indexing.jl @@ -977,9 +977,6 @@ end delay = sys.osc1.τ original = sys.osc1.x @test prob[sym] ≈ prob[original] .+ (prob.tspan[1] - prob.ps[delay]) - integ = init(prob, ImplicitEM()) - step!(integ, 10.0, true) - @test integ[sym] ≈ SciMLBase.get_sol(integ)(integ.t - integ.ps[delay]; idxs = original) sol = solve(prob, ImplicitEM()) @test sol[sym] ≈ sol(sol.t .- sol.ps[delay]; idxs = original) end From 089e31a2b5335450b6ccd40c1face875666e5173 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 10 Dec 2024 14:47:35 +0530 Subject: [PATCH 28/30] feat: generalize `get_history_function` to `AbstractODESolution` --- src/solutions/ode_solutions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index ccc3d7b77..c1c8c16d8 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -388,7 +388,7 @@ function (w::DDESolutionHistoryWrapper)( w.sol(out, t, deriv; idxs) end -function SymbolicIndexingInterface.get_history_function(sol::ODESolution) +function SymbolicIndexingInterface.get_history_function(sol::AbstractODESolution) DDESolutionHistoryWrapper(sol) end From ee118ef5150434c732b65533a091c637a0763fda Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 10 Dec 2024 15:34:12 +0530 Subject: [PATCH 29/30] feat: generalize `CheckInit` to DDEs --- src/initialization.jl | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/initialization.jl b/src/initialization.jl index 8b45bb6a6..ddf9a96d0 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -111,6 +111,27 @@ function _evaluate_f(integrator, f, isinplace::Val{false}, args...) return f(args...) end +""" +Utility function to evaluate the RHS, adding extra arguments (such as history function for +DDEs) wherever necessary. +""" +function evaluate_f(integrator::DEIntegrator, prob, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, u, p, t) +end + +function evaluate_f( + integrator::DEIntegrator, prob::AbstractDAEProblem, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t) +end + +function evaluate_f(integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t) +end + +function evaluate_f(integrator::AbstractSDDEIntegrator, prob::AbstractSDDEProblem, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t) +end + """ $(TYPEDSIGNATURES) @@ -147,7 +168,7 @@ function get_initial_values( algebraic_eqs = [all(iszero, x) for x in eachrow(M)] (iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true update_coefficients!(M, u0, p, t) - tmp = _evaluate_f(integrator, f, isinplace, u0, p, t) + tmp = evaluate_f(integrator, prob, f, isinplace, u0, p, t) tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) normresid = isdefined(integrator.opts, :internalnorm) ? @@ -165,7 +186,7 @@ function get_initial_values( p = parameter_values(integrator) t = current_time(integrator) - resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t) + resid = evaluate_f(integrator, prob, f, isinplace, u0, p, t) normresid = isdefined(integrator.opts, :internalnorm) ? integrator.opts.internalnorm(resid, t) : norm(resid) From dfcb2099726a7c2e1057c92c33d36c8d46d535f8 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 10 Dec 2024 08:21:41 -0500 Subject: [PATCH 30/30] Update ensemble_nondes.jl --- test/downstream/ensemble_nondes.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/downstream/ensemble_nondes.jl b/test/downstream/ensemble_nondes.jl index 449a5d545..4a9af1129 100644 --- a/test/downstream/ensemble_nondes.jl +++ b/test/downstream/ensemble_nondes.jl @@ -13,11 +13,11 @@ ensembleprob = Optimization.EnsembleProblem( sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 4, maxiters = 5) -@test findmin(i -> sol.u[i].objective, 1:4)[1] < sol1.objective +@test findmin(i -> sol.u[i].objective, 1:4)[1] <= sol1.objective sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 4, maxiters = 5) -@test findmin(i -> sol.u[i].objective, 1:4)[1] < sol1.objective +@test findmin(i -> sol.u[i].objective, 1:4)[1] <= sol1.objective prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5]) ensembleprob = Optimization.EnsembleProblem( @@ -25,11 +25,11 @@ ensembleprob = Optimization.EnsembleProblem( sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5) -@test findmin(i -> sol.u[i].objective, 1:4)[1] < sol1.objective +@test findmin(i -> sol.u[i].objective, 1:4)[1] <= sol1.objective sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 5, maxiters = 5) -@test findmin(i -> sol.u[i].objective, 1:4)[1] < sol1.objective +@test findmin(i -> sol.u[i].objective, 1:4)[1] <= sol1.objective using NonlinearSolve