Skip to content

Commit

Permalink
implemented get_PMF() and get_CDF()
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSaydakov committed Oct 23, 2024
1 parent 8b86cf1 commit f0d4cb7
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 3 deletions.
50 changes: 50 additions & 0 deletions tdigest/include/tdigest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class tdigest {
using vector_t = std::vector<T, Allocator>;
using vector_centroid = std::vector<centroid, typename std::allocator_traits<Allocator>::template rebind_alloc<centroid>>;
using vector_bytes = std::vector<uint8_t, typename std::allocator_traits<Allocator>::template rebind_alloc<uint8_t>>;
using vector_double = std::vector<double, typename std::allocator_traits<Allocator>::template rebind_alloc<double>>;

struct centroid_cmp {
centroid_cmp() {}
Expand Down Expand Up @@ -142,20 +143,67 @@ class tdigest {
*/
uint64_t get_total_weight() const;

/**
* Returns an instance of the allocator for this t-Digest.
* @return allocator
*/
Allocator get_allocator() const;

/**
* Compute approximate normalized rank of the given value.
*
* <p>If the sketch is empty this throws std::runtime_error.
*
* @param value to be ranked
* @return normalized rank (from 0 to 1 inclusive)
*/
double get_rank(T value) const;

/**
* Compute approximate quantile value corresponding to the given normalized rank
*
* <p>If the sketch is empty this throws std::runtime_error.
*
* @param rank normalized rank (from 0 to 1 inclusive)
* @return quantile value corresponding to the given rank
*/
T get_quantile(double rank) const;

/**
* Returns an approximation to the Probability Mass Function (PMF) of the input stream
* given a set of split points.
*
* <p>If the sketch is empty this throws std::runtime_error.
*
* @param split_points an array of <i>m</i> unique, monotonically increasing values
* that divide the input domain into <i>m+1</i> consecutive disjoint intervals (bins).
*
* @param size the number of split points in the array
*
* @return an array of m+1 doubles each of which is an approximation
* to the fraction of the input stream values (the mass) that fall into one of those intervals.
*/
vector_double get_PMF(const T* split_points, uint32_t size) const;

/**
* Returns an approximation to the Cumulative Distribution Function (CDF), which is the
* cumulative analog of the PMF, of the input stream given a set of split points.
*
* <p>If the sketch is empty this throws std::runtime_error.
*
* @param split_points an array of <i>m</i> unique, monotonically increasing values
* that divide the input domain into <i>m+1</i> consecutive disjoint intervals.
*
* @param size the number of split points in the array
*
* @return an array of m+1 doubles, which are a consecutive approximation to the CDF
* of the input stream given the split_points. The value at array position j of the returned
* CDF array is the sum of the returned values in positions 0 through j of the returned PMF
* array. This can be viewed as array of ranks of the given split points plus one more value
* that is always 1.
*/
vector_double get_CDF(const T* split_points, uint32_t size) const;

/**
* @return parameter k (compression) that was used to configure this t-Digest
*/
Expand Down Expand Up @@ -245,6 +293,8 @@ class tdigest {
// for compatibility with format of the reference implementation
static tdigest deserialize_compat(std::istream& is, const Allocator& allocator = Allocator());
static tdigest deserialize_compat(const void* bytes, size_t size, const Allocator& allocator = Allocator());

static inline void check_split_points(const T* values, uint32_t size);
};

} /* namespace datasketches */
Expand Down
36 changes: 36 additions & 0 deletions tdigest/include/tdigest_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ uint64_t tdigest<T, A>::get_total_weight() const {
return centroids_weight_ + buffer_.size();
}

template<typename T, typename A>
A tdigest<T, A>::get_allocator() const {
return buffer_.get_allocator();
}

template<typename T, typename A>
double tdigest<T, A>::get_rank(T value) const {
if (is_empty()) throw std::runtime_error("operation is undefined for an empty sketch");
Expand Down Expand Up @@ -191,6 +196,25 @@ T tdigest<T, A>::get_quantile(double rank) const {
return weighted_average(centroids_.back().get_weight(), w1, max_, w2);
}

template<typename T, typename A>
auto tdigest<T, A>::get_PMF(const T* split_points, uint32_t size) const -> vector_double {
auto buckets = get_CDF(split_points, size);
for (uint32_t i = size; i > 0; --i) {
buckets[i] -= buckets[i - 1];
}
return buckets;
}

template<typename T, typename A>
auto tdigest<T, A>::get_CDF(const T* split_points, uint32_t size) const -> vector_double {
check_split_points(split_points, size);
vector_double ranks(get_allocator());
ranks.reserve(size + 1);
for (uint32_t i = 0; i < size; ++i) ranks.push_back(get_rank(split_points[i]));
ranks.push_back(1);
return ranks;
}

template<typename T, typename A>
uint16_t tdigest<T, A>::get_k() const {
return k_;
Expand Down Expand Up @@ -591,6 +615,18 @@ buffer_(std::move(buffer))
buffer_.reserve(centroids_capacity_ * BUFFER_MULTIPLIER);
}

template<typename T, typename A>
void tdigest<T, A>::check_split_points(const T* values, uint32_t size) {
for (uint32_t i = 0; i < size ; i++) {
if (std::isnan(values[i])) {
throw std::invalid_argument("Values must not be NaN");
}
if ((i < (size - 1)) && !(values[i] < values[i + 1])) {
throw std::invalid_argument("Values must be unique and monotonically increasing");
}
}
}

} /* namespace datasketches */

#endif // _TDIGEST_IMPL_HPP_
15 changes: 12 additions & 3 deletions tdigest/test/tdigest_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ TEST_CASE("empty", "[tdigest]") {
REQUIRE_THROWS_AS(td.get_max_value(), std::runtime_error);
REQUIRE_THROWS_AS(td.get_rank(0), std::runtime_error);
REQUIRE_THROWS_AS(td.get_quantile(0.5), std::runtime_error);
const double split_points[1] {0};
REQUIRE_THROWS_AS(td.get_PMF(split_points, 1), std::runtime_error);
REQUIRE_THROWS_AS(td.get_CDF(split_points, 1), std::runtime_error);
}

TEST_CASE("one value", "[tdigest]") {
Expand All @@ -56,9 +59,6 @@ TEST_CASE("many values", "[tdigest]") {
const size_t n = 10000;
tdigest_double td;
for (size_t i = 0; i < n; ++i) td.update(i);
// std::cout << td.to_string(true);
// td.compress();
// std::cout << td.to_string(true);
REQUIRE_FALSE(td.is_empty());
REQUIRE(td.get_total_weight() == n);
REQUIRE(td.get_min_value() == 0);
Expand All @@ -73,6 +73,15 @@ TEST_CASE("many values", "[tdigest]") {
REQUIRE(td.get_quantile(0.9) == Approx(n * 0.9).epsilon(0.01));
REQUIRE(td.get_quantile(0.95) == Approx(n * 0.95).epsilon(0.01));
REQUIRE(td.get_quantile(1) == n - 1);
const double split_points[1] {n / 2};
const auto pmf = td.get_PMF(split_points, 1);
REQUIRE(pmf.size() == 2);
REQUIRE(pmf[0] == Approx(0.5).margin(0.0001));
REQUIRE(pmf[1] == Approx(0.5).margin(0.0001));
const auto cdf = td.get_CDF(split_points, 1);
REQUIRE(cdf.size() == 2);
REQUIRE(cdf[0] == Approx(0.5).margin(0.0001));
REQUIRE(cdf[1] == 1);
}

TEST_CASE("rank - two values", "[tdigest]") {
Expand Down

0 comments on commit f0d4cb7

Please sign in to comment.