From 795eaaf7ebd145e8367407e66a60521250d3b13f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 24 Apr 2024 06:49:28 +0000 Subject: [PATCH 1/7] fix: correct gradients for vector of symbols --- ext/SciMLBaseZygoteExt.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index c7ba1827e..6a68193ae 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -120,6 +120,29 @@ end VA[sym], ODESolution_getindex_pullback end +@adjoint function getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T + function ODESolution_getindex_pullback(Δ) + sym = sym isa Tuple ? collect(sym) : sym + i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) + if i === nothing + throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) + else + Δ′ = map(enumerate(VA.u)) do (t_idx, us) + map(enumerate(us)) do (u_idx, u) + if u_idx in i + idx = findfirst(isequal(u_idx), i) + Δ[t_idx][idx] + else + zero(T) + end + end + end + (Δ′, nothing) + end + end + VA[sym], ODESolution_getindex_pullback +end + @adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 }(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8, From a19208a46febeea79a7559bc98e4ed807444e3f1 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 29 Apr 2024 08:08:46 +0000 Subject: [PATCH 2/7] test: symbolic indexing adjoint --- test/downstream/symbol_indexing.jl | 33 ++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 893fdf437..a0caf5174 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -97,6 +97,39 @@ end @test size(sol[[lorenz1.x, lorenz2.x], :]) == (2, length(sol)) @test size(sol[[lorenz1.x, lorenz2.x], :]) == size(sol[[1, 2], :]) == size(sol[1:2, :]) +gs_sym, = Zygote.gradient(sol) do sol + sum(sol[lorenz1.x]) +end +idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x) +true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys))) +true_grad_sym[idx_sym] .= 1. + +@test "Symbolic Indexing Adjoint: Symbol" all(x -> x == true_grad_sym, gs_sym) + +gs_vec, = Zygote.gradient(sol) do sol + sum(sum.(sol[[lorenz1.x, lorenz2]])) +end +idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x]) +true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys))) +true_grad_vecsym[idx_vecsym] .= 1. + +@test "Symbolic Indexing Adjoint: Vector{Symbol}" all(x -> x == true_grad_vecsym, gs_vec) + +gs_tup, = Zygote.gradient(sol) do sol + sum(sum.(collect.(sol[(lorenz1.x, lorenz2)]))) +end +idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x]) +true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys))) +true_grad_tupsym[idx_tupsym] .= 1. + +@test "Symbolic Indexing Adjoint: Tuple{Symbol}" all(x -> x == true_grad_tupsym, gs_tup) + +gs_ts, = Zygote.gradient(sol) do sol + sum(sol[[lorenz1.x, lorenz2], :]) +end + +@test "Symbolic Indexing Adjoint: Timeseries/ Vector{Symbol}" all(x -> x == true_grad_vecsym, gs_ts) + @variables q(t)[1:2] = [1.0, 2.0] eqs = [D(q[1]) ~ 2q[1] D(q[2]) ~ 2.0] From cc2316f5d1e6301a9f0b07edb86820d868f5a1a5 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 29 Apr 2024 19:41:59 +0530 Subject: [PATCH 3/7] test: index with correct symbols --- test/downstream/symbol_indexing.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index a0caf5174..143e8d519 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -1,4 +1,4 @@ -using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Test +using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Zygote, Test using Optimization, OptimizationOptimJL using ModelingToolkit: t_nounits as t, D_nounits as D @@ -107,7 +107,7 @@ true_grad_sym[idx_sym] .= 1. @test "Symbolic Indexing Adjoint: Symbol" all(x -> x == true_grad_sym, gs_sym) gs_vec, = Zygote.gradient(sol) do sol - sum(sum.(sol[[lorenz1.x, lorenz2]])) + sum(sum.(sol[[lorenz1.x, lorenz2.x]])) end idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x]) true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys))) @@ -116,7 +116,7 @@ true_grad_vecsym[idx_vecsym] .= 1. @test "Symbolic Indexing Adjoint: Vector{Symbol}" all(x -> x == true_grad_vecsym, gs_vec) gs_tup, = Zygote.gradient(sol) do sol - sum(sum.(collect.(sol[(lorenz1.x, lorenz2)]))) + sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)]))) end idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x]) true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys))) From bba5c5215100cf623fc4e0375a15b9c34f4cbbdf Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Wed, 1 May 2024 16:08:05 +0530 Subject: [PATCH 4/7] fix: update SII syntax in CR ext also --- ext/SciMLBaseChainRulesCoreExt.jl | 5 +++-- ext/SciMLBaseZygoteExt.jl | 10 +++++----- test/downstream/symbol_indexing.jl | 10 +++++----- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 0e84ccbc3..a30073222 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -3,6 +3,7 @@ module SciMLBaseChainRulesCoreExt using SciMLBase import ChainRulesCore import ChainRulesCore: NoTangent, @non_differentiable +using SymbolicIndexingInterface function ChainRulesCore.rrule( config::ChainRulesCore.RuleConfig{ @@ -13,7 +14,7 @@ function ChainRulesCore.rrule( sym, j::Integer) function ODESolution_getindex_pullback(Δ) - i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym + i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym if i === nothing getter = getobserved(VA) grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) @@ -66,7 +67,7 @@ end function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym) function ODESolution_getindex_pullback(Δ) - i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym + i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym if i === nothing throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) else diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 6a68193ae..1d2104cf4 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -13,7 +13,7 @@ using RecursiveArrayTools # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt # https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67 -@adjoint function getindex(VA::ODESolution, i::Int, j::Int) +@adjoint function Base.getindex(VA::ODESolution, i::Int, j::Int) function ODESolution_getindex_pullback(Δ) du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : zero(VA.u[1]) for m in 1:length(VA.u)] @@ -38,7 +38,7 @@ using RecursiveArrayTools VA[i, j], ODESolution_getindex_pullback end -@adjoint function getindex(VA::ODESolution, sym, j::Int) +@adjoint function Base.getindex(VA::ODESolution, sym, j::Int) function ODESolution_getindex_pullback(Δ) i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym du, dprob = if i === nothing @@ -92,7 +92,7 @@ end out, EnsembleSolution_adjoint end -@adjoint function getindex(VA::ODESolution, i::Int) +@adjoint function Base.getindex(VA::ODESolution, i::Int) function ODESolution_getindex_pullback(Δ) Δ′ = [(i == j ? Δ : Zygote.FillArrays.Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] @@ -106,7 +106,7 @@ end sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),) end -@adjoint function getindex(VA::ODESolution, sym) +@adjoint function Base.getindex(VA::ODESolution, sym) function ODESolution_getindex_pullback(Δ) i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym if i === nothing @@ -120,7 +120,7 @@ end VA[sym], ODESolution_getindex_pullback end -@adjoint function getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T +@adjoint function Base.getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T function ODESolution_getindex_pullback(Δ) sym = sym isa Tuple ? collect(sym) : sym i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 143e8d519..b7f54b851 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -102,9 +102,9 @@ gs_sym, = Zygote.gradient(sol) do sol end idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x) true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys))) -true_grad_sym[idx_sym] .= 1. +true_grad_sym[idx_sym] = 1. -@test "Symbolic Indexing Adjoint: Symbol" all(x -> x == true_grad_sym, gs_sym) +@test all(map(x -> x == true_grad_sym, gs_sym)) gs_vec, = Zygote.gradient(sol) do sol sum(sum.(sol[[lorenz1.x, lorenz2.x]])) @@ -113,7 +113,7 @@ idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lor true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys))) true_grad_vecsym[idx_vecsym] .= 1. -@test "Symbolic Indexing Adjoint: Vector{Symbol}" all(x -> x == true_grad_vecsym, gs_vec) +@test all(map(x -> x == true_grad_vecsym, gs_vec)) gs_tup, = Zygote.gradient(sol) do sol sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)]))) @@ -122,13 +122,13 @@ idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lor true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys))) true_grad_tupsym[idx_tupsym] .= 1. -@test "Symbolic Indexing Adjoint: Tuple{Symbol}" all(x -> x == true_grad_tupsym, gs_tup) +@test all(x -> x == true_grad_tupsym, gs_tup) gs_ts, = Zygote.gradient(sol) do sol sum(sol[[lorenz1.x, lorenz2], :]) end -@test "Symbolic Indexing Adjoint: Timeseries/ Vector{Symbol}" all(x -> x == true_grad_vecsym, gs_ts) +@test all(x -> x == true_grad_vecsym, gs_ts) @variables q(t)[1:2] = [1.0, 2.0] eqs = [D(q[1]) ~ 2q[1] From 59c24b2ae19cfe479175d06272080da80d13b2c6 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Wed, 1 May 2024 16:41:08 +0530 Subject: [PATCH 5/7] test: typo --- test/downstream/symbol_indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index b7f54b851..74b759b04 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -125,7 +125,7 @@ true_grad_tupsym[idx_tupsym] .= 1. @test all(x -> x == true_grad_tupsym, gs_tup) gs_ts, = Zygote.gradient(sol) do sol - sum(sol[[lorenz1.x, lorenz2], :]) + sum(sol[[lorenz1.x, lorenz2.x], :]) end @test all(x -> x == true_grad_vecsym, gs_ts) From 453a933524668e4f61b4686d8d1aeb62f712d5ad Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Wed, 1 May 2024 17:34:11 +0530 Subject: [PATCH 6/7] test: make test use map --- test/downstream/symbol_indexing.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 74b759b04..e6283ff34 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -122,13 +122,13 @@ idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lor true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys))) true_grad_tupsym[idx_tupsym] .= 1. -@test all(x -> x == true_grad_tupsym, gs_tup) +@test all(map(x -> x == true_grad_tupsym, gs_tup)) gs_ts, = Zygote.gradient(sol) do sol sum(sol[[lorenz1.x, lorenz2.x], :]) end -@test all(x -> x == true_grad_vecsym, gs_ts) +@test all(map(x -> x == true_grad_vecsym, gs_ts)) @variables q(t)[1:2] = [1.0, 2.0] eqs = [D(q[1]) ~ 2q[1] From 071ad2fe570ea4e27ee6fe80843c7f1c5bf25faa Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 2 May 2024 17:28:09 +0530 Subject: [PATCH 7/7] chore: format --- ext/SciMLBaseZygoteExt.jl | 3 ++- test/downstream/symbol_indexing.jl | 15 ++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 1d2104cf4..0147dbd4a 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -120,7 +120,8 @@ end VA[sym], ODESolution_getindex_pullback end -@adjoint function Base.getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T +@adjoint function Base.getindex( + VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T} function ODESolution_getindex_pullback(Δ) sym = sym isa Tuple ? collect(sym) : sym i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index e6283ff34..cec0b458e 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -1,4 +1,5 @@ -using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Zygote, Test +using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, + Zygote, Test using Optimization, OptimizationOptimJL using ModelingToolkit: t_nounits as t, D_nounits as D @@ -98,29 +99,29 @@ end @test size(sol[[lorenz1.x, lorenz2.x], :]) == size(sol[[1, 2], :]) == size(sol[1:2, :]) gs_sym, = Zygote.gradient(sol) do sol - sum(sol[lorenz1.x]) + sum(sol[lorenz1.x]) end idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x) true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys))) -true_grad_sym[idx_sym] = 1. +true_grad_sym[idx_sym] = 1.0 @test all(map(x -> x == true_grad_sym, gs_sym)) gs_vec, = Zygote.gradient(sol) do sol - sum(sum.(sol[[lorenz1.x, lorenz2.x]])) + sum(sum.(sol[[lorenz1.x, lorenz2.x]])) end idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x]) true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys))) -true_grad_vecsym[idx_vecsym] .= 1. +true_grad_vecsym[idx_vecsym] .= 1.0 @test all(map(x -> x == true_grad_vecsym, gs_vec)) gs_tup, = Zygote.gradient(sol) do sol - sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)]))) + sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)]))) end idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x]) true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys))) -true_grad_tupsym[idx_tupsym] .= 1. +true_grad_tupsym[idx_tupsym] .= 1.0 @test all(map(x -> x == true_grad_tupsym, gs_tup))