Skip to content

Commit

Permalink
Fix return value of randn! on empty inputs (#528)
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd authored Apr 5, 2024
1 parent 4623226 commit d187fd9
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ Manifest.toml
.*.swp
.*.swo
*~

# MacOS generated files
*.DS_Store
2 changes: 1 addition & 1 deletion src/host/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractGPUArray} = a isa T
## convenience constructors

function Base.fill!(A::AnyGPUArray{T}, x) where T
length(A) == 0 && return A
isempty(A) && return A
gpu_call(A, convert(T, x)) do ctx, a, val
idx = @linearidx(a)
@inbounds a[idx] = val
Expand Down
3 changes: 2 additions & 1 deletion src/host/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ function Random.seed!(rng::RNG, seed::Vector{UInt32})
end

function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
isempty(A) && return A
gpu_call(A, rng.state) do ctx, a, randstates
idx = linear_index(ctx)
idx > length(a) && return
Expand All @@ -94,8 +95,8 @@ function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
end

function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
isempty(A) && return A
threads = (length(A) - 1) ÷ 2 + 1
length(A) == 0 && return
gpu_call(A, rng.state; elements = threads) do ctx, a, randstates
idx = 2*(linear_index(ctx) - 1) + 1
U1 = gpu_rand(T, ctx, randstates)
Expand Down
6 changes: 6 additions & 0 deletions test/testsuite/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@

@testset "convenience" begin
for T in eltypes
A = AT(rand(T, 0))
b = rand(T)
fill!(A, b)
@test A isa AT{T,1}
@test Array(A) == fill(b, 0)

A = AT(rand(T, 3))
b = rand(T)
fill!(A, b)
Expand Down
12 changes: 12 additions & 0 deletions test/testsuite/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
fill!(A, true)
rand!(rng, A)
@test false in Array(A)

# AT of length 0
B = AT{Float32}(undef, 0)
fill!(B, 1f0)
rand!(rng, B)
@test isempty(Array(B))
end

@testset "randn" begin # normally-distributed
Expand All @@ -56,5 +62,11 @@
randn!(cpu_rng, A)
end
end

# AT of length 0
A = AT{Float32}(undef, 0)
fill!(A, 1f0)
randn!(rng, A)
@test isempty(Array(A))
end
end

0 comments on commit d187fd9

Please sign in to comment.