From 5112e2756365a4410733e64ee33d3a97d3a8e01e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 7 Jun 2024 14:47:02 +0530 Subject: [PATCH] fix: add adjoints for symbolic indexing with `::Colon` --- ext/SciMLBaseZygoteExt.jl | 45 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 58e7bb309..9da8fadff 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -125,6 +125,32 @@ end VA[sym], ODESolution_getindex_pullback end +@adjoint function Base.getindex(VA::ODESolution, sym, ::Colon) + function ODESolution_getindex_pullback(Δ) + i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym + if is_observed(VA, sym) + f = observed(VA, sym) + p = parameter_values(VA) + tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p) + u = state_values(VA) + t = current_time(VA) + y, back = Zygote.pullback(u, tunables) do u, tunables + f.(u, Ref(tunables), t) + end + gs = back(Δ) + (gs[1], nothing) + elseif i === nothing + throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) + else + @show i + Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] + for (x, j) in zip(VA.u, 1:length(VA))] + (Δ′, nothing) + end + end + VA[sym, :], ODESolution_getindex_pullback +end + function obs_grads(VA, sym, obs_idx, Δ) y, back = Zygote.pullback(VA) do sol getindex.(Ref(sol), sym[obs_idx]) @@ -172,6 +198,25 @@ end VA[sym], ODESolution_getindex_pullback end +@adjoint function Base.getindex( + VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}, ::Colon) where {T} + function ODESolution_getindex_pullback(Δ) + sym = sym isa Tuple ? collect(sym) : sym + i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) + + obs_idx = findall(s -> is_observed(VA, s), sym) + not_obs_idx = setdiff(1:length(sym), obs_idx) + + gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ) + gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ) + + a = Zygote.accum(gs_obs[1], gs_not_obs) + + (a, nothing) + end + VA[sym], ODESolution_getindex_pullback +end + @adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 }(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8,