diff --git a/common/include/common_defs.hpp b/common/include/common_defs.hpp index d8e3e6ca..8a61ff30 100644 --- a/common/include/common_defs.hpp +++ b/common/include/common_defs.hpp @@ -91,6 +91,23 @@ static inline void write(std::ostream& os, const T* ptr, size_t size_bytes) { os.write(reinterpret_cast(ptr), size_bytes); } +template +T byteswap(T value) { + char* ptr = static_cast(static_cast(&value)); + const int len = sizeof(T); + for (size_t i = 0; i < len / 2; ++i) { + std::swap(ptr[i], ptr[len - i - 1]); + } + return value; +} + +template +static inline T read_big_endian(std::istream& is) { + T value; + is.read(reinterpret_cast(&value), sizeof(T)); + return byteswap(value); +} + // wrapper for iterators to implement operator-> returning temporary value template class return_value_holder { diff --git a/tdigest/include/tdigest.hpp b/tdigest/include/tdigest.hpp index 357f203b..d60a3969 100644 --- a/tdigest/include/tdigest.hpp +++ b/tdigest/include/tdigest.hpp @@ -228,6 +228,9 @@ class tdigest { static const uint8_t SERIAL_VERSION = 1; static const uint8_t SKETCH_TYPE = 20; + static const uint8_t COMPAT_DOUBLE = 1; + static const uint8_t COMPAT_FLOAT = 2; + enum flags { IS_EMPTY, REVERSE_MERGE }; // for deserialize @@ -238,6 +241,10 @@ class tdigest { void merge_new_values(uint16_t k); static double weighted_average(double x1, double w1, double x2, double w2); + + // for compatibility with format of the reference implementation + static tdigest deserialize_compat(std::istream& is, const Allocator& allocator = Allocator()); + static tdigest deserialize_compat(const void* bytes, size_t size, const Allocator& allocator = Allocator()); }; } /* namespace datasketches */ diff --git a/tdigest/include/tdigest_impl.hpp b/tdigest/include/tdigest_impl.hpp index c47202e6..1c3b3451 100644 --- a/tdigest/include/tdigest_impl.hpp +++ b/tdigest/include/tdigest_impl.hpp @@ -352,6 +352,7 @@ tdigest tdigest::deserialize(std::istream& is, const A& allocator) { const auto serial_version = read(is); const auto sketch_type = read(is); if (sketch_type != SKETCH_TYPE) { + if (preamble_longs == 0 && serial_version == 0 && sketch_type == 0) return deserialize_compat(is, allocator); throw std::invalid_argument("sketch type mismatch: expected " + std::to_string(SKETCH_TYPE) + ", actual " + std::to_string(sketch_type)); } if (serial_version != SERIAL_VERSION) { @@ -391,6 +392,7 @@ tdigest tdigest::deserialize(const void* bytes, size_t size, const A const uint8_t serial_version = *ptr++; const uint8_t sketch_type = *ptr++; if (sketch_type != SKETCH_TYPE) { + if (preamble_longs == 0 && serial_version == 0 && sketch_type == 0) return deserialize_compat(ptr, end_ptr - ptr, allocator); throw std::invalid_argument("sketch type mismatch: expected " + std::to_string(SKETCH_TYPE) + ", actual " + std::to_string(sketch_type)); } if (serial_version != SERIAL_VERSION) { @@ -426,6 +428,124 @@ tdigest tdigest::deserialize(const void* bytes, size_t size, const A return tdigest(reverse_merge, k, min, max, std::move(centroids), total_weight, allocator); } +// compatibility with the format of the reference implementation +// default byte order of ByteBuffer is used there, which is big endian +template +tdigest tdigest::deserialize_compat(std::istream& is, const A& allocator) { + // this method was called because the first three bytes were zeros + // so read one more byte to see if it looks like the reference implementation format + const auto type = read(is); + if (type != COMPAT_DOUBLE && type != COMPAT_FLOAT) { + throw std::invalid_argument("unexpected sketch preamble: 0 0 0 " + std::to_string(type)); + } + if (type == COMPAT_DOUBLE) { // compatibility with asBytes() + const auto min = read_big_endian(is); + const auto max = read_big_endian(is); + const auto k = static_cast(read_big_endian(is)); + const auto num_centroids = read_big_endian(is); + vector_centroid centroids(num_centroids, centroid(0, 0), allocator); + uint64_t total_weight = 0; + for (auto& c: centroids) { + const uint64_t weight = static_cast(read_big_endian(is)); + const auto mean = read_big_endian(is); + c = centroid(mean, weight); + total_weight += weight; + } + return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator); + } + // COMPAT_FLOAT: compatibility with asSmallBytes() + const auto min = read_big_endian(is); // reference implementation uses doubles for min and max + const auto max = read_big_endian(is); + const auto k = static_cast(read_big_endian(is)); + // reference implementation stores capacities of the array of centroids and the buffer as shorts + // they can be derived from k in the constructor + read(is); // unused + const auto num_centroids = read_big_endian(is); + vector_centroid centroids(num_centroids, centroid(0, 0), allocator); + uint64_t total_weight = 0; + for (auto& c: centroids) { + const uint64_t weight = static_cast(read_big_endian(is)); + const auto mean = read_big_endian(is); + c = centroid(mean, weight); + total_weight += weight; + } + return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator); +} + +// compatibility with the format of the reference implementation +// default byte order of ByteBuffer is used there, which is big endian +template +tdigest tdigest::deserialize_compat(const void* bytes, size_t size, const A& allocator) { + const char* ptr = static_cast(bytes); + // this method was called because the first three bytes were zeros + // so read one more byte to see if it looks like the reference implementation format + const auto type = *ptr++; + if (type != COMPAT_DOUBLE && type != COMPAT_FLOAT) { + throw std::invalid_argument("unexpected sketch preamble: 0 0 0 " + std::to_string(type)); + } + const char* end_ptr = static_cast(bytes) + size; + if (type == COMPAT_DOUBLE) { // compatibility with asBytes() + ensure_minimum_memory(end_ptr - ptr, sizeof(double) * 3 + sizeof(uint32_t)); + double min; + ptr += copy_from_mem(ptr, min); + min = byteswap(min); + double max; + ptr += copy_from_mem(ptr, max); + max = byteswap(max); + double k_double; + ptr += copy_from_mem(ptr, k_double); + const uint16_t k = static_cast(byteswap(k_double)); + uint32_t num_centroids; + ptr += copy_from_mem(ptr, num_centroids); + num_centroids = byteswap(num_centroids); + ensure_minimum_memory(end_ptr - ptr, sizeof(double) * num_centroids * 2); + vector_centroid centroids(num_centroids, centroid(0, 0), allocator); + uint64_t total_weight = 0; + for (auto& c: centroids) { + double weight; + ptr += copy_from_mem(ptr, weight); + weight = byteswap(weight); + double mean; + ptr += copy_from_mem(ptr, mean); + mean = byteswap(mean); + c = centroid(mean, static_cast(weight)); + total_weight += static_cast(weight); + } + return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator); + } + // COMPAT_FLOAT: compatibility with asSmallBytes() + ensure_minimum_memory(end_ptr - ptr, sizeof(double) * 2 + sizeof(float) + sizeof(uint16_t) * 3); + double min; // reference implementation uses doubles for min and max + ptr += copy_from_mem(ptr, min); + min = byteswap(min); + double max; + ptr += copy_from_mem(ptr, max); + max = byteswap(max); + float k_float; + ptr += copy_from_mem(ptr, k_float); + const uint16_t k = static_cast(byteswap(k_float)); + // reference implementation stores capacities of the array of centroids and the buffer as shorts + // they can be derived from k in the constructor + ptr += sizeof(uint32_t); // unused + uint16_t num_centroids; + ptr += copy_from_mem(ptr, num_centroids); + num_centroids = byteswap(num_centroids); + ensure_minimum_memory(end_ptr - ptr, sizeof(float) * num_centroids * 2); + vector_centroid centroids(num_centroids, centroid(0, 0), allocator); + uint64_t total_weight = 0; + for (auto& c: centroids) { + float weight; + ptr += copy_from_mem(ptr, weight); + weight = byteswap(weight); + float mean; + ptr += copy_from_mem(ptr, mean); + mean = byteswap(mean); + c = centroid(mean, static_cast(weight)); + total_weight += static_cast(weight); + } + return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator); +} + template tdigest::tdigest(bool reverse_merge, uint16_t k, T min, T max, vector_centroid&& centroids, uint64_t total_weight, const A& allocator): allocator_(allocator), diff --git a/tdigest/test/tdigest_ref_k100_n10000_double.sk b/tdigest/test/tdigest_ref_k100_n10000_double.sk new file mode 100644 index 00000000..f6f4510e Binary files /dev/null and b/tdigest/test/tdigest_ref_k100_n10000_double.sk differ diff --git a/tdigest/test/tdigest_ref_k100_n10000_float.sk b/tdigest/test/tdigest_ref_k100_n10000_float.sk new file mode 100644 index 00000000..16d79811 Binary files /dev/null and b/tdigest/test/tdigest_ref_k100_n10000_float.sk differ diff --git a/tdigest/test/tdigest_test.cpp b/tdigest/test/tdigest_test.cpp index b1627d58..cfed6303 100644 --- a/tdigest/test/tdigest_test.cpp +++ b/tdigest/test/tdigest_test.cpp @@ -19,6 +19,7 @@ #include #include +#include #include "tdigest.hpp" @@ -229,4 +230,70 @@ TEST_CASE("serialize deserialize steam and bytes equivalence", "[tdigest]") { REQUIRE(deserialized_td1.get_quantile(0.5) == deserialized_td2.get_quantile(0.5)); } +TEST_CASE("deserialize from reference implementation stream double", "[tdigest]") { + std::ifstream is; + is.exceptions(std::ios::failbit | std::ios::badbit); + is.open(std::string(TEST_BINARY_INPUT_PATH) + "tdigest_ref_k100_n10000_double.sk", std::ios::binary); + const auto td = tdigest::deserialize(is); + const size_t n = 10000; + REQUIRE(td.get_total_weight() == n); + REQUIRE(td.get_min_value() == 0); + REQUIRE(td.get_max_value() == n - 1); + REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001)); + REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001)); + REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001)); + REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001)); + REQUIRE(td.get_rank(n) == 1); +} + +TEST_CASE("deserialize from reference implementation stream float", "[tdigest]") { + std::ifstream is; + is.exceptions(std::ios::failbit | std::ios::badbit); + is.open(std::string(TEST_BINARY_INPUT_PATH) + "tdigest_ref_k100_n10000_float.sk", std::ios::binary); + const auto td = tdigest::deserialize(is); + const size_t n = 10000; + REQUIRE(td.get_total_weight() == n); + REQUIRE(td.get_min_value() == 0); + REQUIRE(td.get_max_value() == n - 1); + REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001)); + REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001)); + REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001)); + REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001)); + REQUIRE(td.get_rank(n) == 1); +} + +TEST_CASE("deserialize from reference implementation bytes double", "[tdigest]") { + std::ifstream is; + is.exceptions(std::ios::failbit | std::ios::badbit); + is.open(std::string(TEST_BINARY_INPUT_PATH) + "tdigest_ref_k100_n10000_double.sk", std::ios::binary); + std::vector bytes((std::istreambuf_iterator(is)), (std::istreambuf_iterator())); + const auto td = tdigest::deserialize(bytes.data(), bytes.size()); + const size_t n = 10000; + REQUIRE(td.get_total_weight() == n); + REQUIRE(td.get_min_value() == 0); + REQUIRE(td.get_max_value() == n - 1); + REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001)); + REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001)); + REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001)); + REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001)); + REQUIRE(td.get_rank(n) == 1); +} + +TEST_CASE("deserialize from reference implementation bytes float", "[tdigest]") { + std::ifstream is; + is.exceptions(std::ios::failbit | std::ios::badbit); + is.open(std::string(TEST_BINARY_INPUT_PATH) + "tdigest_ref_k100_n10000_float.sk", std::ios::binary); + std::vector bytes((std::istreambuf_iterator(is)), (std::istreambuf_iterator())); + const auto td = tdigest::deserialize(bytes.data(), bytes.size()); + const size_t n = 10000; + REQUIRE(td.get_total_weight() == n); + REQUIRE(td.get_min_value() == 0); + REQUIRE(td.get_max_value() == n - 1); + REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001)); + REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001)); + REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001)); + REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001)); + REQUIRE(td.get_rank(n) == 1); +} + } /* namespace datasketches */