diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index 21b23c60fa..5f70501692 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -711,7 +711,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, _f = islin ? (isode ? f.f : f.f1.f) : f jacvec = JacVec((du, u, p, t) -> _f(du, u, p, t), copy(u), p, t; - autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag()) + autodiff = constructorof(alg_autodiff(alg))(), tag = OrdinaryDiffEqTag()) J = jacvec W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec) elseif alg.linsolve !== nothing && !LinearSolve.needs_concrete_A(alg.linsolve) || @@ -734,7 +734,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, (u, p, t) -> _f(u, p, t) end jacvec = JacVec(__f, copy(u), p, t; - autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag()) + autodiff = constructorof(alg_autodiff(alg))(), tag = OrdinaryDiffEqTag()) WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec) end else