Skip to content

Commit

Permalink
chore: format
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed May 16, 2024
1 parent 95cf416 commit 4ce8257
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
9 changes: 5 additions & 4 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using SciMLBase
using SciMLBase: ODESolution, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, observed, parameter_values
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed,
observed, parameter_values
using RecursiveArrayTools

# This method resolves the ambiguity with the pullback defined in
Expand Down Expand Up @@ -113,7 +114,7 @@ end
y, back = Zygote.pullback(VA) do sol
f = observed(sol, sym)
p = parameter_values(sol)
f.(sol.u,Ref(p), sol.t)
f.(sol.u, Ref(p), sol.t)

Check warning on line 117 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L113-L117

Added lines #L113 - L117 were not covered by tests
end
gs = back(Δ)
(gs[1], nothing)
Expand All @@ -133,15 +134,15 @@ function obs_grads(VA, sym, obss_idx, Δ)
getindex.(Ref(sol), sym[obss_idx])

Check warning on line 134 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L132-L134

Added lines #L132 - L134 were not covered by tests
end
Dprime = reduce(hcat, Δ)
Dobss = eachrow(Dprime[obss_idx, :])
Dobss = eachrow(Dprime[obss_idx, :])
back(Dobss)

Check warning on line 138 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L136-L138

Added lines #L136 - L138 were not covered by tests
end

function obs_grads(VA, sym, ::Nothing, Δ)
Zygote.nt_nothing(VA)

Check warning on line 142 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L141-L142

Added lines #L141 - L142 were not covered by tests
end

function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where T
function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T}
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
map(enumerate(us)) do (u_idx, u)
if u_idx in i
Expand Down
14 changes: 7 additions & 7 deletions test/downstream/observables_autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ sol = solve(prob, Tsit5())
gs, = gradient(sol) do sol
sum(sol[sys.w])
end
du_ = [0., 1., 1., 1.]
du = [du_ for _ = sol.u]
du_ = [0.0, 1.0, 1.0, 1.0]
du = [du_ for _ in sol.u]
@test du == gs.u
end

Expand Down Expand Up @@ -100,15 +100,15 @@ function create_model(; C₁ = 3e-5, C₂ = 1e-6)
@named ampermeter = MSL.Electrical.CurrentSensor()

eqs = [connect(input_signal.output, source.V)
connect(source.p, capacitor1.n, capacitor2.n)
connect(source.n, resistor1.p, resistor2.p, ground.g)
connect(resistor1.n, capacitor1.p, ampermeter.n)
connect(resistor2.n, capacitor2.p, ampermeter.p)]
connect(source.p, capacitor1.n, capacitor2.n)
connect(source.n, resistor1.p, resistor2.p, ground.g)
connect(resistor1.n, capacitor1.p, ampermeter.n)
connect(resistor2.n, capacitor2.p, ampermeter.p)]

@named circuit_model = ODESystem(eqs, t,
systems = [
resistor1, resistor2, capacitor1, capacitor2,
source, input_signal, ground, ampermeter,
source, input_signal, ground, ampermeter
])
end

Expand Down

0 comments on commit 4ce8257

Please sign in to comment.