Skip to content

Commit

Permalink
feat: adjoints through observable functions
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed May 2, 2024
1 parent 453a933 commit 92ad6a8
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using SciMLBase
using SciMLBase: ODESolution, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, observed, parameter_values
using RecursiveArrayTools

# This method resolves the ambiguity with the pullback defined in
Expand Down Expand Up @@ -109,7 +109,15 @@ end
@adjoint function Base.getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if i === nothing
if is_observed(VA, sym)
y, back = Zygote.pullback(VA) do sol
f = observed(sol, sym)
p = parameter_values(sol)
f.(sol.u,Ref(p), sol.t)

Check warning on line 116 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L112-L116

Added lines #L112 - L116 were not covered by tests
end
gs = back(Δ)
(gs[1], nothing)
elseif i === nothing

Check warning on line 120 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L118-L120

Added lines #L118 - L120 were not covered by tests
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
Expand All @@ -122,6 +130,7 @@ end

@adjoint function Base.getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T
function ODESolution_getindex_pullback(Δ)
@show typeof(Δ)

Check warning on line 133 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L133

Added line #L133 was not covered by tests
sym = sym isa Tuple ? collect(sym) : sym
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
if i === nothing
Expand Down Expand Up @@ -182,15 +191,15 @@ end
NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint
end

@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
= @. ifelse=== nothing, (zerou,), Δ)
(build_solution(sol.prob, sol.alg, sol.t, _Δ),)
end
sol.u, solu_adjoint
end
# @adjoint function literal_getproperty(sol::AbstractTimeseriesSolution,
# ::Val{:u})
# function solu_adjoint(Δ)
# zerou = zero(sol.prob.u0)
# _Δ = @. ifelse(Δ === nothing, (zerou,), Δ)
# (build_solution(sol.prob, sol.alg, sol.t, _Δ),)
# end
# sol.u, solu_adjoint
# end

@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution,
::Val{:u})
Expand Down

0 comments on commit 92ad6a8

Please sign in to comment.