Skip to content

Commit

Permalink
chore: try to avoid returning object
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed May 19, 2024
1 parent ff9bb2c commit 032b927
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ end
VA[sym], ODESolution_getindex_pullback
end

function obs_grads(VA, sym, obss_idx, Δ)
function obs_grads(VA, sym, obs_idx, Δ)
y, back = Zygote.pullback(VA) do sol
getindex.(Ref(sol), sym[obss_idx])
getindex.(Ref(sol), sym[obs_idx])

Check warning on line 134 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L132-L134

Added lines #L132 - L134 were not covered by tests
end
Dprime = reduce(hcat, Δ)
Dobss = eachrow(Dprime[obss_idx, :])
back(Dobss)
Δreduced = reduce(hcat, Δ)
Δobs = eachrow(Δreduced[obs_idx, :])
back(Δobs)

Check warning on line 138 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L136-L138

Added lines #L136 - L138 were not covered by tests
end

function obs_grads(VA, sym, ::Nothing, Δ)
Expand Down Expand Up @@ -164,11 +164,11 @@ end
sym = sym isa Tuple ? collect(sym) : sym
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)

obss_idx = findall(s -> is_observed(VA, s), sym)
not_obss_idx = setdiff(1:length(sym), obss_idx)
obs_idx = findall(s -> is_observed(VA, s), sym)
not_obs_idx = setdiff(1:length(sym), obs_idx)

Check warning on line 168 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L167-L168

Added lines #L167 - L168 were not covered by tests

gs_obs = obs_grads(VA, sym, isempty(obss_idx) ? nothing : obss_idx, Δ)
gs_not_obs = not_obs_grads(VA, sym, not_obss_idx, i, Δ)
gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ)
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)

Check warning on line 171 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L170-L171

Added lines #L170 - L171 were not covered by tests

a = Zygote.accum(gs_obs[1], gs_not_obs)
(a, nothing)

Check warning on line 174 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L173-L174

Added lines #L173 - L174 were not covered by tests
Expand Down Expand Up @@ -220,7 +220,9 @@ end
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
= @. ifelse=== nothing, (zerou,), Δ)
(build_solution(sol.prob, sol.alg, sol.t, _Δ),)
nt = Zygote.nt_nothing(sol)
gs = Zygote.accum(nt, (u = _Δ,))
(gs,)

Check warning on line 225 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L223-L225

Added lines #L223 - L225 were not covered by tests
end
sol.u, solu_adjoint
end
Expand Down

0 comments on commit 032b927

Please sign in to comment.