diff --git a/src/internal/jacobian.jl b/src/internal/jacobian.jl index 50c6dfc51..aa1a3c83b 100644 --- a/src/internal/jacobian.jl +++ b/src/internal/jacobian.jl @@ -63,7 +63,9 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, if !has_analytic_jac && needs_jac autodiff = construct_concrete_adtype(f, autodiff) - di_extras = if iip + di_extras = if !iip && autodiff isa AutoEnzyme + Enzyme.onehot(u) + elseif iip DI.prepare_jacobian(f, fu, autodiff, u, Constant(prob.p)) else DI.prepare_jacobian(f, autodiff, u, Constant(prob.p)) @@ -90,8 +92,8 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, if iip DI.jacobian(f, fu, di_extras, autodiff, u, Constant(p)) else - if autodiff <: AutoEnzyme() - hcat(Enzyme.autodiff(Forward, f, BatchDuplicated(u, Enzyme.onehot(u)), Const(p))[1]...) + if autodiff isa AutoEnzyme + hcat(Enzyme.autodiff(Forward, f, BatchDuplicated(u, di_extras), Const(p))[1]...) else DI.jacobian(f, di_extras, autodiff, u, Constant(p)) end @@ -159,6 +161,8 @@ function (cache::JacobianCache{iip})( else if SciMLBase.has_jac(cache.f) return cache.f.jac(u, p) + elseif cache.autodiff isa AutoEnzyme + hcat(Enzyme.autodiff(Forward, cache.f, BatchDuplicated(u, cache.di_extras), Const(p))[1]...) else return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) end