diff --git a/include/cuco/utility/key_generator.cuh b/include/cuco/utility/key_generator.cuh index d58c8cf08..37a3345e0 100644 --- a/include/cuco/utility/key_generator.cuh +++ b/include/cuco/utility/key_generator.cuh @@ -20,7 +20,8 @@ #include #include -#include +#include +#include // TODO include instead once available #include #include #include @@ -80,6 +81,7 @@ struct gaussian : public cuco::detail::strong_type { } // namespace distribution namespace detail { + /** * @brief Generate uniform functor * @@ -94,29 +96,37 @@ struct generate_uniform_fn { * * @param num Number of elements to generate * @param dist Random number distribution + * @param seed Random seed */ - __host__ __device__ constexpr generate_uniform_fn(std::size_t num, Dist dist) - : num_{num}, dist_{dist} + __host__ __device__ constexpr generate_uniform_fn(std::size_t num, Dist dist, std::size_t seed) + : num_{num}, dist_{dist}, seed_{seed} { } /** * @brief Generates a random number of type `T` based on the given `seed` * - * @param seed Random number generator seed + * @param idx Index of the output element * * @return A resulting random number */ - __host__ __device__ constexpr T operator()(std::size_t seed) const noexcept + __host__ __device__ constexpr T operator()(std::size_t idx) const noexcept { RNG rng; - thrust::uniform_int_distribution uniform_dist{1, static_cast(num_ / dist_.value)}; - rng.seed(seed); + // Improved seeding using a linear congruential generator + rng.seed(seed_ + idx * 1664525ull + 1013904223ull); + // Calculate number of unique keys + auto num_unique_keys = cuda::std::max( + 1ull, + static_cast( + cuda::std::ceil(static_cast(num_) / static_cast(dist_.value)))); + thrust::uniform_int_distribution uniform_dist{0, static_cast(num_unique_keys - 1)}; return uniform_dist(rng); } - std::size_t num_; ///< Number of elements to generate - Dist dist_; ///< Random number distribution + std::size_t num_; ///< Number of elements to generate + Dist dist_; ///< Random number distribution + std::size_t seed_; ///< Random seed }; /** @@ -270,18 +280,17 @@ class key_generator { using value_type = typename std::iterator_traits::value_type; if constexpr (std::is_same_v) { - thrust::sequence(exec_policy, out_begin, out_end, 0); + thrust::sequence(exec_policy, out_begin, out_end, value_type{0}); thrust::shuffle(exec_policy, out_begin, out_end, this->rng_); } else if constexpr (std::is_same_v) { size_t num_keys = thrust::distance(out_begin, out_end); - - thrust::counting_iterator seeds(this->rng_()); + size_t seed = this->rng_(); thrust::transform(exec_policy, - seeds, - seeds + num_keys, + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_keys), out_begin, - detail::generate_uniform_fn{num_keys, dist}); + detail::generate_uniform_fn{num_keys, dist, seed}); } else if constexpr (std::is_same_v) { size_t num_keys = thrust::distance(out_begin, out_end);