diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 0e84ccbc3..a30073222 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -3,6 +3,7 @@ module SciMLBaseChainRulesCoreExt using SciMLBase import ChainRulesCore import ChainRulesCore: NoTangent, @non_differentiable +using SymbolicIndexingInterface function ChainRulesCore.rrule( config::ChainRulesCore.RuleConfig{ @@ -13,7 +14,7 @@ function ChainRulesCore.rrule( sym, j::Integer) function ODESolution_getindex_pullback(Δ) - i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym + i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym if i === nothing getter = getobserved(VA) grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) @@ -66,7 +67,7 @@ end function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym) function ODESolution_getindex_pullback(Δ) - i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym + i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym if i === nothing throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) else diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index c7ba1827e..0147dbd4a 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -13,7 +13,7 @@ using RecursiveArrayTools # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt # https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67 -@adjoint function getindex(VA::ODESolution, i::Int, j::Int) +@adjoint function Base.getindex(VA::ODESolution, i::Int, j::Int) function ODESolution_getindex_pullback(Δ) du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : zero(VA.u[1]) for m in 1:length(VA.u)] @@ -38,7 +38,7 @@ using RecursiveArrayTools VA[i, j], ODESolution_getindex_pullback end -@adjoint function getindex(VA::ODESolution, sym, j::Int) +@adjoint function Base.getindex(VA::ODESolution, sym, j::Int) function ODESolution_getindex_pullback(Δ) i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym du, dprob = if i === nothing @@ -92,7 +92,7 @@ end out, EnsembleSolution_adjoint end -@adjoint function getindex(VA::ODESolution, i::Int) +@adjoint function Base.getindex(VA::ODESolution, i::Int) function ODESolution_getindex_pullback(Δ) Δ′ = [(i == j ? Δ : Zygote.FillArrays.Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] @@ -106,7 +106,7 @@ end sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),) end -@adjoint function getindex(VA::ODESolution, sym) +@adjoint function Base.getindex(VA::ODESolution, sym) function ODESolution_getindex_pullback(Δ) i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym if i === nothing @@ -120,6 +120,30 @@ end VA[sym], ODESolution_getindex_pullback end +@adjoint function Base.getindex( + VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T} + function ODESolution_getindex_pullback(Δ) + sym = sym isa Tuple ? collect(sym) : sym + i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) + if i === nothing + throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) + else + Δ′ = map(enumerate(VA.u)) do (t_idx, us) + map(enumerate(us)) do (u_idx, u) + if u_idx in i + idx = findfirst(isequal(u_idx), i) + Δ[t_idx][idx] + else + zero(T) + end + end + end + (Δ′, nothing) + end + end + VA[sym], ODESolution_getindex_pullback +end + @adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 }(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8, diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 98d07c5ba..699839b90 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -1,4 +1,5 @@ -using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Test +using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, + Zygote, Test using Optimization, OptimizationOptimJL using ModelingToolkit: t_nounits as t, D_nounits as D @@ -97,6 +98,39 @@ end @test length(sol[[lorenz1.x, lorenz2.x], :]) == length(sol) @test length(sol[[lorenz1.x, lorenz2.x], :][1]) == 2 +gs_sym, = Zygote.gradient(sol) do sol + sum(sol[lorenz1.x]) +end +idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x) +true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys))) +true_grad_sym[idx_sym] = 1.0 + +@test all(map(x -> x == true_grad_sym, gs_sym)) + +gs_vec, = Zygote.gradient(sol) do sol + sum(sum.(sol[[lorenz1.x, lorenz2.x]])) +end +idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x]) +true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys))) +true_grad_vecsym[idx_vecsym] .= 1.0 + +@test all(map(x -> x == true_grad_vecsym, gs_vec)) + +gs_tup, = Zygote.gradient(sol) do sol + sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)]))) +end +idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x]) +true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys))) +true_grad_tupsym[idx_tupsym] .= 1.0 + +@test all(map(x -> x == true_grad_tupsym, gs_tup)) + +gs_ts, = Zygote.gradient(sol) do sol + sum(sol[[lorenz1.x, lorenz2.x], :]) +end + +@test all(map(x -> x == true_grad_vecsym, gs_ts)) + @variables q(t)[1:2] = [1.0, 2.0] eqs = [D(q[1]) ~ 2q[1] D(q[2]) ~ 2.0]