Skip to content

Commit

Permalink
Update SciMLBaseZygoteExt.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas authored Oct 28, 2023
1 parent a44c420 commit 9cb8c14
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module SciMLBaseZygoteExt

using Zygote: pullback
using ZygoteRules: @adjoint
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved
import ZygoteRules
using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand Down Expand Up @@ -55,4 +56,25 @@ end
VA[sym, j], ODESolution_getindex_pullback
end

ZygoteRules.@adjoint function DiffEqBase.EnsembleSolution(sim, time, converged)
out = EnsembleSolution(sim, time, converged)
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]

Check warning on line 62 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L59-L62

Added lines #L59 - L62 were not covered by tests
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
(EnsembleSolution(arrarr, 0.0, true), nothing, nothing)

Check warning on line 64 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L64

Added line #L64 was not covered by tests
end
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
(EnsembleSolution(p̄, 0.0, true), nothing, nothing)

Check warning on line 67 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L66-L67

Added lines #L66 - L67 were not covered by tests
end
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
(p̄, nothing, nothing)

Check warning on line 70 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
end
out, EnsembleSolution_adjoint

Check warning on line 72 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L72

Added line #L72 was not covered by tests
end

ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,

Check warning on line 75 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L75

Added line #L75 was not covered by tests
::Val{:u})
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),)

Check warning on line 77 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L77

Added line #L77 was not covered by tests
end

end

0 comments on commit 9cb8c14

Please sign in to comment.