From 785b052f2b0c9934b08d418c33d44a32e1d3caf6 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Sun, 12 May 2024 23:58:04 +0530 Subject: [PATCH] test: add test for observable functions --- test/downstream/observables_autodiff.jl | 34 +++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 test/downstream/observables_autodiff.jl diff --git a/test/downstream/observables_autodiff.jl b/test/downstream/observables_autodiff.jl new file mode 100644 index 000000000..e198bd703 --- /dev/null +++ b/test/downstream/observables_autodiff.jl @@ -0,0 +1,34 @@ +using ModelingToolkit, OrdinaryDiffEq +using Zygote + +@parameters σ ρ β +@variables x(t) y(t) z(t) w(t) + +eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z, + w ~ x + y + z + 2 * β] + +@mtkbuild sys = ODESystem(eqs, t) + +u0 = [D(x) => 2.0, + x => 1.0, + y => 0.0, + z => 0.0] + +p = [σ => 28.0, + ρ => 10.0, + β => 8 / 3] + +tspan = (0.0, 100.0) +prob = ODEProblem(sys, u0, tspan, p, jac = true) +sol = solve(prob, Tsit5()) + +@testset "AutoDiff Observable Functions" begin + gs, = gradient(sol) do sol + sum(sol[sys.w]) + end + du_ = [0., 1., 1., 1.] + du = [du_ for _ = sol.u] + @test du == gs.u +end