From 92ad6a84de65ee99da3bdb008ea7c9b192485289 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 2 May 2024 16:13:48 +0530 Subject: [PATCH] feat: adjoints through observable functions --- ext/SciMLBaseZygoteExt.jl | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 1d2104cf4..c3b24d458 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -7,7 +7,7 @@ using SciMLBase using SciMLBase: ODESolution, remake, getobserved, build_solution, EnsembleSolution, NonlinearSolution, AbstractTimeseriesSolution -using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index +using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, observed, parameter_values using RecursiveArrayTools # This method resolves the ambiguity with the pullback defined in @@ -109,7 +109,15 @@ end @adjoint function Base.getindex(VA::ODESolution, sym) function ODESolution_getindex_pullback(Δ) i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym - if i === nothing + if is_observed(VA, sym) + y, back = Zygote.pullback(VA) do sol + f = observed(sol, sym) + p = parameter_values(sol) + f.(sol.u,Ref(p), sol.t) + end + gs = back(Δ) + (gs[1], nothing) + elseif i === nothing throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) else Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] @@ -122,6 +130,7 @@ end @adjoint function Base.getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T function ODESolution_getindex_pullback(Δ) + @show typeof(Δ) sym = sym isa Tuple ? collect(sym) : sym i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) if i === nothing @@ -182,15 +191,15 @@ end NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint end -@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution, - ::Val{:u}) - function solu_adjoint(Δ) - zerou = zero(sol.prob.u0) - _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) - (build_solution(sol.prob, sol.alg, sol.t, _Δ),) - end - sol.u, solu_adjoint -end +# @adjoint function literal_getproperty(sol::AbstractTimeseriesSolution, +# ::Val{:u}) +# function solu_adjoint(Δ) +# zerou = zero(sol.prob.u0) +# _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) +# (build_solution(sol.prob, sol.alg, sol.t, _Δ),) +# end +# sol.u, solu_adjoint +# end @adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution, ::Val{:u})