From 88dd51ec7036f6ffa71b9f1615b675111033375d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 7 Jun 2024 11:41:37 +0530 Subject: [PATCH 1/3] test: remove unnecessary test --- test/downstream/modelingtoolkit_remake.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 4c4206197..33ee43cd8 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -147,9 +147,7 @@ for (sys, prob) in zip(syss, probs) prob2 = @inferred baseType remake(prob; u0 = [sys.x => 0.5σ + 1], p = [sys.β => 0.5x + 1]) @test ugetter(prob2) ≈ [15.0, 0.0, 0.0] @test pgetter(prob2) ≈ [28.0, 8.5, 10.0] - prob2 = @inferred baseType remake(prob; u0 = [:x => 0.5σ + 1], p = [:β => 0.5x + 1]) - @test ugetter(prob2) ≈ [15.0, 0.0, 0.0] - @test pgetter(prob2) ≈ [28.0, 8.5, 10.0] + # Not testing `Symbol => expr` since nested substitution doesn't work with that end @variables ud(t) xd(t) yd(t) From b1c999fe88c98897dfda23a5db64758b48629dad Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 7 Jun 2024 12:12:19 +0530 Subject: [PATCH 2/3] test: fix `DiscreteProblem` initialization in mtk-remake tests --- test/downstream/modelingtoolkit_remake.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 33ee43cd8..e1e2c9da6 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -69,8 +69,9 @@ k = ShiftIndex(t) # Roundabout method to avoid having to specify values for previous timestep fn = DiscreteFunction(discsys) ps = ModelingToolkit.MTKParameters(discsys, p) +discu0 = Dict([u0..., x(k-1) => 0.0, y(k-1) => 0.0, z(k-1) => 0.0]) push!(syss, discsys) -push!(probs, DiscreteProblem(fn, [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0, 10), ps)) +push!(probs, DiscreteProblem(fn, getindex.((discu0,), unknowns(discsys)), (0, 10), ps)) for (sys, prob) in zip(syss, probs) @test parameter_values(prob) isa ModelingToolkit.MTKParameters From 5112e2756365a4410733e64ee33d3a97d3a8e01e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 7 Jun 2024 14:47:02 +0530 Subject: [PATCH 3/3] fix: add adjoints for symbolic indexing with `::Colon` --- ext/SciMLBaseZygoteExt.jl | 45 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 58e7bb309..9da8fadff 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -125,6 +125,32 @@ end VA[sym], ODESolution_getindex_pullback end +@adjoint function Base.getindex(VA::ODESolution, sym, ::Colon) + function ODESolution_getindex_pullback(Δ) + i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym + if is_observed(VA, sym) + f = observed(VA, sym) + p = parameter_values(VA) + tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p) + u = state_values(VA) + t = current_time(VA) + y, back = Zygote.pullback(u, tunables) do u, tunables + f.(u, Ref(tunables), t) + end + gs = back(Δ) + (gs[1], nothing) + elseif i === nothing + throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) + else + @show i + Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] + for (x, j) in zip(VA.u, 1:length(VA))] + (Δ′, nothing) + end + end + VA[sym, :], ODESolution_getindex_pullback +end + function obs_grads(VA, sym, obs_idx, Δ) y, back = Zygote.pullback(VA) do sol getindex.(Ref(sol), sym[obs_idx]) @@ -172,6 +198,25 @@ end VA[sym], ODESolution_getindex_pullback end +@adjoint function Base.getindex( + VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}, ::Colon) where {T} + function ODESolution_getindex_pullback(Δ) + sym = sym isa Tuple ? collect(sym) : sym + i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) + + obs_idx = findall(s -> is_observed(VA, s), sym) + not_obs_idx = setdiff(1:length(sym), obs_idx) + + gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ) + gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ) + + a = Zygote.accum(gs_obs[1], gs_not_obs) + + (a, nothing) + end + VA[sym], ODESolution_getindex_pullback +end + @adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 }(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8,