diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index b8952f910..13eb86d13 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -7,7 +7,8 @@ using SciMLBase using SciMLBase: ODESolution, remake, getobserved, build_solution, EnsembleSolution, NonlinearSolution, AbstractTimeseriesSolution -using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, observed, parameter_values +using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, + observed, parameter_values using RecursiveArrayTools # This method resolves the ambiguity with the pullback defined in @@ -113,7 +114,7 @@ end y, back = Zygote.pullback(VA) do sol f = observed(sol, sym) p = parameter_values(sol) - f.(sol.u,Ref(p), sol.t) + f.(sol.u, Ref(p), sol.t) end gs = back(Δ) (gs[1], nothing) @@ -133,7 +134,7 @@ function obs_grads(VA, sym, obss_idx, Δ) getindex.(Ref(sol), sym[obss_idx]) end Dprime = reduce(hcat, Δ) - Dobss = eachrow(Dprime[obss_idx, :]) + Dobss = eachrow(Dprime[obss_idx, :]) back(Dobss) end @@ -141,7 +142,7 @@ function obs_grads(VA, sym, ::Nothing, Δ) Zygote.nt_nothing(VA) end -function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where T +function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T} Δ′ = map(enumerate(VA.u)) do (t_idx, us) map(enumerate(us)) do (u_idx, u) if u_idx in i diff --git a/test/downstream/observables_autodiff.jl b/test/downstream/observables_autodiff.jl index bfd96b85f..554b6e6f9 100644 --- a/test/downstream/observables_autodiff.jl +++ b/test/downstream/observables_autodiff.jl @@ -31,8 +31,8 @@ sol = solve(prob, Tsit5()) gs, = gradient(sol) do sol sum(sol[sys.w]) end - du_ = [0., 1., 1., 1.] - du = [du_ for _ = sol.u] + du_ = [0.0, 1.0, 1.0, 1.0] + du = [du_ for _ in sol.u] @test du == gs.u end @@ -100,15 +100,15 @@ function create_model(; C₁ = 3e-5, C₂ = 1e-6) @named ampermeter = MSL.Electrical.CurrentSensor() eqs = [connect(input_signal.output, source.V) - connect(source.p, capacitor1.n, capacitor2.n) - connect(source.n, resistor1.p, resistor2.p, ground.g) - connect(resistor1.n, capacitor1.p, ampermeter.n) - connect(resistor2.n, capacitor2.p, ampermeter.p)] + connect(source.p, capacitor1.n, capacitor2.n) + connect(source.n, resistor1.p, resistor2.p, ground.g) + connect(resistor1.n, capacitor1.p, ampermeter.n) + connect(resistor2.n, capacitor2.p, ampermeter.p)] @named circuit_model = ODESystem(eqs, t, systems = [ resistor1, resistor2, capacitor1, capacitor2, - source, input_signal, ground, ampermeter, + source, input_signal, ground, ampermeter ]) end