diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 34b3e7e..1a3093d 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -4,20 +4,22 @@ using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools -@testset "https://github.com/LuxDL/MLDataDevices.jl/issues/10 patch" begin - dev = CPUDevice() - ps = (; weight=randn(10, 1), bias=randn(1)) +@testset "Issues Patches" begin + @testset "#10 patch" begin + dev = CPUDevice() + ps = (; weight=randn(10, 1), bias=randn(1)) - ps_ca = ps |> ComponentArray + ps_ca = ps |> ComponentArray - ps_ca_dev = ps_ca |> dev + ps_ca_dev = ps_ca |> dev - @test ps_ca_dev isa ComponentArray + @test ps_ca_dev isa ComponentArray - @test ps_ca_dev.weight == ps.weight - @test ps_ca_dev.bias == ps.bias + @test ps_ca_dev.weight == ps.weight + @test ps_ca_dev.bias == ps.bias - @test ps_ca_dev == (ps |> dev |> ComponentArray) + @test ps_ca_dev == (ps |> dev |> ComponentArray) + end end @testset "AD Types" begin