diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index 5fe107e2f8..c671634a02 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -874,7 +874,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, # be overridden with concrete_jac. _f = islin ? (isode ? f.f : f.f1.f) : f - jacvec = JacVec(UJacobianWrapper(_f, t, p), copy(u), p, t; + jacvec = JacVec((du, u, p, t) -> _f(du, u, p, t), copy(u), p, t; autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag()) J = jacvec W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec) @@ -890,7 +890,12 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, else deepcopy(f.jac_prototype) end - jacvec = JacVec(UJacobianWrapper(_f, t, p), copy(u), p, t; + __f = if IIP + (du, u, p, t) -> _f(du, u, p, t) + else + (u, p, t) -> _f(u, p, t) + end + jacvec = JacVec(__f, copy(u), p, t; autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag()) W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)