Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
test: RNG movement
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 18, 2024
1 parent 46fccbb commit 6565e9f
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 0 deletions.
8 changes: 8 additions & 0 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa AMDGPUDevice
@test get_device_type(ps_xpu.rng_default) <: AMDGPUDevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(AMDGPUDevice)
@test ps_xpu.one_elem isa ROCArray
Expand All @@ -83,7 +87,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(AMDGPUDevice)
@test ps_cpu.one_elem isa Array
Expand Down
8 changes: 8 additions & 0 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa CUDADevice
@test get_device_type(ps_xpu.rng_default) <: CUDADevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(CUDADevice)
@test ps_xpu.one_elem isa CuArray
Expand All @@ -82,7 +86,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(CUDADevice)
@test ps_cpu.one_elem isa Array
Expand Down
8 changes: 8 additions & 0 deletions test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa MetalDevice
@test get_device_type(ps_xpu.rng_default) <: MetalDevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(MetalDevice)
@test ps_xpu.one_elem isa MtlArray
Expand All @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(MetalDevice)
@test ps_cpu.one_elem isa Array
Expand Down
8 changes: 8 additions & 0 deletions test/oneapi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa oneAPIDevice
@test get_device_type(ps_xpu.rng_default) <: oneAPIDevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(oneAPIDevice)
@test ps_xpu.one_elem isa oneArray
Expand All @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(oneAPIDevice)
@test ps_cpu.one_elem isa Array
Expand Down
8 changes: 8 additions & 0 deletions test/xla_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_xpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(XLADevice)
@test ps_xpu.one_elem isa Reactant.RArray
Expand All @@ -80,7 +84,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(XLADevice)
@test ps_cpu.one_elem isa Array
Expand Down

0 comments on commit 6565e9f

Please sign in to comment.