Skip to content

Commit

Permalink
Merge branch 'master' into dg/obsfn
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi authored May 8, 2024
2 parents 8600a8d + a0fab7a commit a69d087
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 53 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.34.0"
version = "2.36.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -76,15 +76,15 @@ PyCall = "1.96"
PythonCall = "0.9.15"
RCall = "0.14.0"
RecipesBase = "1.3.4"
RecursiveArrayTools = "3.8.0"
RecursiveArrayTools = "3.14.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
SciMLStructures = "1.1"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
SymbolicIndexingInterface = "0.3.15"
SymbolicIndexingInterface = "0.3.20"
Tables = "1.11"
Zygote = "0.6.67"
julia = "1.10"
Expand Down
7 changes: 4 additions & 3 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,16 @@ end

function ChainRulesCore.rrule(
::Type{
<:ODESolution{uType, tType, isinplace, P, NP, F, G, K,
<:RODESolution{uType, tType, isinplace, P, NP, F, G, K,
ND
}}, u,
args...) where {uType, tType, isinplace, P, NP, F, G, K, ND}
function SDESolutionAdjoint(ȳ)
function RODESolutionAdjoint(ȳ)
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
end

SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
RODESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...),
RODESolutionAdjoint
end

function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged)
Expand Down
7 changes: 4 additions & 3 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where T
Zygote.accum(nt, (u = Δ′,))

Check warning on line 157 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L156-L157

Added lines #L156 - L157 were not covered by tests
end

@adjoint function getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T
@adjoint function Base.getindex(
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
function ODESolution_getindex_pullback(Δ)
sym = sym isa Tuple ? collect(sym) : sym
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
Expand Down Expand Up @@ -189,11 +190,11 @@ end
@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u,
args...) where
{uType, tType, isinplace, P, NP, F, G, K, ND}
function SDESolutionAdjoint(ȳ)
function SDEProblemAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDEProblemAdjoint
end

@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u,
Expand Down
2 changes: 1 addition & 1 deletion src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,8 @@ Same as `check_error` but also set solution's return code
"""
function check_error!(integrator::DEIntegrator)
code = check_error(integrator)
integrator.sol = solution_new_retcode(integrator.sol, code)
if code != ReturnCode.Success
integrator.sol = solution_new_retcode(integrator.sol, code)
postamble!(integrator)
end
return code
Expand Down
17 changes: 17 additions & 0 deletions src/retcodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,23 @@ EnumX.@enumx ReturnCode begin
- successful_retcode = false
"""
Stalled

"""
`ReturnCode.InternalLinearSolveFailed`
The linear problem inside another problem (for example inside a NonlinearProblem)
could not be solved.
## Common Reasons for Seeing this Return Code
- If a rank-deficient matrix originated inside the nonlinear solve and the
provided linear solver is incapable of handling those cases.
## Properties
- successful_retcode = false
"""
InternalLinearSolveFailed
end

Base.:(!=)(retcode::ReturnCode.T, s::Symbol) = Symbol(retcode) != s
Expand Down
11 changes: 9 additions & 2 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ end
function (sol::AbstractODESolution)(t::Number, ::Type{deriv},
idxs::AbstractVector{<:Integer},
continuity) where {deriv}
if eltype(sol.u) <: Number
idxs = only(idxs)
end
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
end
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
Expand All @@ -183,6 +186,9 @@ end
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::AbstractVector{<:Integer},
continuity) where {deriv}
if eltype(sol.u) <: Number
idxs = only(idxs)
end
A = sol.interp(t, idxs, deriv, sol.prob.p, continuity)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
return DiffEqArray(A.u, A.t, p, sol)
Expand All @@ -203,7 +209,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
[is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs]
first(interp_sol[idxs])
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
Expand All @@ -224,8 +230,9 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
indexed_sol = interp_sol[idxs]
return DiffEqArray(
[[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol)
[indexed_sol[i] for i in 1:length(t)], t, p, sol)
end

function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
Expand Down
15 changes: 14 additions & 1 deletion src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,18 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic,
plot_vecs = []
labels = String[]
varsyms = variable_symbols(sol)
batch_symbolic_vars = []
for x in vars
for j in 2:length(x)
if (x[j] isa Integer && x[j] == 0) || isequal(x[j], getindepsym_defaultt(sol))
else
push!(batch_symbolic_vars, x[j])
end
end
end
batch_symbolic_vars = identity.(batch_symbolic_vars)
indexed_solution = sol(plott; idxs = batch_symbolic_vars)
idxx = 0
for x in vars
tmp = []
strs = String[]
Expand All @@ -444,7 +456,8 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic,
push!(tmp, plott)
push!(strs, "t")
else
push!(tmp, sol(plott; idxs = x[j]))
idxx += 1
push!(tmp, indexed_solution[idxx, :])
if !isempty(varsyms) && x[j] isa Integer
push!(strs, String(getname(varsyms[x[j]])))
elseif hasname(x[j])
Expand Down
92 changes: 92 additions & 0 deletions test/downstream/adjoints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Zygote, Test
using ModelingToolkit: t_nounits as t, D_nounits as D

@parameters σ ρ β
@variables x(t) y(t) z(t)

eqs = [D(x) ~ σ * (y - x),
D(y) ~ x *- z) - y,
D(z) ~ x * y - β * z]

@named lorenz1 = ODESystem(eqs, t)
@named lorenz2 = ODESystem(eqs, t)

@parameters γ
@variables a(t) α(t)
connections = [0 ~ lorenz1.x + lorenz2.y + a * γ,
α ~ 2lorenz1.x + a * γ]
@mtkbuild sys = ODESystem(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2])

u0 = [lorenz1.x => 1.0,
lorenz1.y => 0.0,
lorenz1.z => 0.0,
lorenz2.x => 0.0,
lorenz2.y => 1.0,
lorenz2.z => 0.0,
a => 2.0]

p = [lorenz1.σ => 10.0,
lorenz1.ρ => 28.0,
lorenz1.β => 8 / 3,
lorenz2.σ => 10.0,
lorenz2.ρ => 28.0,
lorenz2.β => 8 / 3,
γ => 2.0]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p)
sol = solve(prob, Rodas4())

gs_sym, = Zygote.gradient(sol) do sol
sum(sol[lorenz1.x])
end
idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x)
true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_sym[idx_sym] = 1.0

@test all(map(x -> x == true_grad_sym, gs_sym))

gs_vec, = Zygote.gradient(sol) do sol
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
end
idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_vecsym[idx_vecsym] .= 1.0

@test all(map(x -> x == true_grad_vecsym, gs_vec))

gs_tup, = Zygote.gradient(sol) do sol
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
end
idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_tupsym[idx_tupsym] .= 1.0

@test all(map(x -> x == true_grad_tupsym, gs_tup))

gs_ts, = Zygote.gradient(sol) do sol
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))
end

@test all(map(x -> x == true_grad_vecsym, gs_ts))

# BatchedInterface AD
@variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0
@named sys1 = ODESystem([D(x) ~ x + y, D(y) ~ y * z, D(z) ~ z * t * x], t)
sys1 = complete(sys1)
prob1 = ODEProblem(sys1, [], (0.0, 10.0))
@named sys2 = ODESystem([D(x) ~ x + w, D(y) ~ w * t, D(w) ~ x + y + w], t)
sys2 = complete(sys2)
prob2 = ODEProblem(sys2, [], (0.0, 10.0))

bi = BatchedInterface((sys1, [x, y, z]), (sys2, [x, y, w]))
getter = getu(bi)

p1grad, p2grad = Zygote.gradient(prob1, prob2) do prob1, prob2
sum(getter(prob1, prob2))
end

@test p1grad.u0 ones(3)
testp2grad = zeros(3)
testp2grad[variable_index(prob2, w)] = 1.0
@test p2grad.u0 testp2grad
3 changes: 2 additions & 1 deletion test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ eqs = [D(x) ~ Hold(ud)
xd ~ Sample(t, dt)(x)]
@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [p3 => 2p1])
prob = ODEProblem(sys, [x => 1.0], (0.0, 5.0),
[p1 => 1.0, p2 => 2, ud(k - 1) => 3.0, xd(k - 1) => 4.0, xd(k - 2) => 5.0])
[p1 => 1.0, p2 => 2, ud(k - 1) => 3.0,
xd(k - 1) => 4.0, xd(k - 2) => 5.0, yd(k - 1) => 0.0])

# parameter dependencies
prob2 = @inferred ODEProblem remake(prob; p = [p1 => 2.0])
Expand Down
42 changes: 5 additions & 37 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Zygote, Test
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface,
Test
using Optimization, OptimizationOptimJL
using ModelingToolkit: t_nounits as t, D_nounits as D

Expand Down Expand Up @@ -93,42 +94,9 @@ end
@test length(sol[(lorenz1.x, lorenz2.x)]) == length(sol)
@test all(length.(sol[(lorenz1.x, lorenz2.x)]) .== 2)

@test sol[[lorenz1.x, lorenz2.x], :] isa Matrix{Float64}
@test size(sol[[lorenz1.x, lorenz2.x], :]) == (2, length(sol))
@test size(sol[[lorenz1.x, lorenz2.x], :]) == size(sol[[1, 2], :]) == size(sol[1:2, :])

gs_sym, = Zygote.gradient(sol) do sol
sum(sol[lorenz1.x])
end
idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x)
true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_sym[idx_sym] = 1.

@test all(map(x -> x == true_grad_sym, gs_sym))

gs_vec, = Zygote.gradient(sol) do sol
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
end
idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_vecsym[idx_vecsym] .= 1.

@test all(map(x -> x == true_grad_vecsym, gs_vec))

gs_tup, = Zygote.gradient(sol) do sol
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
end
idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys)))
true_grad_tupsym[idx_tupsym] .= 1.

@test all(map(x -> x == true_grad_tupsym, gs_tup))

gs_ts, = Zygote.gradient(sol) do sol
sum(sol[[lorenz1.x, lorenz2.x], :])
end

@test all(map(x -> x == true_grad_vecsym, gs_ts))
@test sol[[lorenz1.x, lorenz2.x], :] isa Vector{Vector{Float64}}
@test length(sol[[lorenz1.x, lorenz2.x], :]) == length(sol)
@test length(sol[[lorenz1.x, lorenz2.x], :][1]) == 2

@variables q(t)[1:2] = [1.0, 2.0]
eqs = [D(q[1]) ~ 2q[1]
Expand Down
21 changes: 19 additions & 2 deletions test/integrator_tests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
using SciMLBase
mutable struct DummySolution
retcode::Any

struct DummySolution
retcode::SciMLBase.ReturnCode.T
end

SciMLBase.solution_new_retcode(::DummySolution, code) = DummySolution(code)

mutable struct DummyIntegrator{Alg, IIP, U, T} <: SciMLBase.DEIntegrator{Alg, IIP, U, T}
uprev::U
tprev::T
Expand Down Expand Up @@ -46,6 +49,9 @@ function SciMLBase.done(integrator::DummyIntegrator)
integrator.t > 10
end

SciMLBase.check_error(::DummyIntegrator) = ReturnCode.Success
SciMLBase.postamble!(::DummyIntegrator) = nothing

integrator = DummyIntegrator()
@test step_dt!(integrator, 1.5) == 2
@test step_dt!(integrator, 1.5, true) == 1.5
Expand All @@ -62,3 +68,14 @@ for (uprev, tprev, u, t) in intervals(DummyIntegrator())
end
@test eltype(collect(intervals(DummyIntegrator()))) ==
Tuple{Vector{Float64}, Float64, Vector{Float64}, Float64}

@test integrator.sol.retcode == ReturnCode.Default
@test check_error(integrator) == ReturnCode.Success
@test integrator.sol.retcode == ReturnCode.Default
@test SciMLBase.check_error!(integrator) == ReturnCode.Success
@test integrator.sol.retcode == ReturnCode.Success

let
integrator = DummyIntegrator()
@test 0 == @allocated SciMLBase.check_error!(integrator)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ end
include("downstream/problem_interface.jl")
end
end
@time @safetestset "Adjoints" begin
include("downstream/adjoints.jl")
end
end

if !is_APPVEYOR && GROUP == "Python"
Expand Down

0 comments on commit a69d087

Please sign in to comment.