diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 8af8f3dce..85ede1a0f 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -125,32 +125,6 @@ end VA[sym], ODESolution_getindex_pullback end -@adjoint function Base.getindex(VA::ODESolution, sym, ::Colon) - function ODESolution_getindex_pullback(Δ) - i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym - if is_observed(VA, sym) - f = observed(VA, sym) - p = parameter_values(VA) - tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p) - u = state_values(VA) - t = current_time(VA) - y, back = Zygote.pullback(u, tunables) do u, tunables - f.(u, Ref(tunables), t) - end - gs = back(Δ) - (gs[1], nothing) - elseif i === nothing - throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) - else - @show i - Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] - for (x, j) in zip(VA.u, 1:length(VA))] - (Δ′, nothing) - end - end - VA[sym, :], ODESolution_getindex_pullback -end - function obs_grads(VA, sym, obs_idx, Δ) y, back = Zygote.pullback(VA) do sol getindex.(Ref(sol), sym[obs_idx]) @@ -198,25 +172,6 @@ end VA[sym], ODESolution_getindex_pullback end -@adjoint function Base.getindex( - VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}, ::Colon) where {T} - function ODESolution_getindex_pullback(Δ) - sym = sym isa Tuple ? collect(sym) : sym - i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) - - obs_idx = findall(s -> is_observed(VA, s), sym) - not_obs_idx = setdiff(1:length(sym), obs_idx) - - gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ) - gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ) - - a = Zygote.accum(gs_obs[1], gs_not_obs) - - (a, nothing) - end - VA[sym], ODESolution_getindex_pullback -end - @adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14 }(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8,