Skip to content

Commit

Permalink
build: add MSL to test deps
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed May 19, 2024
1 parent 032b927 commit 44bfc91
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 61 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Logging = "1.10"
Makie = "0.20"
Markdown = "1.10"
ModelingToolkit = "8.75, 9"
ModelingToolkitStandardLibrary = "2.7"
PartialFunctions = "1.1"
PrecompileTools = "1.2"
Preferences = "1.3"
Expand Down Expand Up @@ -96,6 +97,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -111,4 +113,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "ModelingToolkitStandardLibrary", "OrdinaryDiffEq", "ForwardDiff"]
4 changes: 1 addition & 3 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ end
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
= @. ifelse=== nothing, (zerou,), Δ)
nt = Zygote.nt_nothing(sol)
gs = Zygote.accum(nt, (u = _Δ,))
(gs,)
(build_solution(sol.prob, sol.alg, sol.t, _Δ),)
end
sol.u, solu_adjoint
end
Expand Down
79 changes: 22 additions & 57 deletions test/downstream/observables_autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ using Zygote
using ModelingToolkit: t_nounits as t, D_nounits as D
import SymbolicIndexingInterface as SII
import SciMLStructures as SS
using ModelingToolkitStandardLibrary
import ModelingToolkitStandardLibrary as MSL

@parameters σ ρ β
@variables x(t) y(t) z(t) w(t)
Expand Down Expand Up @@ -34,60 +36,18 @@ sol = solve(prob, Tsit5())
du_ = [0.0, 1.0, 1.0, 1.0]
du = [du_ for _ in sol.u]
@test du == gs.u
end

# Lorenz

@parameters σ ρ β
@variables x(t) y(t) z(t)

eqs = [D(x) ~ σ * (y - x),
D(y) ~ x *- z) - y,
D(z) ~ x * y - β * z]

@named lorenz1 = ODESystem(eqs, t)
@named lorenz2 = ODESystem(eqs, t)

@parameters γ
@variables a(t) α(t)
connections = [0 ~ lorenz1.x + lorenz2.y + a * γ,
α ~ 2lorenz1.x + a * γ]
@mtkbuild sys = ODESystem(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2])

u0 = [lorenz1.x => 1.0,
lorenz1.y => 0.0,
lorenz1.z => 0.0,
lorenz2.x => 0.0,
lorenz2.y => 1.0,
lorenz2.z => 0.0,
a => 2.0]

p = [lorenz1.σ => 10.0,
lorenz1.ρ => 28.0,
lorenz1.β => 8 / 3,
lorenz2.σ => 10.0,
lorenz2.ρ => 28.0,
lorenz2.β => 8 / 3,
γ => 2.0]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p)
integ = init(prob, Rodas4())
sol = solve(prob, Rodas4())

gt = reduce(hcat, sol[[sys.a, sys.α]]) .+ randn.()

gs, = Zygote.gradient(sol) do sol
mean(abs.(sol[[sys.a, sys.α]] .- gt), dims = 2)
# Observable in a vector
gs, = gradient(sol) do sol
sum(sum.(sol[[sys.w, sys.x]]))
end
du_ = [0.0, 1.0, 1.0, 2.0]
du = [du_ for _ in sol.u]
@test du == gs.u
end

# DAE

using ModelingToolkit, OrdinaryDiffEq, Zygote
using ModelingToolkitStandardLibrary
import ModelingToolkitStandardLibrary as MSL
using SciMLStructures

function create_model(; C₁ = 3e-5, C₂ = 1e-6)
@variables t
@named resistor1 = MSL.Electrical.Resistor(R = 5.0)
Expand All @@ -112,15 +72,20 @@ function create_model(; C₁ = 3e-5, C₂ = 1e-6)
])
end

model = create_model()
sys = structural_simplify(model)
@testset "DAE Observable function AD" begin
model = create_model()
sys = structural_simplify(model)

prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob, Rodas4())

prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob, Rodas4())
pf = getp(sol, sys.resistor1.R)
mtkparams = SII.parameter_values(sol)
tunables, _, _ = SS.canonicalize(SS.Tunable(), mtkparams)
p_new = rand(length(tunables))
gs, = gradient(sol) do sol
sum(sol[sys.ampermeter.i])
end
du_ = [0.2, 1.0]
du = [du_ for _ in sol.u]
@test gs.u == du
end

# @testset "Adjoints with DAE" begin
# gs_mtkp, gs_p_new = gradient(mtkparams, p_new) do p, new_tunables
Expand Down

0 comments on commit 44bfc91

Please sign in to comment.