Skip to content

Commit

Permalink
Merge pull request #885 from AayushSabharwal/as/init-infra
Browse files Browse the repository at this point in the history
feat: add infrastructure for initialization of different problem types
  • Loading branch information
ChrisRackauckas authored Dec 10, 2024
2 parents b9ac7b5 + dfcb209 commit f970ee7
Show file tree
Hide file tree
Showing 16 changed files with 538 additions and 109 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ StableRNGs = "1.0"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
SymbolicIndexingInterface = "0.3.34"
SymbolicIndexingInterface = "0.3.36"
Tables = "1.11"
Zygote = "0.6.67"
julia = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import FunctionWrappersWrappers
import RuntimeGeneratedFunctions
import EnumX
import ADTypes: ADTypes, AbstractADType
import Accessors: @set, @reset
import Accessors: @set, @reset, @delete
using Expronicon.ADT: @match

using Reexport
Expand Down
25 changes: 23 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,27 @@ function _evaluate_f(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

"""
Utility function to evaluate the RHS, adding extra arguments (such as history function for
DDEs) wherever necessary.
"""
function evaluate_f(integrator::DEIntegrator, prob, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, p, t)
end

function evaluate_f(
integrator::DEIntegrator, prob::AbstractDAEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t)
end

function evaluate_f(integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
end

function evaluate_f(integrator::AbstractSDDEIntegrator, prob::AbstractSDDEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
end

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -147,7 +168,7 @@ function get_initial_values(
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
update_coefficients!(M, u0, p, t)
tmp = _evaluate_f(integrator, f, isinplace, u0, p, t)
tmp = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))

normresid = isdefined(integrator.opts, :internalnorm) ?
Expand All @@ -165,7 +186,7 @@ function get_initial_values(
p = parameter_values(integrator)
t = current_time(integrator)

resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t)
resid = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
normresid = isdefined(integrator.opts, :internalnorm) ?
integrator.opts.internalnorm(resid, t) : norm(resid)

Expand Down
3 changes: 2 additions & 1 deletion src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ function isadaptive(integrator::DEIntegrator)
isdefined(integrator.opts, :adaptive) ? integrator.opts.adaptive : false
end

function SymbolicIndexingInterface.get_history_function(integ::AbstractDDEIntegrator)
function SymbolicIndexingInterface.get_history_function(integ::Union{
AbstractDDEIntegrator, AbstractSDDEIntegrator})
DDESolutionHistoryWrapper(get_sol(integ))
end
13 changes: 13 additions & 0 deletions src/problems/dde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,19 @@ struct DDEProblem{uType, tType, lType, lType2, isinplace, P, F, H, K, PT} <:
end
end

function ConstructionBase.constructorof(::Type{P}) where {P <: DDEProblem}
function ctor(f, u0, h, tspan, p, constant_lags, dependent_lags,
kw, neutral, order_discontinuity_t0, problem_type)
if f isa AbstractDDEFunction
iip = isinplace(f)
else
iip = isinplace(f, 5)
end
return DDEProblem{iip}(f, u0, h, tspan, p; kw..., constant_lags, dependent_lags,
neutral, order_discontinuity_t0, problem_type)
end
end

DDEProblem(f, args...; kwargs...) = DDEProblem(DDEFunction(f), args...; kwargs...)

function DDEProblem(f::AbstractDDEFunction, args...; kwargs...)
Expand Down
22 changes: 22 additions & 0 deletions src/problems/nonlinear_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@ function NonlinearProblem(f::AbstractODEFunction, u0, p = NullParameters(); kwar
NonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
end

function ConstructionBase.constructorof(::Type{P}) where {P <: NonlinearProblem}
function ctor(f, u0, p, pt, kw)
if f isa AbstractNonlinearFunction
iip = isinplace(f)
else
iip = isinplace(f, 4)
end
return NonlinearProblem{iip}(f, u0, p, pt; kw...)
end
end

"""
$(SIGNATURES)
Expand Down Expand Up @@ -322,6 +333,17 @@ function NonlinearLeastSquaresProblem(f, u0, p = NullParameters(); kwargs...)
return NonlinearLeastSquaresProblem(NonlinearFunction(f), u0, p; kwargs...)
end

function ConstructionBase.constructorof(::Type{P}) where {P <: NonlinearLeastSquaresProblem}
function ctor(f, u0, p, kw)
if f isa AbstractNonlinearFunction
iip = isinplace(f)
else
iip = isinplace(f, 4)
end
return NonlinearProblem{iip}(f, u0, p; kw...)
end
end

@doc doc"""
SCCNonlinearProblem(probs, explicitfuns!)
Expand Down
16 changes: 16 additions & 0 deletions src/problems/sdde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,19 @@ end
function SDDEProblem(f::AbstractSDDEFunction, args...; kwargs...)
SDDEProblem{isinplace(f)}(f, args...; kwargs...)
end

function ConstructionBase.constructorof(::Type{P}) where {P <: SDDEProblem}
function ctor(f, g, u0, h, tspan, p, noise, constant_lags, dependent_lags, kw,
noise_rate_prototype, seed, neutral, order_discontinuity_t0)
if f isa AbstractSDDEFunction
iip = isinplace(f)
else
iip = isinplace(f, 5)
end
return SDDEProblem{iip}(
f, g, u0, h, tspan, p; kw..., noise, constant_lags, dependent_lags,
noise_rate_prototype, seed, neutral, order_discontinuity_t0)
end
end

SymbolicIndexingInterface.get_history_function(prob::AbstractSDDEProblem) = prob.h
11 changes: 11 additions & 0 deletions src/problems/sde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...)
SDEProblem{iip}(SDEFunction{iip}(f, g), u0, tspan, p; kwargs...)
end

function ConstructionBase.constructorof(::Type{P}) where {P <: SDEProblem}
function ctor(f, g, u0, tspan, p, noise, kw, noise_rate_prototype, seed)
if f isa AbstractSDEFunction
iip = isinplace(f)
else
iip = isinplace(f, 4)
end
return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
end
end

"""
$(TYPEDEF)
"""
Expand Down
Loading

0 comments on commit f970ee7

Please sign in to comment.