Skip to content

Commit

Permalink
Merge pull request #424 from apache/tdigest
Browse files Browse the repository at this point in the history
uint32_t weight in tdigest<float> + cross-language test
  • Loading branch information
AlexanderSaydakov authored Feb 26, 2024
2 parents 4052e03 + 70fafa4 commit 5bf5a9f
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 16 deletions.
11 changes: 7 additions & 4 deletions tdigest/include/tdigest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ using tdigest_double = tdigest<double>;
*/
template <typename T, typename Allocator>
class tdigest {
static_assert(std::is_floating_point<T>::value, "Floating-point type expected");
// exclude long double by not using std::is_floating_point
static_assert(std::is_same<T, double>::value || std::is_same<T, float>::value, "Either double or float type expected");
static_assert(std::numeric_limits<T>::is_iec559, "IEEE 754 compatibility required");
public:
using value_type = T;
Expand All @@ -84,18 +85,20 @@ class tdigest {
static const bool USE_TWO_LEVEL_COMPRESSION = true;
static const bool USE_WEIGHT_LIMIT = true;

using W = typename std::conditional<std::is_same<T, double>::value, uint64_t, uint32_t>::type;

class centroid {
public:
centroid(T value, uint64_t weight): mean_(value), weight_(weight) {}
centroid(T value, W weight): mean_(value), weight_(weight) {}
void add(const centroid& other) {
weight_ += other.weight_;
mean_ += (other.mean_ - mean_) * other.weight_ / weight_;
}
T get_mean() const { return mean_; }
uint64_t get_weight() const { return weight_; }
W get_weight() const { return weight_; }
private:
T mean_;
uint64_t weight_;
W weight_;
};
using vector_centroid = std::vector<centroid, typename std::allocator_traits<Allocator>::template rebind_alloc<centroid>>;
using vector_bytes = std::vector<uint8_t, typename std::allocator_traits<Allocator>::template rebind_alloc<uint8_t>>;
Expand Down
10 changes: 5 additions & 5 deletions tdigest/include/tdigest_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ template<typename T, typename A>
auto tdigest<T, A>::serialize(unsigned header_size_bytes) const -> vector_bytes {
const_cast<tdigest*>(this)->merge_buffered(); // side effect
const uint8_t preamble_longs = is_empty() ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_NON_EMPTY;
const size_t size_bytes = preamble_longs * sizeof(uint64_t) + sizeof(T) * 2 + sizeof(centroid) * centroids_.size();
const size_t size_bytes = preamble_longs * sizeof(uint64_t) + (is_empty() ? 0 : sizeof(T) * 2 + sizeof(centroid) * centroids_.size());
vector_bytes bytes(size_bytes, 0, allocator_);
uint8_t* ptr = bytes.data() + header_size_bytes;

Expand Down Expand Up @@ -445,7 +445,7 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
const uint64_t weight = static_cast<uint64_t>(read_big_endian<double>(is));
const W weight = static_cast<W>(read_big_endian<double>(is));
const auto mean = read_big_endian<double>(is);
c = centroid(mean, weight);
total_weight += weight;
Expand All @@ -463,7 +463,7 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
const uint64_t weight = static_cast<uint64_t>(read_big_endian<float>(is));
const W weight = static_cast<W>(read_big_endian<float>(is));
const auto mean = read_big_endian<float>(is);
c = centroid(mean, weight);
total_weight += weight;
Expand Down Expand Up @@ -507,7 +507,7 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
double mean;
ptr += copy_from_mem(ptr, mean);
mean = byteswap(mean);
c = centroid(mean, static_cast<uint64_t>(weight));
c = centroid(mean, static_cast<W>(weight));
total_weight += static_cast<uint64_t>(weight);
}
return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator);
Expand Down Expand Up @@ -539,7 +539,7 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
float mean;
ptr += copy_from_mem(ptr, mean);
mean = byteswap(mean);
c = centroid(mean, static_cast<uint64_t>(weight));
c = centroid(mean, static_cast<W>(weight));
total_weight += static_cast<uint64_t>(weight);
}
return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator);
Expand Down
8 changes: 8 additions & 0 deletions tdigest/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,12 @@ add_test(
target_sources(tdigest_test
PRIVATE
tdigest_test.cpp
tdigest_custom_allocator_test.cpp
)

if (GENERATE)
target_sources(tdigest_test
PRIVATE
tdigest_serialize_for_java.cpp
)
endif()
43 changes: 43 additions & 0 deletions tdigest/test/tdigest_custom_allocator_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <catch2/catch.hpp>

#include "tdigest.hpp"
#include "test_allocator.hpp"

namespace datasketches {

using alloc_d = test_allocator<double>;
using tdigest_d = tdigest<double, alloc_d>;

TEST_CASE("tdigest custom allocator", "[tdigest]") {
test_allocator_total_bytes = 0;
test_allocator_net_allocations = 0;
{
tdigest_d td(100, alloc_d(0));
for (int i = 0; i < 10000; ++i) td.update(static_cast<double>(i));
REQUIRE(test_allocator_total_bytes != 0);
REQUIRE(test_allocator_net_allocations != 0);
}
REQUIRE(test_allocator_total_bytes == 0);
REQUIRE(test_allocator_net_allocations == 0);
}

} /* namespace datasketches */
47 changes: 47 additions & 0 deletions tdigest/test/tdigest_serialize_for_java.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <catch2/catch.hpp>
#include <fstream>

#include "tdigest.hpp"

namespace datasketches {

TEST_CASE("tdigest double generate", "[serialize_for_java]") {
const unsigned n_arr[] = {0, 1, 10, 100, 1000, 10000, 100000, 1000000};
for (const unsigned n: n_arr) {
tdigest_double td(100);
for (unsigned i = 1; i <= n; ++i) td.update(i);
std::ofstream os("tdigest_double_n" + std::to_string(n) + "_cpp.sk", std::ios::binary);
td.serialize(os);
}
}

TEST_CASE("tdigest float generate", "[serialize_for_java]") {
const unsigned n_arr[] = {0, 1, 10, 100, 1000, 10000, 100000, 1000000};
for (const unsigned n: n_arr) {
tdigest_float td(100);
for (unsigned i = 1; i <= n; ++i) td.update(i);
std::ofstream os("tdigest_float_n" + std::to_string(n) + "_cpp.sk", std::ios::binary);
td.serialize(os);
}
}

} /* namespace datasketches */
46 changes: 39 additions & 7 deletions tdigest/test/tdigest_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,34 @@ TEST_CASE("serialize deserialize bytes non empty", "[tdigest]") {
REQUIRE(td.get_quantile(0.5) == deserialized_td.get_quantile(0.5));
}

TEST_CASE("serialize deserialize steam and bytes equivalence empty", "[tdigest]") {
tdigest<double> td(100);
std::stringstream s(std::ios::in | std::ios::out | std::ios::binary);
td.serialize(s);
auto bytes = td.serialize();

REQUIRE(bytes.size() == static_cast<size_t>(s.tellp()));
for (size_t i = 0; i < bytes.size(); ++i) {
REQUIRE(((char*)bytes.data())[i] == (char)s.get());
}

s.seekg(0); // rewind
auto deserialized_td1 = tdigest<double>::deserialize(s);
auto deserialized_td2 = tdigest<double>::deserialize(bytes.data(), bytes.size());
REQUIRE(bytes.size() == static_cast<size_t>(s.tellg()));

REQUIRE(deserialized_td1.is_empty());
REQUIRE(deserialized_td2.is_empty());
REQUIRE(deserialized_td1.get_k() == 100);
REQUIRE(deserialized_td2.get_k() == 100);
REQUIRE(deserialized_td1.get_total_weight() == 0);
REQUIRE(deserialized_td2.get_total_weight() == 0);
}

TEST_CASE("serialize deserialize steam and bytes equivalence", "[tdigest]") {
tdigest<double> td(100);
for (int i = 0; i < 1000; ++i) td.update(i);
const int n = 1000;
for (int i = 0; i < n; ++i) td.update(i);
std::stringstream s(std::ios::in | std::ios::out | std::ios::binary);
td.serialize(s);
auto bytes = td.serialize();
Expand All @@ -221,12 +246,19 @@ TEST_CASE("serialize deserialize steam and bytes equivalence", "[tdigest]") {
auto deserialized_td2 = tdigest<double>::deserialize(bytes.data(), bytes.size());
REQUIRE(bytes.size() == static_cast<size_t>(s.tellg()));

REQUIRE(deserialized_td1.get_k() == deserialized_td2.get_k());
REQUIRE(deserialized_td1.get_total_weight() == deserialized_td2.get_total_weight());
REQUIRE(deserialized_td1.is_empty() == deserialized_td2.is_empty());
REQUIRE(deserialized_td1.get_min_value() == deserialized_td2.get_min_value());
REQUIRE(deserialized_td1.get_max_value() == deserialized_td2.get_max_value());
REQUIRE(deserialized_td1.get_rank(500) == deserialized_td2.get_rank(500));
REQUIRE_FALSE(deserialized_td1.is_empty());
REQUIRE(deserialized_td1.get_k() == 100);
REQUIRE(deserialized_td1.get_total_weight() == n);
REQUIRE(deserialized_td1.get_min_value() == 0);
REQUIRE(deserialized_td1.get_max_value() == n - 1);

REQUIRE_FALSE(deserialized_td2.is_empty());
REQUIRE(deserialized_td2.get_k() == 100);
REQUIRE(deserialized_td2.get_total_weight() == n);
REQUIRE(deserialized_td2.get_min_value() == 0);
REQUIRE(deserialized_td2.get_max_value() == n - 1);

REQUIRE(deserialized_td1.get_rank(n / 2) == deserialized_td2.get_rank(n / 2));
REQUIRE(deserialized_td1.get_quantile(0.5) == deserialized_td2.get_quantile(0.5));
}

Expand Down

0 comments on commit 5bf5a9f

Please sign in to comment.