Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: adjoints through observable functions #689

Merged
merged 26 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
92ad6a8
feat: adjoints through observable functions
DhairyaLGandhi May 2, 2024
22dc7ec
Update ext/SciMLBaseZygoteExt.jl
DhairyaLGandhi May 6, 2024
0c2b69d
feat: allow observables in collections
DhairyaLGandhi May 8, 2024
c61e08c
chore: handle no observables in collection
DhairyaLGandhi May 8, 2024
8600a8d
fix: typo
DhairyaLGandhi May 8, 2024
a69d087
Merge branch 'master' into dg/obsfn
DhairyaLGandhi May 8, 2024
785b052
test: add test for observable functions
DhairyaLGandhi May 12, 2024
adee4f0
test: add AD testset
DhairyaLGandhi May 12, 2024
2197a30
Update test/downstream/observables_autodiff.jl
ChrisRackauckas May 13, 2024
9172014
test: add a simple DAE example; disable till sensitivities are turned on
DhairyaLGandhi May 15, 2024
95cf416
test: add missing imports
DhairyaLGandhi May 15, 2024
4ce8257
chore: format
DhairyaLGandhi May 16, 2024
839bd63
chore: rm unwanted adjoint
DhairyaLGandhi May 16, 2024
2474a8d
test: check failures with SciMLSensitivity + SII
DhairyaLGandhi May 18, 2024
f68cb05
ci(SciMLSensitivity): checkout SII branch
DhairyaLGandhi May 18, 2024
9ab29d9
ci(SciMLSensitivity): use correct path
DhairyaLGandhi May 18, 2024
a417cdd
ci: revert changes
DhairyaLGandhi May 18, 2024
ff9bb2c
chore: revert literal_getproperty adjoint
DhairyaLGandhi May 19, 2024
032b927
chore: try to avoid returning object
DhairyaLGandhi May 19, 2024
44bfc91
build: add MSL to test deps
DhairyaLGandhi May 19, 2024
de2d6cd
chore: don't return structural tangent
DhairyaLGandhi May 20, 2024
c63dfbf
chore: fix imports
DhairyaLGandhi May 20, 2024
940ea78
test: add MSL to downstream env
DhairyaLGandhi May 20, 2024
8e48f1c
test: rm MSL from test env
DhairyaLGandhi May 20, 2024
d061ce4
chore: format
DhairyaLGandhi May 20, 2024
f817b52
Update CI.yml
ChrisRackauckas May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Logging = "1.10"
Makie = "0.20"
Markdown = "1.10"
ModelingToolkit = "8.75, 9"
ModelingToolkitStandardLibrary = "2.7"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be gated int Downstream

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added bounds to test/downstream/Project.toml in 940ea78, should I remove anything from the regular test environment or do i need to declare these in both places?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it from the regular

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d061ce4 does that

PartialFunctions = "1.1"
PrecompileTools = "1.2"
Preferences = "1.3"
Expand Down Expand Up @@ -96,6 +97,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -111,4 +113,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "ModelingToolkitStandardLibrary", "OrdinaryDiffEq", "ForwardDiff"]
70 changes: 53 additions & 17 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
using SciMLBase: ODESolution, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed,
observed, parameter_values, state_values, current_time
using RecursiveArrayTools
import SciMLStructures

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand Down Expand Up @@ -109,7 +111,18 @@
@adjoint function Base.getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if i === nothing
if is_observed(VA, sym)
f = observed(VA, sym)
p = parameter_values(VA)
tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
u = state_values(VA)
t = current_time(VA)
y, back = Zygote.pullback(u, tunables) do u, tunables
f.(u, Ref(tunables), t)

Check warning on line 121 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L114-L121

Added lines #L114 - L121 were not covered by tests
end
gs = back(Δ)
(gs[1], nothing)
elseif i === nothing

Check warning on line 125 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L123-L125

Added lines #L123 - L125 were not covered by tests
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
Expand All @@ -120,26 +133,49 @@
VA[sym], ODESolution_getindex_pullback
end

function obs_grads(VA, sym, obs_idx, Δ)
y, back = Zygote.pullback(VA) do sol
getindex.(Ref(sol), sym[obs_idx])

Check warning on line 138 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L136-L138

Added lines #L136 - L138 were not covered by tests
end
Δreduced = reduce(hcat, Δ)
Δobs = eachrow(Δreduced[obs_idx, :])
back(Δobs)

Check warning on line 142 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L140-L142

Added lines #L140 - L142 were not covered by tests
end

function obs_grads(VA, sym, ::Nothing, Δ)
Zygote.nt_nothing(VA)

Check warning on line 146 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L145-L146

Added lines #L145 - L146 were not covered by tests
end

function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T}
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
map(enumerate(us)) do (u_idx, u)
if u_idx in i
idx = findfirst(isequal(u_idx), i)
Δ[t_idx][idx]

Check warning on line 154 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L149-L154

Added lines #L149 - L154 were not covered by tests
else
zero(T)

Check warning on line 156 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L156

Added line #L156 was not covered by tests
end
end
end

Δ′

Check warning on line 161 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L161

Added line #L161 was not covered by tests
end

@adjoint function Base.getindex(
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
function ODESolution_getindex_pullback(Δ)
sym = sym isa Tuple ? collect(sym) : sym
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
map(enumerate(us)) do (u_idx, u)
if u_idx in i
idx = findfirst(isequal(u_idx), i)
Δ[t_idx][idx]
else
zero(T)
end
end
end
(Δ′, nothing)
end

obs_idx = findall(s -> is_observed(VA, s), sym)
not_obs_idx = setdiff(1:length(sym), obs_idx)

Check warning on line 171 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L170-L171

Added lines #L170 - L171 were not covered by tests

gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ)
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)

Check warning on line 174 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L173-L174

Added lines #L173 - L174 were not covered by tests

a = Zygote.accum(gs_obs[1], gs_not_obs)

Check warning on line 176 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L176

Added line #L176 was not covered by tests

(a, nothing)

Check warning on line 178 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L178

Added line #L178 was not covered by tests
end
VA[sym], ODESolution_getindex_pullback
end
Expand Down
102 changes: 102 additions & 0 deletions test/downstream/observables_autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using ModelingToolkit, OrdinaryDiffEq
using Zygote
using ModelingToolkit: t_nounits as t, D_nounits as D
import SymbolicIndexingInterface as SII
import SciMLStructures as SS
using ModelingToolkitStandardLibrary
import ModelingToolkitStandardLibrary as MSL

@parameters σ ρ β
@variables x(t) y(t) z(t) w(t)

eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z,
w ~ x + y + z + 2 * β]

@mtkbuild sys = ODESystem(eqs, t)

u0 = [D(x) => 2.0,
x => 1.0,
y => 0.0,
z => 0.0]

p = [σ => 28.0,
ρ => 10.0,
β => 8 / 3]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob, Tsit5())

@testset "AutoDiff Observable Functions" begin
gs, = gradient(sol) do sol
sum(sol[sys.w])
end
du_ = [0.0, 1.0, 1.0, 1.0]
du = [du_ for _ in sol.u]
@test du == gs

# Observable in a vector
gs, = gradient(sol) do sol
sum(sum.(sol[[sys.w, sys.x]]))
end
du_ = [0.0, 1.0, 1.0, 2.0]
du = [du_ for _ in sol.u]
@test du == gs
end

# DAE

function create_model(; C₁ = 3e-5, C₂ = 1e-6)
@variables t
@named resistor1 = MSL.Electrical.Resistor(R = 5.0)
@named resistor2 = MSL.Electrical.Resistor(R = 2.0)
@named capacitor1 = MSL.Electrical.Capacitor(C = C₁)
@named capacitor2 = MSL.Electrical.Capacitor(C = C₂)
@named source = MSL.Electrical.Voltage()
@named input_signal = MSL.Blocks.Sine(frequency = 100.0)
@named ground = MSL.Electrical.Ground()
@named ampermeter = MSL.Electrical.CurrentSensor()

eqs = [connect(input_signal.output, source.V)
connect(source.p, capacitor1.n, capacitor2.n)
connect(source.n, resistor1.p, resistor2.p, ground.g)
connect(resistor1.n, capacitor1.p, ampermeter.n)
connect(resistor2.n, capacitor2.p, ampermeter.p)]

@named circuit_model = ODESystem(eqs, t,
systems = [
resistor1, resistor2, capacitor1, capacitor2,
source, input_signal, ground, ampermeter
])
end

@testset "DAE Observable function AD" begin
model = create_model()
sys = structural_simplify(model)

prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob, Rodas4())

gs, = gradient(sol) do sol
sum(sol[sys.ampermeter.i])
end
du_ = [0.2, 1.0]
du = [du_ for _ in sol.u]
@test gs == du
end

# @testset "Adjoints with DAE" begin
# gs_mtkp, gs_p_new = gradient(mtkparams, p_new) do p, new_tunables
# new_p = SciMLStructures.replace(SciMLStructures.Tunable(), p, new_tunables)
# new_prob = remake(prob, p = new_p)
# sol = solve(new_prob, Rodas4())
# @show size(sol)
# # mean(abs.(sol[sys.ampermeter.i] .- gt))
# sum(sol[sys.ampermeter.i])
# end
#
# @test isnothing(gs_mtkp)
# @test length(gs_p_new) == length(p_new)
# end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ end
@time @safetestset "Partial Functions" begin
include("downstream/partial_functions.jl")
end
@time @safetestset "Autodiff Observable Functions" begin
include("downstream/observables_autodiff.jl")
end
end

if !is_APPVEYOR && (GROUP == "Downstream" || GROUP == "SymbolicIndexingInterface")
Expand Down
Loading