diff --git a/Project.toml b/Project.toml index 34ef2a560..78a09fd27 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.7.1" +version = "2.7.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 09ce9bdb5..557a3844d 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -1,10 +1,12 @@ module SciMLBaseZygoteExt using Zygote -using Zygote: pullback, ZygoteRules -using ZygoteRules: @adjoint +using Zygote: @adjoint, pullback +import Zygote: literal_getproperty using SciMLBase -using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved +using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, + getobserved, build_solution, EnsembleSolution, + NonlinearSolution, AbstractTimeseriesSolution # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt @@ -82,7 +84,7 @@ end VA[i], ODESolution_getindex_pullback end -@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, +@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution, ::Val{:u}) sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),) end @@ -140,32 +142,32 @@ end NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint end -@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution, +@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution, ::Val{:u}) function solu_adjoint(Δ) zerou = zero(sol.prob.u0) _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) - (DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) + (build_solution(sol.prob, sol.alg, sol.t, _Δ),) end sol.u, solu_adjoint end -@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution, +@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution, ::Val{:u}) function solu_adjoint(Δ) zerou = zero(sol.prob.u0) _Δ = @. ifelse(Δ === nothing, zerou, Δ) - (DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),) + (build_solution(sol.prob, sol.alg, _Δ, sol.resid),) end sol.u, solu_adjoint end -@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution, +@adjoint function literal_getproperty(sol::SciMLBase.OptimizationSolution, ::Val{:u}) function solu_adjoint(Δ) zerou = zero(sol.u) _Δ = @. ifelse(Δ === nothing, zerou, Δ) - (DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),) + (build_solution(sol.cache, sol.alg, _Δ, sol.objective),) end sol.u, solu_adjoint end