Skip to content

Commit

Permalink
add LuxCore extension tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianM-C committed Feb 14, 2024
1 parent 2f25742 commit 9dbd42d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ julia = "1.10"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Expand All @@ -98,4 +99,4 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux"]
13 changes: 13 additions & 0 deletions test/extensions/lux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Symbolics, Lux, Random, Test

@testset "Symbolics extension" begin
model = Dense(5, 6)
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, st = LuxCore.setup(rng, model)

Symbolics.@variables sym_ps[1:5] = Float32[1, 2, 3, 4, 5]

out = LuxCore.partial_apply(model, sym_ps, ps, st)
@test out isa Symbolics.Arr
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Integral Test" begin include("integral.jl") end
@safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
@safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end
@safetestset "Registration without using Test" begin include("registration_without_using.jl") end
end

Expand Down

0 comments on commit 9dbd42d

Please sign in to comment.