Skip to content

Commit

Permalink
Merge pull request #84 from AayushSabharwal/as/initialization
Browse files Browse the repository at this point in the history
feat: add initialization support
  • Loading branch information
ChrisRackauckas authored Dec 14, 2024
2 parents 4912067 + 662e455 commit 4c1ae36
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Reexport = "1.0"
SciMLBase = "2.59.2"
SparseArrays = "1.9"
StaticArrays = "1.0"
StochasticDiffEq = "6.19"
StochasticDiffEq = "6.72.1"
UnPack = "0.1, 1.0"
julia = "1.9"

Expand Down
20 changes: 7 additions & 13 deletions src/functionwrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end
(g::SDEDiffusionTermWrapper{false})(u, p, t) = g.g(u, g.h, p, t)

struct SDEFunctionWrapper{iip, F, G, H, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, GG,
TCV} <: DiffEqBase.AbstractRODEFunction{iip}
TCV, ID, S} <: DiffEqBase.AbstractRODEFunction{iip}
f::F
g::G
h::H
Expand All @@ -26,6 +26,8 @@ struct SDEFunctionWrapper{iip, F, G, H, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, T
paramjac::TPJ
ggprime::GG
colorvec::TCV
initialization_data::ID
sys::S
end

(f::SDEFunctionWrapper{true})(du, u, p, t) = f.f(du, u, f.h, p, t)
Expand Down Expand Up @@ -53,17 +55,9 @@ function wrap_functions_and_history(f::SDDEFunction, g, h)
typeof(f.analytic), typeof(f.tgrad), typeof(jac), typeof(f.jvp),
typeof(f.vjp), typeof(f.jac_prototype), typeof(f.sparsity),
typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.paramjac),
typeof(f.ggprime), typeof(f.colorvec)}(f.f, gwh, h,
f.mass_matrix,
f.analytic,
f.tgrad, jac,
f.jvp, f.vjp,
f.jac_prototype,
f.sparsity,
f.Wfact,
f.Wfact_t,
f.paramjac,
f.ggprime,
f.colorvec),
typeof(f.ggprime), typeof(f.colorvec), typeof(f.initialization_data),
typeof(f.sys)}(f.f, gwh, h, f.mass_matrix, f.analytic, f.tgrad, jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.paramjac,
f.ggprime, f.colorvec, f.initialization_data, f.sys),
gwh
end
3 changes: 2 additions & 1 deletion src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end
mutable struct SDDEIntegrator{algType, IIP, uType, uEltype, tType, P, eigenType,
tTypeNoUnits, uEltypeNoUnits, randType, randType2, rateType,
solType, cacheType, F, G, F6, OType, noiseType,
EventErrorType, CallbackCacheType, H, IType} <:
EventErrorType, CallbackCacheType, H, IType, IA} <:
AbstractSDDEIntegrator{algType, IIP, uType, tType}
f::F
g::G
Expand Down Expand Up @@ -81,6 +81,7 @@ mutable struct SDDEIntegrator{algType, IIP, uType, uEltype, tType, P, eigenType,
history::H
stats::DiffEqBase.Stats
integrator::IType # history integrator
initializealg::IA
end

function (integrator::SDDEIntegrator)(t, deriv::Type = Val{0}; idxs = nothing)
Expand Down
8 changes: 5 additions & 3 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ function DiffEqBase.__init(prob::AbstractSDDEProblem,# TODO DiffEqBasee.Abstract
# Keywords for Delay problems (from DDE)
discontinuity_interp_points::Int = 10,
discontinuity_abstol = eltype(prob.tspan)(1 // Int64(10)^12),
discontinuity_reltol = 0, kwargs...) where {recompile_flag}
discontinuity_reltol = 0,
initializealg = StochasticDiffEq.SDEDefaultInit(), kwargs...) where {recompile_flag}

# alg = getalg(alg0);
if prob.f isa Tuple
Expand Down Expand Up @@ -468,7 +469,7 @@ function DiffEqBase.__init(prob::AbstractSDDEProblem,# TODO DiffEqBasee.Abstract
typeof(c),
typeof(opts), typeof(noise), typeof(last_event_error),
typeof(callback_cache), typeof(history),
typeof(sde_integrator)}(f_with_history,
typeof(sde_integrator), typeof(initializealg)}(f_with_history,
g_with_history, c, noise, uprev,
tprev,
order_discontinuity_t0,
Expand All @@ -486,9 +487,10 @@ function DiffEqBase.__init(prob::AbstractSDDEProblem,# TODO DiffEqBasee.Abstract
P,
opts, iter, success_iter, eigen_est,
EEst, q, QT(qoldinit), q11, history,
stats, sde_integrator)
stats, sde_integrator, initializealg)

if initialize_integrator
DiffEqBase.initialize_dae!(integrator)
StochasticDiffEq.initialize_callbacks!(integrator, initialize_save)
initialize!(integrator, integrator.cache)

Expand Down

0 comments on commit 4c1ae36

Please sign in to comment.