Skip to content

Commit

Permalink
chore: format
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed May 2, 2024
1 parent 453a933 commit 071ad2f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
3 changes: 2 additions & 1 deletion ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function Base.getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T
@adjoint function Base.getindex(

Check warning on line 123 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L123

Added line #L123 was not covered by tests
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
function ODESolution_getindex_pullback(Δ)
sym = sym isa Tuple ? collect(sym) : sym
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
Expand Down
15 changes: 8 additions & 7 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Zygote, Test
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface,
Zygote, Test
using Optimization, OptimizationOptimJL
using ModelingToolkit: t_nounits as t, D_nounits as D

Expand Down Expand Up @@ -98,29 +99,29 @@ end
@test size(sol[[lorenz1.x, lorenz2.x], :]) == size(sol[[1, 2], :]) == size(sol[1:2, :])

gs_sym, = Zygote.gradient(sol) do sol
sum(sol[lorenz1.x])
sum(sol[lorenz1.x])
end
idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x)
true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_sym[idx_sym] = 1.
true_grad_sym[idx_sym] = 1.0

@test all(map(x -> x == true_grad_sym, gs_sym))

gs_vec, = Zygote.gradient(sol) do sol
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
end
idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_vecsym[idx_vecsym] .= 1.
true_grad_vecsym[idx_vecsym] .= 1.0

@test all(map(x -> x == true_grad_vecsym, gs_vec))

gs_tup, = Zygote.gradient(sol) do sol
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
end
idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_tupsym[idx_tupsym] .= 1.
true_grad_tupsym[idx_tupsym] .= 1.0

@test all(map(x -> x == true_grad_tupsym, gs_tup))

Expand Down

0 comments on commit 071ad2f

Please sign in to comment.