Skip to content

Commit

Permalink
Merge pull request #594 from AayushSabharwal/as/override-init
Browse files Browse the repository at this point in the history
feat: support `initialize_dae!` for `SDEIntegrator`
  • Loading branch information
ChrisRackauckas authored Dec 11, 2024
2 parents 9a28c57 + 7508787 commit 71f24ae
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 5 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <[email protected]>"]
version = "6.71.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -20,6 +21,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -32,6 +34,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
ADTypes = "1"
Adapt = "3, 4"
ArrayInterface = "6, 7"
DataStructures = "0.18"
Expand All @@ -48,11 +51,12 @@ Logging = "1.6"
MuladdMacro = "0.2.1"
NLsolve = "4"
OrdinaryDiffEq = "6.87"
OrdinaryDiffEqCore = "1.12.1"
Random = "1.6"
RandomNumbers = "1.5.3"
RecursiveArrayTools = "2, 3"
Reexport = "0.2, 1.0"
SciMLBase = "2.59.2"
SciMLBase = "2.65"
SciMLOperators = "0.2.9, 0.3"
SparseArrays = "1.6"
SparseDiffTools = "2"
Expand Down
7 changes: 6 additions & 1 deletion src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using DocStringExtensions
using Reexport
@reexport using DiffEqBase

import ADTypes

import OrdinaryDiffEq
import OrdinaryDiffEq: default_controller, isstandard, ispredictive,
beta2_default, beta1_default, gamma_default,
Expand Down Expand Up @@ -41,7 +43,7 @@ using DocStringExtensions
import DiffEqBase: step!, initialize!, DEAlgorithm,
AbstractSDEAlgorithm, AbstractRODEAlgorithm, DEIntegrator, AbstractDiffEqInterpolation,
DECache, AbstractSDEIntegrator, AbstractRODEIntegrator, AbstractContinuousCallback,
Tableau
Tableau, AbstractSDDEIntegrator

# Integrator Interface
import DiffEqBase: resize!,deleteat!,addat!,full_cache,user_cache,u_cache,du_cache,
Expand All @@ -58,6 +60,8 @@ using OrdinaryDiffEq: nlsolvefail, isnewton, set_new_W!, get_W, _vec, _reshape

using OrdinaryDiffEq: NLSolver

import OrdinaryDiffEqCore

if isdefined(OrdinaryDiffEq,:FastConvergence)
using OrdinaryDiffEq:
FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence
Expand Down Expand Up @@ -119,6 +123,7 @@ end
include("cache_utils.jl")
include("integrators/integrator_interface.jl")
include("iterator_interface.jl")
include("initialize_dae.jl")
include("solve.jl")
include("initdt.jl")
include("perform_step/low_order.jl")
Expand Down
25 changes: 25 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,31 @@ SciMLBase.forwarddiffs_model(alg::Union{StochasticDiffEqNewtonAlgorithm,
StochasticDiffEqNewtonAdaptiveAlgorithm,StochasticDiffEqJumpNewtonAdaptiveAlgorithm,
StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm}) = OrdinaryDiffEq.alg_autodiff(alg)

# Required for initialization, because ODECore._initialize_dae! calls it during
# OverrideInit
OrdinaryDiffEqCore.has_autodiff(::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm,StochasticDiffEqJumpAlgorithm}) = false
for T in [
StochasticDiffEqNewtonAlgorithm, StochasticDiffEqNewtonAdaptiveAlgorithm,
StochasticDiffEqJumpNewtonAdaptiveAlgorithm,
StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm]
@eval OrdinaryDiffEqCore.has_autodiff(::$T) = true
end

_alg_autodiff(::StochasticDiffEqNewtonAlgorithm{T, AD}) where {T, AD} = Val{AD}()
_alg_autodiff(::StochasticDiffEqNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}()
_alg_autodiff(::StochasticDiffEqJumpNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}()
_alg_autodiff(::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}()

function OrdinaryDiffEqCore.alg_autodiff(alg)
ad = _alg_autodiff(alg)
if ad == Val(false)
return ADTypes.AutoFiniteDiff()
elseif ad == Val(true)
return ADTypes.AutoForwardDiff()
else
return SciMLBase._unwrap_val(ad)
end
end

isadaptive(alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm}) = false
isadaptive(alg::Union{StochasticDiffEqAdaptiveAlgorithm,StochasticDiffEqRODEAdaptiveAlgorithm,StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpDiffusionAdaptiveAlgorithm}) = true
Expand Down
13 changes: 13 additions & 0 deletions src/initialize_dae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
struct SDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end

function DiffEqBase.initialize_dae!(integrator::Union{AbstractSDEIntegrator, AbstractSDDEIntegrator}, initializealg = integrator.initializealg)
OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob)))
end

function OrdinaryDiffEqCore._initialize_dae!(integrator::Union{AbstractSDEIntegrator, AbstractSDDEIntegrator}, prob, ::SDEDefaultInit, isinplace)
if SciMLBase.has_initializeprob(prob.f)
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace)
elseif SciMLBase.__has_mass_matrix(prob.f)
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace)
end
end
3 changes: 2 additions & 1 deletion src/integrators/type.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenType,tTypeNoUnits,uEltypeNoUnits,randType,randType2,rateType,solType,cacheType,F4,F5,F6,OType,noiseType,EventErrorType,CallbackCacheType,RCs} <: AbstractSDEIntegrator{algType,IIP,uType,tType}
mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenType,tTypeNoUnits,uEltypeNoUnits,randType,randType2,rateType,solType,cacheType,F4,F5,F6,OType,noiseType,EventErrorType,CallbackCacheType,RCs,IA} <: AbstractSDEIntegrator{algType,IIP,uType,tType}
f::F4
g::F5
c::F6
Expand Down Expand Up @@ -43,4 +43,5 @@ mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenTy
qold::tTypeNoUnits
q11::tTypeNoUnits
stats::DiffEqBase.Stats
initializealg::IA
end
7 changes: 5 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ function DiffEqBase.__init(
userdata=nothing,
initialize_integrator=true,
seed = UInt64(0), alias_u0=false, alias_jumps = Threads.threadid()==1,
initializealg = SDEDefaultInit(),
kwargs...) where recompile_flag

prob = concrete_prob(_prob)
Expand Down Expand Up @@ -587,7 +588,8 @@ function DiffEqBase.__init(
uBottomEltype,tType,typeof(tdir),typeof(p),
typeof(eigen_est),QT,
uEltypeNoUnits,typeof(W),typeof(P),rateType,typeof(sol),typeof(cache),
FType,GType,CType,typeof(opts),typeof(noise),typeof(last_event_error),typeof(callback_cache),typeof(rate_constants)}(
FType,GType,CType,typeof(opts),typeof(noise),typeof(last_event_error),typeof(callback_cache),typeof(rate_constants),
typeof(initializealg)}(
f,g,c,noise,uprev,tprev,t,u,p,tType(dt),tType(dt),tType(dt),dtcache,tspan[2],tdir,
just_hit_tstop,do_error_check,isout,event_last_time,
vector_event_last_time,last_event_error,accept_step,
Expand All @@ -597,9 +599,10 @@ function DiffEqBase.__init(
alg,sol,
cache,callback_cache,tType(dt),W,P,rate_constants,
opts,iter,success_iter,eigen_est,EEst,q,
QT(qoldinit),q11,stats)
QT(qoldinit),q11,stats,initializealg)

if initialize_integrator
DiffEqBase.initialize_dae!(integrator)
initialize_callbacks!(integrator, initialize_save)
initialize!(integrator,integrator.cache)
save_start && alg isa Union{StochasticDiffEqCompositeAlgorithm,
Expand Down

0 comments on commit 71f24ae

Please sign in to comment.