diff --git a/test/misc_tests.jl b/test/misc_tests.jl index a84b8be..614cc20 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -167,12 +167,13 @@ end x end Functors.@functor Tleaf - MLDataDevices.isleaf(::Tleaf) = true - Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) cpu = cpu_device() t = Tleaf(ones(2)) - @test cpu(t).x == 2 .* ones(2) + t = cpu(t) + @test y.x == 2 .* ones(2) + y = cpu([(t,)]) + @test y[1][1].x == 2 .* ones(2) end