diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 2b262572f..a401e34f9 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -56,7 +56,7 @@ end VA[sym, j], ODESolution_getindex_pullback end -ZygoteRules.@adjoint function DiffEqBase.EnsembleSolution(sim, time, converged, stats) +ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats) out = EnsembleSolution(sim, time, converged) function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]