diff --git a/test/reactant/tracing_tests.jl b/test/reactant/tracing_tests.jl new file mode 100644 index 000000000..7b710544b --- /dev/null +++ b/test/reactant/tracing_tests.jl @@ -0,0 +1,22 @@ +@testitem "Tracing" tags = [:reactant] skip = :(Sys.iswindows()) begin + using Reactant, Lux, Random + + model = Chain(Dense(2 => 3, relu), BatchNorm(3), Dense(3 => 2)) + ps, st = Lux.setup(Random.default_rng(), model) + + smodel = StatefulLuxLayer{true}(model, ps, st) + smodel_ra = Reactant.to_rarray(smodel) + + @test get_device_type(smodel_ra.ps) <: ReactantDevice + @test get_device_type(smodel_ra.st) <: ReactantDevice + @test smodel_ra.st_any === nothing + @test smodel_ra.fixed_state_type == smodel.fixed_state_type + + smodel = StatefulLuxLayer{false}(model, ps, st) + smodel_ra = Reactant.to_rarray(smodel) + + @test get_device_type(smodel_ra.ps) <: ReactantDevice + @test get_device_type(smodel_ra.st_any) <: ReactantDevice + @test smodel_ra.st === nothing + @test smodel_ra.fixed_state_type == smodel.fixed_state_type +end