From 322375dc31062095a3af96cbb1f170254df1516e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 17 Dec 2023 12:42:56 +0530 Subject: [PATCH] fix: fix EnsembleSolution adjoint --- ext/SciMLBaseZygoteExt.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 57d101949..0a2ce6ea1 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -8,6 +8,7 @@ using SciMLBase: ODESolution, sym_to_index, remake, getobserved, build_solution, EnsembleSolution, NonlinearSolution, AbstractTimeseriesSolution using SymbolicIndexingInterface: symbolic_type, NotSymbolic +using RecursiveArrayTools # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt @@ -70,6 +71,9 @@ end function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) (EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing) end + function EnsembleSolution_adjoint(p̄::RecursiveArrayTools.AbstractVectorOfArray) + (EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing) + end function EnsembleSolution_adjoint(p̄::EnsembleSolution) (p̄, nothing, nothing, nothing) end