From 9cb8c14153a32a7c849d4cef5b45b985f99f66c2 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 28 Oct 2023 09:08:59 +0100 Subject: [PATCH] Update SciMLBaseZygoteExt.jl --- ext/SciMLBaseZygoteExt.jl | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 59dc35509..58ebef8bd 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -2,7 +2,8 @@ module SciMLBaseZygoteExt using Zygote: pullback using ZygoteRules: @adjoint -using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved +import ZygoteRules +using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt @@ -55,4 +56,25 @@ end VA[sym, j], ODESolution_getindex_pullback end +ZygoteRules.@adjoint function DiffEqBase.EnsembleSolution(sim, time, converged) + out = EnsembleSolution(sim, time, converged) + function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} + arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] + for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]] + (EnsembleSolution(arrarr, 0.0, true), nothing, nothing) + end + function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) + (EnsembleSolution(p̄, 0.0, true), nothing, nothing) + end + function EnsembleSolution_adjoint(p̄::EnsembleSolution) + (p̄, nothing, nothing) + end + out, EnsembleSolution_adjoint +end + +ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, + ::Val{:u}) + sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),) +end + end