diff --git a/ext/FluxCUDAExt/FluxCUDAExt.jl b/ext/FluxCUDAExt/FluxCUDAExt.jl index 3fcdc5c263..9948c5f4c0 100644 --- a/ext/FluxCUDAExt/FluxCUDAExt.jl +++ b/ext/FluxCUDAExt/FluxCUDAExt.jl @@ -43,6 +43,7 @@ end ChainRulesCore.@non_differentiable check_use_cuda() include("functor.jl") +include("utils.jl") function __init__() Flux.CUDA_LOADED[] = true diff --git a/ext/FluxCUDAExt/utils.jl b/ext/FluxCUDAExt/utils.jl index f6ba3751ad..07500e9eb9 100644 --- a/ext/FluxCUDAExt/utils.jl +++ b/ext/FluxCUDAExt/utils.jl @@ -1 +1 @@ -rng_from_array(::CuArray) = CUDA.default_rng() \ No newline at end of file +Flux.rng_from_array(::CuArray) = CUDA.default_rng()