Skip to content

Commit

Permalink
test: StatefulLuxLayer tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 7, 2025
1 parent 1cc3a47 commit 5c1122c
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions test/reactant/tracing_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@testitem "Tracing" tags = [:reactant] skip = :(Sys.iswindows()) begin
using Reactant, Lux, Random

model = Chain(Dense(2 => 3, relu), BatchNorm(3), Dense(3 => 2))
ps, st = Lux.setup(Random.default_rng(), model)

smodel = StatefulLuxLayer{true}(model, ps, st)
smodel_ra = Reactant.to_rarray(smodel)

@test get_device_type(smodel_ra.ps) <: ReactantDevice
@test get_device_type(smodel_ra.st) <: ReactantDevice
@test smodel_ra.st_any === nothing
@test smodel_ra.fixed_state_type == smodel.fixed_state_type

smodel = StatefulLuxLayer{false}(model, ps, st)
smodel_ra = Reactant.to_rarray(smodel)

@test get_device_type(smodel_ra.ps) <: ReactantDevice
@test get_device_type(smodel_ra.st_any) <: ReactantDevice
@test smodel_ra.st === nothing
@test smodel_ra.fixed_state_type == smodel.fixed_state_type
end

0 comments on commit 5c1122c

Please sign in to comment.