From 6565e9f45a2d1520b505acbd7494610620b6b1fb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 13:05:23 -0400 Subject: [PATCH] test: RNG movement --- test/amdgpu_tests.jl | 8 ++++++++ test/cuda_tests.jl | 8 ++++++++ test/metal_tests.jl | 8 ++++++++ test/oneapi_tests.jl | 8 ++++++++ test/xla_tests.jl | 8 ++++++++ 5 files changed, 40 insertions(+) diff --git a/test/amdgpu_tests.jl b/test/amdgpu_tests.jl index 67edff4..f29c279 100644 --- a/test/amdgpu_tests.jl +++ b/test/amdgpu_tests.jl @@ -57,7 +57,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa AMDGPUDevice + @test get_device_type(ps_xpu.rng_default) <: AMDGPUDevice @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(AMDGPUDevice) @test ps_xpu.one_elem isa ROCArray @@ -83,7 +87,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(AMDGPUDevice) @test ps_cpu.one_elem isa Array diff --git a/test/cuda_tests.jl b/test/cuda_tests.jl index 92c0a27..bd8a234 100644 --- a/test/cuda_tests.jl +++ b/test/cuda_tests.jl @@ -56,7 +56,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa CUDADevice + @test get_device_type(ps_xpu.rng_default) <: CUDADevice @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(CUDADevice) @test ps_xpu.one_elem isa CuArray @@ -82,7 +86,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(CUDADevice) @test ps_cpu.one_elem isa Array diff --git a/test/metal_tests.jl b/test/metal_tests.jl index 789fa49..3e4634c 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa MetalDevice + @test get_device_type(ps_xpu.rng_default) <: MetalDevice @test ps_xpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(MetalDevice) @test ps_xpu.one_elem isa MtlArray @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(MetalDevice) @test ps_cpu.one_elem isa Array diff --git a/test/oneapi_tests.jl b/test/oneapi_tests.jl index 7731c43..3441bad 100644 --- a/test/oneapi_tests.jl +++ b/test/oneapi_tests.jl @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa oneAPIDevice + @test get_device_type(ps_xpu.rng_default) <: oneAPIDevice @test ps_xpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(oneAPIDevice) @test ps_xpu.one_elem isa oneArray @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(oneAPIDevice) @test ps_cpu.one_elem isa Array diff --git a/test/xla_tests.jl b/test/xla_tests.jl index 81ae929..3cd6b15 100644 --- a/test/xla_tests.jl +++ b/test/xla_tests.jl @@ -54,7 +54,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_xpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(XLADevice) @test ps_xpu.one_elem isa Reactant.RArray @@ -80,7 +84,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(XLADevice) @test ps_cpu.one_elem isa Array