From 795eaaf7ebd145e8367407e66a60521250d3b13f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 24 Apr 2024 06:49:28 +0000 Subject: [PATCH] fix: correct gradients for vector of symbols --- ext/SciMLBaseZygoteExt.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index c7ba1827e..6a68193ae 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -120,6 +120,29 @@ end VA[sym], ODESolution_getindex_pullback end +@adjoint function getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T + function ODESolution_getindex_pullback(Δ) + sym = sym isa Tuple ? collect(sym) : sym + i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) + if i === nothing + throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) + else + Δ′ = map(enumerate(VA.u)) do (t_idx, us) + map(enumerate(us)) do (u_idx, u) + if u_idx in i + idx = findfirst(isequal(u_idx), i) + Δ[t_idx][idx] + else + zero(T) + end + end + end + (Δ′, nothing) + end + end + VA[sym], ODESolution_getindex_pullback +end + @adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 }(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8,