Skip to content

Commit

Permalink
Merge pull request #678 from DhairyaLGandhi/dg/vecsym
Browse files Browse the repository at this point in the history
fix: Correct gradients for vector of symbols while indexing
  • Loading branch information
ChrisRackauckas authored May 2, 2024
2 parents fc5a573 + 071ad2f commit c261a03
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 7 deletions.
5 changes: 3 additions & 2 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module SciMLBaseChainRulesCoreExt
using SciMLBase
import ChainRulesCore
import ChainRulesCore: NoTangent, @non_differentiable
using SymbolicIndexingInterface

function ChainRulesCore.rrule(
config::ChainRulesCore.RuleConfig{
Expand All @@ -13,7 +14,7 @@ function ChainRulesCore.rrule(
sym,
j::Integer)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if i === nothing
getter = getobserved(VA)
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -66,7 +67,7 @@ end

function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if i === nothing
throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
32 changes: 28 additions & 4 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using RecursiveArrayTools
# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
# https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67
@adjoint function getindex(VA::ODESolution, i::Int, j::Int)
@adjoint function Base.getindex(VA::ODESolution, i::Int, j::Int)
function ODESolution_getindex_pullback(Δ)
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
zero(VA.u[1]) for m in 1:length(VA.u)]
Expand All @@ -38,7 +38,7 @@ using RecursiveArrayTools
VA[i, j], ODESolution_getindex_pullback
end

@adjoint function getindex(VA::ODESolution, sym, j::Int)
@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
du, dprob = if i === nothing
Expand Down Expand Up @@ -92,7 +92,7 @@ end
out, EnsembleSolution_adjoint
end

@adjoint function getindex(VA::ODESolution, i::Int)
@adjoint function Base.getindex(VA::ODESolution, i::Int)
function ODESolution_getindex_pullback(Δ)
Δ′ = [(i == j ? Δ : Zygote.FillArrays.Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
Expand All @@ -106,7 +106,7 @@ end
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
end

@adjoint function getindex(VA::ODESolution, sym)
@adjoint function Base.getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if i === nothing
Expand All @@ -120,6 +120,30 @@ end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function Base.getindex(
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
function ODESolution_getindex_pullback(Δ)
sym = sym isa Tuple ? collect(sym) : sym
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
map(enumerate(us)) do (u_idx, u)
if u_idx in i
idx = findfirst(isequal(u_idx), i)
Δ[t_idx][idx]
else
zero(T)
end
end
end
(Δ′, nothing)
end
end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
Expand Down
36 changes: 35 additions & 1 deletion test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Test
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface,
Zygote, Test
using Optimization, OptimizationOptimJL
using ModelingToolkit: t_nounits as t, D_nounits as D

Expand Down Expand Up @@ -97,6 +98,39 @@ end
@test length(sol[[lorenz1.x, lorenz2.x], :]) == length(sol)
@test length(sol[[lorenz1.x, lorenz2.x], :][1]) == 2

gs_sym, = Zygote.gradient(sol) do sol
sum(sol[lorenz1.x])
end
idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x)
true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_sym[idx_sym] = 1.0

@test all(map(x -> x == true_grad_sym, gs_sym))

gs_vec, = Zygote.gradient(sol) do sol
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
end
idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_vecsym[idx_vecsym] .= 1.0

@test all(map(x -> x == true_grad_vecsym, gs_vec))

gs_tup, = Zygote.gradient(sol) do sol
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
end
idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_tupsym[idx_tupsym] .= 1.0

@test all(map(x -> x == true_grad_tupsym, gs_tup))

gs_ts, = Zygote.gradient(sol) do sol
sum(sol[[lorenz1.x, lorenz2.x], :])
end

@test all(map(x -> x == true_grad_vecsym, gs_ts))

@variables q(t)[1:2] = [1.0, 2.0]
eqs = [D(q[1]) ~ 2q[1]
D(q[2]) ~ 2.0]
Expand Down

0 comments on commit c261a03

Please sign in to comment.