From d187fd99599ef41a63cfd45128d4c104434d87a5 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 5 Apr 2024 10:27:32 -0300 Subject: [PATCH] Fix return value of randn! on empty inputs (#528) --- .gitignore | 3 +++ src/host/construction.jl | 2 +- src/host/random.jl | 3 ++- test/testsuite/construction.jl | 6 ++++++ test/testsuite/random.jl | 12 ++++++++++++ 5 files changed, 24 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index cf6b5669..9b73c974 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ Manifest.toml .*.swp .*.swo *~ + +# MacOS generated files +*.DS_Store diff --git a/src/host/construction.jl b/src/host/construction.jl index a456606b..d80bce2d 100644 --- a/src/host/construction.jl +++ b/src/host/construction.jl @@ -10,7 +10,7 @@ Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractGPUArray} = a isa T ## convenience constructors function Base.fill!(A::AnyGPUArray{T}, x) where T - length(A) == 0 && return A + isempty(A) && return A gpu_call(A, convert(T, x)) do ctx, a, val idx = @linearidx(a) @inbounds a[idx] = val diff --git a/src/host/random.jl b/src/host/random.jl index 09e4257d..b7e5dc74 100644 --- a/src/host/random.jl +++ b/src/host/random.jl @@ -84,6 +84,7 @@ function Random.seed!(rng::RNG, seed::Vector{UInt32}) end function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number + isempty(A) && return A gpu_call(A, rng.state) do ctx, a, randstates idx = linear_index(ctx) idx > length(a) && return @@ -94,8 +95,8 @@ function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number end function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number + isempty(A) && return A threads = (length(A) - 1) รท 2 + 1 - length(A) == 0 && return gpu_call(A, rng.state; elements = threads) do ctx, a, randstates idx = 2*(linear_index(ctx) - 1) + 1 U1 = gpu_rand(T, ctx, randstates) diff --git a/test/testsuite/construction.jl b/test/testsuite/construction.jl index 8982e16e..ed42518b 100644 --- a/test/testsuite/construction.jl +++ b/test/testsuite/construction.jl @@ -102,6 +102,12 @@ @testset "convenience" begin for T in eltypes + A = AT(rand(T, 0)) + b = rand(T) + fill!(A, b) + @test A isa AT{T,1} + @test Array(A) == fill(b, 0) + A = AT(rand(T, 3)) b = rand(T) fill!(A, b) diff --git a/test/testsuite/random.jl b/test/testsuite/random.jl index 5488484b..f2cf832a 100644 --- a/test/testsuite/random.jl +++ b/test/testsuite/random.jl @@ -33,6 +33,12 @@ fill!(A, true) rand!(rng, A) @test false in Array(A) + + # AT of length 0 + B = AT{Float32}(undef, 0) + fill!(B, 1f0) + rand!(rng, B) + @test isempty(Array(B)) end @testset "randn" begin # normally-distributed @@ -56,5 +62,11 @@ randn!(cpu_rng, A) end end + + # AT of length 0 + A = AT{Float32}(undef, 0) + fill!(A, 1f0) + randn!(rng, A) + @test isempty(Array(A)) end end