Skip to content

Commit

Permalink
feat: add initialization_data, sys to SDEFunctionWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 3, 2024
1 parent aa87692 commit 3c333b7
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions src/functionwrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end
(g::SDEDiffusionTermWrapper{false})(u, p, t) = g.g(u, g.h, p, t)

struct SDEFunctionWrapper{iip, F, G, H, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, GG,
TCV} <: DiffEqBase.AbstractRODEFunction{iip}
TCV, ID, S} <: DiffEqBase.AbstractRODEFunction{iip}
f::F
g::G
h::H
Expand All @@ -26,6 +26,8 @@ struct SDEFunctionWrapper{iip, F, G, H, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, T
paramjac::TPJ
ggprime::GG
colorvec::TCV
initialization_data::ID
sys::S
end

(f::SDEFunctionWrapper{true})(du, u, p, t) = f.f(du, u, f.h, p, t)
Expand Down Expand Up @@ -53,17 +55,9 @@ function wrap_functions_and_history(f::SDDEFunction, g, h)
typeof(f.analytic), typeof(f.tgrad), typeof(jac), typeof(f.jvp),
typeof(f.vjp), typeof(f.jac_prototype), typeof(f.sparsity),
typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.paramjac),
typeof(f.ggprime), typeof(f.colorvec)}(f.f, gwh, h,
f.mass_matrix,
f.analytic,
f.tgrad, jac,
f.jvp, f.vjp,
f.jac_prototype,
f.sparsity,
f.Wfact,
f.Wfact_t,
f.paramjac,
f.ggprime,
f.colorvec),
typeof(f.ggprime), typeof(f.colorvec), typeof(f.initialization_data),
typeof(f.sys)}(f.f, gwh, h, f.mass_matrix, f.analytic, f.tgrad, jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.paramjac,
f.ggprime, f.colorvec, f.initialization_data, f.sys),
gwh
end

0 comments on commit 3c333b7

Please sign in to comment.