Skip to content

Commit

Permalink
✨ Add operators >>, ==, !=, and tests for them
Browse files Browse the repository at this point in the history
  • Loading branch information
heavywatal committed Nov 20, 2024
1 parent b11a880 commit fd2f59b
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 3 deletions.
64 changes: 61 additions & 3 deletions include/pcglite/pcglite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#define PCGLITE_PCGLITE_HPP_

#include <array>
#include <charconv>
#include <ios>
#include <limits>
#include <type_traits>
Expand Down Expand Up @@ -207,11 +208,33 @@ class permuted_congruential_engine {
state_ += increment_;
}

template <class T>
friend bool
operator==(const permuted_congruential_engine<T>&, const permuted_congruential_engine<T>&) noexcept;

template <class T>
friend bool
operator!=(const permuted_congruential_engine<T>&, const permuted_congruential_engine<T>&) noexcept;

template <class CharT, class Traits, class T>
friend std::basic_ostream<CharT, Traits>&
operator<<(std::basic_ostream<CharT, Traits>&, const permuted_congruential_engine<T>&);

template <class CharT, class Traits, class T>
friend std::basic_istream<CharT, Traits>&
operator>>(std::basic_istream<CharT, Traits>&, permuted_congruential_engine<T>&);
};

template <class T> bool
operator==(const permuted_congruential_engine<T>& x, const permuted_congruential_engine<T>& y) noexcept {
return (x.state_ == y.state_) && (x.increment_ == y.increment_);
}

template <class T> bool
operator!=(const permuted_congruential_engine<T>& x, const permuted_congruential_engine<T>& y) noexcept {
return !(x == y);
}

template <class CharT, class Traits>
std::basic_ostream<CharT, Traits>&
operator<<(std::basic_ostream<CharT, Traits>& ost, __uint128_t x) {
Expand All @@ -226,12 +249,47 @@ operator<<(std::basic_ostream<CharT, Traits>& ost, __uint128_t x) {
return ost;
}

template <class CharT, class Traits>
std::basic_istream<CharT, Traits>&
operator>>(std::basic_istream<CharT, Traits>& ist, __uint128_t& x) {
uint64_t high{}, low{};
char buffer[33];
ist.getline(buffer, 33, ' ');
const auto size = std::strlen(buffer); // ist.gcount() includes trailing \0
const auto end = buffer + size;
const auto begin_low = (size > 16) ? (end - 16) : buffer;
std::from_chars(buffer, begin_low, high, 16);
std::from_chars(begin_low, end, low, 16);
x = detail::constexpr_uint128(high, low);
return ist;
}

template <class CharT, class Traits, class T>
inline std::basic_ostream<CharT, Traits>&
operator<<(std::basic_ostream<CharT, Traits>& ost, const permuted_congruential_engine<T>& x) {
return ost << x.multiplier << " "
<< x.increment_ << " "
<< x.state_;
auto fillch = ost.fill();
auto flags = ost.flags(std::ios_base::dec | std::ios_base::left);
ost << x.multiplier << ' '
<< x.increment_ << ' '
<< x.state_;
ost.fill(fillch);
ost.flags(flags);
return ost;
}

template <class CharT, class Traits, class T>
std::basic_istream<CharT, Traits>&
operator>>(std::basic_istream<CharT, Traits>& ist, permuted_congruential_engine<T>& x) {
auto flags = ist.flags(std::ios_base::dec | std::ios_base::skipws);
typename permuted_congruential_engine<T>::state_type multiplier{}, increment{}, state{};
ist >> multiplier >> increment >> state;
if (!ist.fail()) {
if (multiplier != x.multiplier) ist.clear(std::ios_base::failbit);
x.increment_ = increment;
x.state_ = state;
}
ist.flags(flags);
return ist;
}

using pcg32 = permuted_congruential_engine<uint32_t>;
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(source_files
example.cpp
static.cpp
std.cpp
)

find_package(pcg)
Expand Down
50 changes: 50 additions & 0 deletions test/std.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <pcglite/pcglite.hpp>

#include <iostream>
#include <sstream>
#include <random>

inline int test_uint128_stream() {
using pcglite::operator<<;
using pcglite::operator>>;
__uint128_t x{42u}, z{};
auto y = pcglite::detail::constexpr_uint128(42u, 54u);
std::stringstream sst;
sst << x;
sst >> z;
std::cout << x << std::endl;
std::cout << z << std::endl;
if (x != z) return 1;
sst.clear();
sst << y;
sst >> z;
std::cout << y << std::endl;
std::cout << z << std::endl;
if (y != z) return 1;
return 0;
}

template <class URBG> inline
int test_rng_stream(URBG rng) {
URBG copy(rng);
if (rng != copy) return 1;
std::cout << rng << std::endl;
rng.operator()();
std::cout << rng << std::endl;
if (rng == copy) return 1;
std::stringstream sst;
sst << copy;
sst >> rng;
std::cout << rng << std::endl;
if (rng != copy) return 1;
return 0;
}

int main() {
int ret = 0;
ret |= test_uint128_stream();
ret |= test_rng_stream(std::minstd_rand{});
ret |= test_rng_stream(pcglite::pcg32{});
ret |= test_rng_stream(pcglite::pcg64{});
return ret;
}

0 comments on commit fd2f59b

Please sign in to comment.