Skip to content

Commit

Permalink
test: add simple adjoint tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 31, 2024
1 parent 783f4b9 commit 3ff8187
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ SafeTestsets = "0.0.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4"
Test = "1"
Zygote = "0.6.67"
julia = "1.10"

[extras]
Expand All @@ -27,6 +28,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays"]
test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays", "Zygote"]
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "BatchedInterface test" begin
@time include("batched_interface_test.jl")
end
@safetestset "Simple Adjoints test" begin
@time include("simple_adjoints_test.jl")
end
end

if GROUP == "All" || GROUP == "Downstream"
Expand Down
17 changes: 17 additions & 0 deletions test/simple_adjoints_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using SymbolicIndexingInterface
using Zygote

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
pstate = ProblemState(; u = rand(3), p = rand(3), t = rand())

getter = getu(sys, :x)
@test Zygote.gradient(getter, pstate)[1].u == [1.0, 0.0, 0.0]

getter = getu(sys, [:x, :z])
@test Zygote.gradient(sum getter, pstate)[1].u == [1.0, 0.0, 1.0]

getter = getu(sys, :a)
@test Zygote.gradient(getter, pstate)[1].p == [1.0, 0.0, 0.0]

getter = getu(sys, [:a, :c])
@test Zygote.gradient(sum getter, pstate)[1].p == [1.0, 0.0, 1.0]

0 comments on commit 3ff8187

Please sign in to comment.