Skip to content

Commit

Permalink
hotfix missing imports
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Nov 3, 2023
1 parent 8b5e58d commit 17b664c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.7.1"
version = "2.7.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
22 changes: 12 additions & 10 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
module SciMLBaseZygoteExt

using Zygote
using Zygote: pullback, ZygoteRules
using ZygoteRules: @adjoint
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand Down Expand Up @@ -82,7 +84,7 @@ end
VA[i], ODESolution_getindex_pullback
end

@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution,
::Val{:u})
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
end
Expand Down Expand Up @@ -140,32 +142,32 @@ end
NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint
end

@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution,
@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
= @. ifelse=== nothing, (zerou,), Δ)
(DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
(build_solution(sol.prob, sol.alg, sol.t, _Δ),)
end
sol.u, solu_adjoint
end

@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution,
@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
= @. ifelse=== nothing, zerou, Δ)
(DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),)
(build_solution(sol.prob, sol.alg, _Δ, sol.resid),)
end
sol.u, solu_adjoint
end

@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution,
@adjoint function literal_getproperty(sol::SciMLBase.OptimizationSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.u)
= @. ifelse=== nothing, zerou, Δ)
(DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),)
(build_solution(sol.cache, sol.alg, _Δ, sol.objective),)
end
sol.u, solu_adjoint
end
Expand Down

0 comments on commit 17b664c

Please sign in to comment.