Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 24, 2024
1 parent 2aef090 commit d33b2e3
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.15"
version = "0.1.16"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
23 changes: 23 additions & 0 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,26 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.farray isa Fill
end
end

if LuxAMDGPU.functional()
ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10))
ps_cpu = deepcopy(ps)
cdev = cpu_device()
for idx in 1:length(AMDGPU.devices())
amdgpu_device = gpu_device(idx)
@test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice
@test AMDGPU.device_id(amdgpu_device.device) == idx

ps = ps |> amdgpu_device
@test ps.weight isa ROCArray
@test ps.bias isa ROCArray
@test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx
@test AMDGPU.device_id(AMDGPU.device(ps.bias)) == idx
@test isequal(cdev(ps.weight), ps_cpu.weight)
@test isequal(cdev(ps.bias), ps_cpu.bias)
end

ps = ps |> cdev
@test ps.weight isa Array
@test ps.bias isa Array
end
23 changes: 23 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,26 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.farray isa Fill
end
end

if LuxCUDA.functional()
ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10))
ps_cpu = deepcopy(ps)
cdev = cpu_device()
for idx in 1:length(CUDA.devices())
cuda_device = gpu_device(idx)
@test typeof(cuda_device.device) <: CUDA.CuDevice
@test cuda_device.device.handle == (idx - 1)

ps = ps |> cuda_device
@test ps.weight isa CuArray
@test ps.bias isa CuArray
@test CUDA.device(ps.weight).handle == idx - 1
@test CUDA.device(ps.bias).handle == idx - 1
@test isequal(cdev(ps.weight), ps_cpu.weight)
@test isequal(cdev(ps.bias), ps_cpu.bias)
end

ps = ps |> cdev
@test ps.weight isa Array
@test ps.bias isa Array
end

0 comments on commit d33b2e3

Please sign in to comment.