diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl index 0ec32797a0..aa1f431fe7 100644 --- a/test/ext_cuda/runtests.jl +++ b/test/ext_cuda/runtests.jl @@ -32,3 +32,7 @@ end @testset "ctc" begin include("ctc.jl") end +@testset "utils" begin + include("utils.jl") +end + diff --git a/test/ext_cuda/utils.jl b/test/ext_cuda/utils.jl index edcb1f0161..53cd900638 100644 --- a/test/ext_cuda/utils.jl +++ b/test/ext_cuda/utils.jl @@ -9,3 +9,9 @@ @test y isa Wrapped{<:CuArray} end end + +@testset "rng_from_array" begin + x = cu(randn(2,2)) + rng = Flux.rng_from_array(x) + @test rng == CUDA.default_rng() +end