Skip to content

Commit

Permalink
Merge pull request #422 from apache/tdigest
Browse files Browse the repository at this point in the history
Tdigest
  • Loading branch information
AlexanderSaydakov authored Feb 13, 2024
2 parents 50ad1ba + 5f94bdd commit fa0237a
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 0 deletions.
17 changes: 17 additions & 0 deletions common/include/common_defs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,23 @@ static inline void write(std::ostream& os, const T* ptr, size_t size_bytes) {
os.write(reinterpret_cast<const char*>(ptr), size_bytes);
}

template<typename T>
T byteswap(T value) {
char* ptr = static_cast<char*>(static_cast<void*>(&value));
const int len = sizeof(T);
for (size_t i = 0; i < len / 2; ++i) {
std::swap(ptr[i], ptr[len - i - 1]);
}
return value;
}

template<typename T>
static inline T read_big_endian(std::istream& is) {
T value;
is.read(reinterpret_cast<char*>(&value), sizeof(T));
return byteswap(value);
}

// wrapper for iterators to implement operator-> returning temporary value
template<typename T>
class return_value_holder {
Expand Down
7 changes: 7 additions & 0 deletions tdigest/include/tdigest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ class tdigest {
static const uint8_t SERIAL_VERSION = 1;
static const uint8_t SKETCH_TYPE = 20;

static const uint8_t COMPAT_DOUBLE = 1;
static const uint8_t COMPAT_FLOAT = 2;

enum flags { IS_EMPTY, REVERSE_MERGE };

// for deserialize
Expand All @@ -238,6 +241,10 @@ class tdigest {
void merge_new_values(uint16_t k);

static double weighted_average(double x1, double w1, double x2, double w2);

// for compatibility with format of the reference implementation
static tdigest deserialize_compat(std::istream& is, const Allocator& allocator = Allocator());
static tdigest deserialize_compat(const void* bytes, size_t size, const Allocator& allocator = Allocator());
};

} /* namespace datasketches */
Expand Down
120 changes: 120 additions & 0 deletions tdigest/include/tdigest_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ tdigest<T, A> tdigest<T, A>::deserialize(std::istream& is, const A& allocator) {
const auto serial_version = read<uint8_t>(is);
const auto sketch_type = read<uint8_t>(is);
if (sketch_type != SKETCH_TYPE) {
if (preamble_longs == 0 && serial_version == 0 && sketch_type == 0) return deserialize_compat(is, allocator);
throw std::invalid_argument("sketch type mismatch: expected " + std::to_string(SKETCH_TYPE) + ", actual " + std::to_string(sketch_type));
}
if (serial_version != SERIAL_VERSION) {
Expand Down Expand Up @@ -391,6 +392,7 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, size_t size, const A
const uint8_t serial_version = *ptr++;
const uint8_t sketch_type = *ptr++;
if (sketch_type != SKETCH_TYPE) {
if (preamble_longs == 0 && serial_version == 0 && sketch_type == 0) return deserialize_compat(ptr, end_ptr - ptr, allocator);
throw std::invalid_argument("sketch type mismatch: expected " + std::to_string(SKETCH_TYPE) + ", actual " + std::to_string(sketch_type));
}
if (serial_version != SERIAL_VERSION) {
Expand Down Expand Up @@ -426,6 +428,124 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, size_t size, const A
return tdigest(reverse_merge, k, min, max, std::move(centroids), total_weight, allocator);
}

// compatibility with the format of the reference implementation
// default byte order of ByteBuffer is used there, which is big endian
template<typename T, typename A>
tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& allocator) {
// this method was called because the first three bytes were zeros
// so read one more byte to see if it looks like the reference implementation format
const auto type = read<uint8_t>(is);
if (type != COMPAT_DOUBLE && type != COMPAT_FLOAT) {
throw std::invalid_argument("unexpected sketch preamble: 0 0 0 " + std::to_string(type));
}
if (type == COMPAT_DOUBLE) { // compatibility with asBytes()
const auto min = read_big_endian<double>(is);
const auto max = read_big_endian<double>(is);
const auto k = static_cast<uint16_t>(read_big_endian<double>(is));
const auto num_centroids = read_big_endian<uint32_t>(is);
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
const uint64_t weight = static_cast<uint64_t>(read_big_endian<double>(is));
const auto mean = read_big_endian<double>(is);
c = centroid(mean, weight);
total_weight += weight;
}
return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator);
}
// COMPAT_FLOAT: compatibility with asSmallBytes()
const auto min = read_big_endian<double>(is); // reference implementation uses doubles for min and max
const auto max = read_big_endian<double>(is);
const auto k = static_cast<uint16_t>(read_big_endian<float>(is));
// reference implementation stores capacities of the array of centroids and the buffer as shorts
// they can be derived from k in the constructor
read<uint32_t>(is); // unused
const auto num_centroids = read_big_endian<uint16_t>(is);
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
const uint64_t weight = static_cast<uint64_t>(read_big_endian<float>(is));
const auto mean = read_big_endian<float>(is);
c = centroid(mean, weight);
total_weight += weight;
}
return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator);
}

// compatibility with the format of the reference implementation
// default byte order of ByteBuffer is used there, which is big endian
template<typename T, typename A>
tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size, const A& allocator) {
const char* ptr = static_cast<const char*>(bytes);
// this method was called because the first three bytes were zeros
// so read one more byte to see if it looks like the reference implementation format
const auto type = *ptr++;
if (type != COMPAT_DOUBLE && type != COMPAT_FLOAT) {
throw std::invalid_argument("unexpected sketch preamble: 0 0 0 " + std::to_string(type));
}
const char* end_ptr = static_cast<const char*>(bytes) + size;
if (type == COMPAT_DOUBLE) { // compatibility with asBytes()
ensure_minimum_memory(end_ptr - ptr, sizeof(double) * 3 + sizeof(uint32_t));
double min;
ptr += copy_from_mem(ptr, min);
min = byteswap(min);
double max;
ptr += copy_from_mem(ptr, max);
max = byteswap(max);
double k_double;
ptr += copy_from_mem(ptr, k_double);
const uint16_t k = static_cast<uint16_t>(byteswap(k_double));
uint32_t num_centroids;
ptr += copy_from_mem(ptr, num_centroids);
num_centroids = byteswap(num_centroids);
ensure_minimum_memory(end_ptr - ptr, sizeof(double) * num_centroids * 2);
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
double weight;
ptr += copy_from_mem(ptr, weight);
weight = byteswap(weight);
double mean;
ptr += copy_from_mem(ptr, mean);
mean = byteswap(mean);
c = centroid(mean, static_cast<uint64_t>(weight));
total_weight += static_cast<uint64_t>(weight);
}
return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator);
}
// COMPAT_FLOAT: compatibility with asSmallBytes()
ensure_minimum_memory(end_ptr - ptr, sizeof(double) * 2 + sizeof(float) + sizeof(uint16_t) * 3);
double min; // reference implementation uses doubles for min and max
ptr += copy_from_mem(ptr, min);
min = byteswap(min);
double max;
ptr += copy_from_mem(ptr, max);
max = byteswap(max);
float k_float;
ptr += copy_from_mem(ptr, k_float);
const uint16_t k = static_cast<uint16_t>(byteswap(k_float));
// reference implementation stores capacities of the array of centroids and the buffer as shorts
// they can be derived from k in the constructor
ptr += sizeof(uint32_t); // unused
uint16_t num_centroids;
ptr += copy_from_mem(ptr, num_centroids);
num_centroids = byteswap(num_centroids);
ensure_minimum_memory(end_ptr - ptr, sizeof(float) * num_centroids * 2);
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
float weight;
ptr += copy_from_mem(ptr, weight);
weight = byteswap(weight);
float mean;
ptr += copy_from_mem(ptr, mean);
mean = byteswap(mean);
c = centroid(mean, static_cast<uint64_t>(weight));
total_weight += static_cast<uint64_t>(weight);
}
return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator);
}

template<typename T, typename A>
tdigest<T, A>::tdigest(bool reverse_merge, uint16_t k, T min, T max, vector_centroid&& centroids, uint64_t total_weight, const A& allocator):
allocator_(allocator),
Expand Down
Binary file added tdigest/test/tdigest_ref_k100_n10000_double.sk
Binary file not shown.
Binary file added tdigest/test/tdigest_ref_k100_n10000_float.sk
Binary file not shown.
67 changes: 67 additions & 0 deletions tdigest/test/tdigest_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <catch2/catch.hpp>
#include <iostream>
#include <fstream>

#include "tdigest.hpp"

Expand Down Expand Up @@ -229,4 +230,70 @@ TEST_CASE("serialize deserialize steam and bytes equivalence", "[tdigest]") {
REQUIRE(deserialized_td1.get_quantile(0.5) == deserialized_td2.get_quantile(0.5));
}

TEST_CASE("deserialize from reference implementation stream double", "[tdigest]") {
std::ifstream is;
is.exceptions(std::ios::failbit | std::ios::badbit);
is.open(std::string(TEST_BINARY_INPUT_PATH) + "tdigest_ref_k100_n10000_double.sk", std::ios::binary);
const auto td = tdigest<double>::deserialize(is);
const size_t n = 10000;
REQUIRE(td.get_total_weight() == n);
REQUIRE(td.get_min_value() == 0);
REQUIRE(td.get_max_value() == n - 1);
REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001));
REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001));
REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001));
REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001));
REQUIRE(td.get_rank(n) == 1);
}

TEST_CASE("deserialize from reference implementation stream float", "[tdigest]") {
std::ifstream is;
is.exceptions(std::ios::failbit | std::ios::badbit);
is.open(std::string(TEST_BINARY_INPUT_PATH) + "tdigest_ref_k100_n10000_float.sk", std::ios::binary);
const auto td = tdigest<float>::deserialize(is);
const size_t n = 10000;
REQUIRE(td.get_total_weight() == n);
REQUIRE(td.get_min_value() == 0);
REQUIRE(td.get_max_value() == n - 1);
REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001));
REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001));
REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001));
REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001));
REQUIRE(td.get_rank(n) == 1);
}

TEST_CASE("deserialize from reference implementation bytes double", "[tdigest]") {
std::ifstream is;
is.exceptions(std::ios::failbit | std::ios::badbit);
is.open(std::string(TEST_BINARY_INPUT_PATH) + "tdigest_ref_k100_n10000_double.sk", std::ios::binary);
std::vector<char> bytes((std::istreambuf_iterator<char>(is)), (std::istreambuf_iterator<char>()));
const auto td = tdigest<double>::deserialize(bytes.data(), bytes.size());
const size_t n = 10000;
REQUIRE(td.get_total_weight() == n);
REQUIRE(td.get_min_value() == 0);
REQUIRE(td.get_max_value() == n - 1);
REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001));
REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001));
REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001));
REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001));
REQUIRE(td.get_rank(n) == 1);
}

TEST_CASE("deserialize from reference implementation bytes float", "[tdigest]") {
std::ifstream is;
is.exceptions(std::ios::failbit | std::ios::badbit);
is.open(std::string(TEST_BINARY_INPUT_PATH) + "tdigest_ref_k100_n10000_float.sk", std::ios::binary);
std::vector<char> bytes((std::istreambuf_iterator<char>(is)), (std::istreambuf_iterator<char>()));
const auto td = tdigest<double>::deserialize(bytes.data(), bytes.size());
const size_t n = 10000;
REQUIRE(td.get_total_weight() == n);
REQUIRE(td.get_min_value() == 0);
REQUIRE(td.get_max_value() == n - 1);
REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001));
REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001));
REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001));
REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001));
REQUIRE(td.get_rank(n) == 1);
}

} /* namespace datasketches */

0 comments on commit fa0237a

Please sign in to comment.