Skip to content

Commit

Permalink
fix: correct gradients for vector of symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed Apr 24, 2024
1 parent 1238b2b commit 795eaaf
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,29 @@ end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T
function ODESolution_getindex_pullback(Δ)
sym = sym isa Tuple ? collect(sym) : sym
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))

Check warning on line 128 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L123-L128

Added lines #L123 - L128 were not covered by tests
else
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
map(enumerate(us)) do (u_idx, u)
if u_idx in i
idx = findfirst(isequal(u_idx), i)
Δ[t_idx][idx]

Check warning on line 134 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L130-L134

Added lines #L130 - L134 were not covered by tests
else
zero(T)

Check warning on line 136 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L136

Added line #L136 was not covered by tests
end
end
end
(Δ′, nothing)

Check warning on line 140 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L140

Added line #L140 was not covered by tests
end
end
VA[sym], ODESolution_getindex_pullback

Check warning on line 143 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L143

Added line #L143 was not covered by tests
end

@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
Expand Down

0 comments on commit 795eaaf

Please sign in to comment.