Skip to content

Commit

Permalink
[SYCLomatic] Refine rng_utils.hpp (#2510)
Browse files Browse the repository at this point in the history
Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 authored Nov 27, 2024
1 parent a2dc3f5 commit 6ae8d31
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions clang/runtime/dpct-rt/include/dpct/rng_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,29 +338,26 @@ class rng_generator : public rng_generator_base {
/// Set the seed of host rng_generator.
/// \param seed The engine seed.
void set_seed(const std::uint64_t seed) {
if (seed == _seed) {
if (seed == _seed)
return;
}
_seed = seed;
_engine = create_engine(_queue, _seed, _dimensions, _mode);
}

/// Set the dimensions of host rng_generator.
/// \param dimensions The engine dimensions.
void set_dimensions(const std::uint32_t dimensions) {
if (dimensions == _dimensions) {
if (dimensions == _dimensions)
return;
}
_dimensions = dimensions;
_engine = create_engine(_queue, _seed, _dimensions, _mode);
}

/// Set the queue of host rng_generator.
/// \param queue The engine queue.
void set_queue(sycl::queue *queue) {
if (queue == _queue) {
if (queue == _queue)
return;
}
_queue = queue;
_engine = create_engine(_queue, _seed, _dimensions, _mode);
}
Expand All @@ -374,9 +371,8 @@ class rng_generator : public rng_generator_base {
if constexpr (!std::is_same_v<engine_t, oneapi::mkl::rng::mrg32k3a>) {
throw std::runtime_error("Only mrg32k3a engine support this method.");
}
if (mode == _mode) {
if (mode == _mode)
return;
}
_mode = mode;
_engine = create_engine(_queue, _seed, _dimensions, _mode);
#endif
Expand All @@ -390,11 +386,11 @@ class rng_generator : public rng_generator_base {
throw std::runtime_error(OneMKLNotSupport);
#else
if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::sobol>) {
if (direction_numbers == _direction_numbers) {
if (direction_numbers == _direction_numbers)
return;
}
_direction_numbers = direction_numbers;
_engine = oneapi::mkl::rng::sobol(*_queue, _direction_numbers);
_engine =
create_engine(_queue, _seed, _dimensions, _mode, _direction_numbers);
} else {
throw std::runtime_error("Only Sobol engine supports this method.");
}
Expand All @@ -409,11 +405,11 @@ class rng_generator : public rng_generator_base {
"Interfaces Project does not support this API.");
#else
if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::mt2203>) {
if (engine_idx == _engine_idx) {
if (engine_idx == _engine_idx)
return;
}
_engine_idx = engine_idx;
_engine = oneapi::mkl::rng::mt2203(*_queue, _seed, _engine_idx);
_engine = create_engine(_queue, _seed, _dimensions, _mode, std::nullopt,
_engine_idx);
} else {
throw std::runtime_error("Only MT2203 engine supports this method.");
}
Expand Down Expand Up @@ -525,10 +521,12 @@ class rng_generator : public rng_generator_base {
}

private:
static inline engine_t create_engine(sycl::queue *queue,
const std::uint64_t seed,
const std::uint32_t dimensions,
const random_mode mode) {
static inline engine_t
create_engine(sycl::queue *queue, const std::uint64_t seed,
const std::uint32_t dimensions, const random_mode mode,
std::optional<std::vector<std::uint32_t>> direction_numbers =
std::nullopt,
std::optional<std::uint32_t> engine_idx = std::nullopt) {
#ifdef __INTEL_MKL__
if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::mrg32k3a>) {
// oneapi::mkl::rng::mrg32k3a_mode is only supported for GPU device. For
Expand All @@ -546,13 +544,18 @@ class rng_generator : public rng_generator_base {
oneapi::mkl::rng::mrg32k3a_mode::optimal_v);
}
}
} else if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::mt2203>) {
if (engine_idx.has_value()) {
return engine_t(*queue, seed, engine_idx.value());
}
} else if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::sobol>) {
if (direction_numbers.has_value()) {
return engine_t(*queue, direction_numbers.value());
}
return engine_t(*queue, dimensions);
}
return std::is_same_v<engine_t, oneapi::mkl::rng::sobol>
? engine_t(*queue, dimensions)
: engine_t(*queue, seed);
#else
return engine_t(*queue, seed);
#endif
return engine_t(*queue, seed);
}

template <typename distr_t, typename buffer_t, class... distr_params_t>
Expand Down

0 comments on commit 6ae8d31

Please sign in to comment.