Skip to content

Commit

Permalink
fix: fix ODESolution getindex adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Apr 15, 2024
1 parent 2ee252f commit 65c7b3f
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ using RecursiveArrayTools
dprob = remake(VA.prob, p = dp)
du, dprob
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, typeof(VA.t),
typeof(VA.k), typeof(dprob), typeof(VA.alg), typeof(VA.interp),
typeof(VA.stats), typeof(VA.alg_choice)}(du, nothing, nothing,
if dprob.u0 === nothing
N = 2
elseif dprob isa SciMLBase.BVProblem && !hasmethod(size, Tuple{typeof(dprob.u0)})
__u0 = hasmethod(dprob.u0, Tuple{typeof(dprob.p), typeof(first(dprob.tspan))}) ?

Check warning on line 27 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L24-L27

Added lines #L24 - L27 were not covered by tests
dprob.u0(dprob.p, first(dprob.tspan)) : dprob.u0(first(dprob.tspan))
N = length((size(__u0)..., length(du)))

Check warning on line 29 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L29

Added line #L29 was not covered by tests
else
N = length((size(dprob.u0)..., length(du)))

Check warning on line 31 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L31

Added line #L31 was not covered by tests
end
Δ′ = ODESolution{T, N}(du, nothing, nothing,

Check warning on line 33 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L33

Added line #L33 was not covered by tests
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
Expand All @@ -50,10 +56,16 @@ end
du, dprob
end
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, typeof(VA.t),
typeof(VA.k), typeof(dprob), typeof(VA.alg), typeof(VA.interp),
typeof(VA.stats), typeof(VA.alg_choice)}(du, nothing, nothing,
if dprob.u0 === nothing
N = 2
elseif dprob isa SciMLBase.BVProblem && !hasmethod(size, Tuple{typeof(dprob.u0)})
__u0 = hasmethod(dprob.u0, Tuple{typeof(dprob.p), typeof(first(dprob.tspan))}) ?

Check warning on line 62 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L59-L62

Added lines #L59 - L62 were not covered by tests
dprob.u0(dprob.p, first(dprob.tspan)) : dprob.u0(first(dprob.tspan))
N = length((size(__u0)..., length(du)))

Check warning on line 64 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L64

Added line #L64 was not covered by tests
else
N = length((size(dprob.u0)..., length(du)))

Check warning on line 66 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L66

Added line #L66 was not covered by tests
end
Δ′ = ODESolution{T, N}(du, nothing, nothing,

Check warning on line 68 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L68

Added line #L68 was not covered by tests
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
Expand Down

0 comments on commit 65c7b3f

Please sign in to comment.