From 30c0342bb4604efaa812bf5fb12a0a47b7b8ba91 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 28 Oct 2023 09:10:30 +0100 Subject: [PATCH] Update SciMLBaseZygoteExt.jl --- ext/SciMLBaseZygoteExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 58ebef8bd..2b262572f 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -56,18 +56,18 @@ end VA[sym, j], ODESolution_getindex_pullback end -ZygoteRules.@adjoint function DiffEqBase.EnsembleSolution(sim, time, converged) +ZygoteRules.@adjoint function DiffEqBase.EnsembleSolution(sim, time, converged, stats) out = EnsembleSolution(sim, time, converged) function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]] - (EnsembleSolution(arrarr, 0.0, true), nothing, nothing) + (EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing) end function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) - (EnsembleSolution(p̄, 0.0, true), nothing, nothing) + (EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing) end function EnsembleSolution_adjoint(p̄::EnsembleSolution) - (p̄, nothing, nothing) + (p̄, nothing, nothing, nothing) end out, EnsembleSolution_adjoint end