Skip to content

Commit

Permalink
Merge branch 'master' into optensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 authored Nov 7, 2023
2 parents 1d87d19 + ffe68ae commit e85bdb3
Show file tree
Hide file tree
Showing 20 changed files with 356 additions and 150 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
julia-version: [1,1.6]
julia-version: [1]
os: [ubuntu-latest]
package:
- {user: SciML, repo: DelayDiffEq.jl, group: Interface}
Expand Down
16 changes: 11 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.5.0"
version = "2.7.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -32,16 +31,17 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SciMLBaseChainRulesCoreExt = "ChainRulesCore"
SciMLBasePartialFunctionsExt = "PartialFunctions"
SciMLBasePyCallExt = "PyCall"
SciMLBasePythonCallExt = "PythonCall"
Expand All @@ -54,14 +54,19 @@ ArrayInterface = "6, 7"
ChainRulesCore = "1.16"
CommonSolve = "0.2.4"
ConstructionBase = "1"
Distributed = "1.9"
DocStringExtensions = "0.8, 0.9"
EnumX = "1"
FillArrays = "1.6"
FunctionWrappersWrappers = "0.1.3"
IteratorInterfaceExtensions = "^0.1, ^1"
LinearAlgebra = "1.9"
Logging = "1.9"
Markdown = "1.9"
PartialFunctions = "1.1"
PrecompileTools = "1"
Preferences = "1.3"
Printf = "1.9"
RCall = "0.13.18"
RecipesBase = "0.7.0, 0.8, 1.0"
RecursiveArrayTools = "2.33"
Expand All @@ -75,11 +80,12 @@ SymbolicIndexingInterface = "0.2"
Tables = "1"
TruncatedStacktraces = "1"
QuasiMonteCarlo = "0.3"
ZygoteRules = "0.2"
julia = "1.6"
Zygote = "0.6"
julia = "1.9"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Expand Down
63 changes: 63 additions & 0 deletions src/solutions/chainrules.jl → ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
module SciMLBaseChainRulesCoreExt

using SciMLBase
import ChainRulesCore
import ChainRulesCore: NoTangent, @non_differentiable

function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{
>:ChainRulesCore.HasReverseMode,
},
Expand Down Expand Up @@ -70,3 +76,60 @@ function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
end
VA[sym], ODESolution_getindex_pullback
end

function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
function ODEProblemAdjoint(ȳ)
(NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
end

ODEProblem(args...; kwargs...), ODEProblemAdjoint
end

function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...)
function SDEProblemAdjoint(ȳ)
(NoTangent(), ȳ.f, ȳ.g, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
end

SDEProblem(args...; kwargs...), SDEProblemAdjoint
end

function ChainRulesCore.rrule(::Type{
<:ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
T11, T12,
}}, u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11,
T12}
function ODESolutionAdjoint(ȳ)
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
ODESolutionAdjoint
end

function ChainRulesCore.rrule(::Type{
<:ODESolution{uType, tType, isinplace, P, NP, F, G, K,
ND,
}}, u,
args...) where {uType, tType, isinplace, P, NP, F, G, K, ND}
function SDESolutionAdjoint(ȳ)
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
end

SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
end

function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged)
out = EnsembleSolution(sim, time, converged)
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
(NoTangent(), EnsembleSolution(arrarr, 0.0, true), NoTangent(), NoTangent())
end
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
(NoTangent(), p̄, NoTangent(), NoTangent())
end
out, EnsembleSolution_adjoint
end

end
162 changes: 152 additions & 10 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module SciMLBaseZygoteExt

using Zygote: pullback
using ZygoteRules: @adjoint
import ZygoteRules
using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved
using Zygote
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand Down Expand Up @@ -56,25 +59,164 @@ end
VA[sym, j], ODESolution_getindex_pullback
end

ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats)
out = EnsembleSolution(sim, time, converged)
@adjoint function EnsembleSolution(sim, time, converged, stats)
out = EnsembleSolution(sim, time, converged, stats)
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
(EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing)
(EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
(EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing)
(EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
(p̄, nothing, nothing, nothing)
end
out, EnsembleSolution_adjoint
end

ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
@adjoint function getindex(VA::ODESolution, i::Int)
function ODESolution_getindex_pullback(Δ)
Δ′ = [(i == j ? Δ : Zygote.FillArrays.Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
end
VA[i], ODESolution_getindex_pullback
end

@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution,
::Val{:u})
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),)
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
end

@adjoint function getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
end
end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
T9, T10, T11, T12}
function ODESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
ODESolutionAdjoint
end

@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u,
args...) where
{uType, tType, isinplace, P, NP, F, G, K, ND}
function SDESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
end

@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u,
args...) where {
T,
N,
uType,
R,
P,
A,
O,
uType2,
}
function NonlinearSolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end
NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint
end

@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
= @. ifelse=== nothing, (zerou,), Δ)
(build_solution(sol.prob, sol.alg, sol.t, _Δ),)
end
sol.u, solu_adjoint
end

@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
= @. ifelse=== nothing, zerou, Δ)
(build_solution(sol.prob, sol.alg, _Δ, sol.resid),)
end
sol.u, solu_adjoint
end

@adjoint function literal_getproperty(sol::SciMLBase.OptimizationSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.u)
= @. ifelse=== nothing, zerou, Δ)
(build_solution(sol.cache, sol.alg, _Δ, sol.objective),)
end
sol.u, solu_adjoint
end

function ∇tmap(cx, f, args...)
ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> (NoTangent(), NoTangent())
else
ys, backs = Zygote.unzip(ys_and_backs)
function ∇tmap_internal(Δ)
Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ)
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
ys, ∇tmap_internal
end
end

function ∇responsible_map(cx, f, args...)
ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...),
args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> (NoTangent(), NoTangent())
else
ys, backs = Zygote.unzip(ys_and_backs)
ys,
function ∇responsible_map_internal(Δ)
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ),
Zygote._tryreverse(SciMLBase.responsible_map,
backs, Δ)...)
Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map,
Δf_and_args_zipped))
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
end
end

@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...)
∇tmap(__context__, f, args...)
end

@adjoint function SciMLBase.responsible_map(f,
args::Union{AbstractArray, Tuple
}...)
∇responsible_map(__context__, f, args...)
end

end
3 changes: 0 additions & 3 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import RuntimeGeneratedFunctions
import EnumX
import TruncatedStacktraces
import ADTypes: AbstractADType
import ChainRulesCore
import ZygoteRules: @adjoint
import FillArrays
import QuasiMonteCarlo
using Reexport
Expand Down Expand Up @@ -716,7 +714,6 @@ include("solutions/optimization_solutions.jl")
include("solutions/dae_solutions.jl")
include("solutions/pde_solutions.jl")
include("solutions/solution_interface.jl")
include("solutions/zygote.jl")

include("ensemble/ensemble_solutions.jl")
include("ensemble/ensemble_problems.jl")
Expand Down
6 changes: 3 additions & 3 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function batch_func(i, prob, alg; kwargs...)
new_prob = prob.prob_func(_prob, i, iter)
rerun = true
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
if !(typeof(x) <: Tuple)
if !(x isa Tuple)
rerun_warn()
_x = (x, false)
else
Expand All @@ -117,7 +117,7 @@ function batch_func(i, prob, alg; kwargs...)
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
new_prob = prob.prob_func(_prob, i, iter)
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
if !(typeof(x) <: Tuple)
if !(x isa Tuple)
rerun_warn()
_x = (x, false)
else
Expand Down Expand Up @@ -170,7 +170,7 @@ function solve_batch(prob, alg, ensemblealg::EnsembleThreads, II, pmap_batch_siz
return solve_batch(prob, alg, EnsembleSerial(), II, pmap_batch_size; kwargs...)
end

if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
if prob.prob isa AbstractJumpProblem && length(II) != 1
probs = [deepcopy(prob.prob) for i in 1:nthreads]
else
probs = prob.prob
Expand Down
Loading

0 comments on commit e85bdb3

Please sign in to comment.