diff --git a/sampling/include/ebpps_sample.hpp b/sampling/include/ebpps_sample.hpp index 97c8b569..7ddb2828 100644 --- a/sampling/include/ebpps_sample.hpp +++ b/sampling/include/ebpps_sample.hpp @@ -37,14 +37,14 @@ class ebpps_sample { public: explicit ebpps_sample(uint32_t k, const A& allocator = A()); - // constructor used to create a sample to merge one item - template - ebpps_sample(TT&& item, double theta, const A& allocator = A()); - // for deserialization class items_deleter; ebpps_sample(std::vector&& data, optional&& partial_item, double c, const A& allocator = A()); + // used instead of having a single-item constructor for update/merge calls + template + void replace_content(TT&& item, double theta); + void reset(); void downsample(double theta); diff --git a/sampling/include/ebpps_sample_impl.hpp b/sampling/include/ebpps_sample_impl.hpp index f379841b..88a86ae0 100644 --- a/sampling/include/ebpps_sample_impl.hpp +++ b/sampling/include/ebpps_sample_impl.hpp @@ -41,22 +41,6 @@ ebpps_sample::ebpps_sample(uint32_t reserved_size, const A& allocator) : data_.reserve(reserved_size); } -template -template -ebpps_sample::ebpps_sample(TT&& item, double theta, const A& allocator) : - allocator_(allocator), - c_(theta), - partial_item_(), - data_(allocator) - { - if (theta == 1.0) { - data_.reserve(1); - data_.emplace_back(std::forward(item)); - } else { - partial_item_.emplace(std::forward(item)); - } - } - template ebpps_sample::ebpps_sample(std::vector&& data, optional&& partial_item, double c, const A& allocator) : allocator_(allocator), @@ -65,6 +49,19 @@ ebpps_sample::ebpps_sample(std::vector&& data, optional&& partial_ data_(data, allocator) {} +template +template +void ebpps_sample::replace_content(TT&& item, double theta) { + c_ = theta; + data_.clear(); + partial_item_.reset(); + if (theta == 1.0) { + data_.emplace_back(std::forward(item)); + } else { + partial_item_.emplace(std::forward(item)); + } +} + template auto ebpps_sample::get_sample() const -> result_type { double unused; diff --git a/sampling/include/ebpps_sketch.hpp b/sampling/include/ebpps_sketch.hpp index 51bcc4f1..038b5a30 100644 --- a/sampling/include/ebpps_sketch.hpp +++ b/sampling/include/ebpps_sketch.hpp @@ -256,6 +256,8 @@ class ebpps_sketch { ebpps_sample sample_; // Object holding the current state of the sample + ebpps_sample tmp_; // Temporary sample of size 1 used in updates + // handles merge after ensuring other.cumulative_wt_ <= this->cumulative_wt_ // so we can send items in individually template diff --git a/sampling/include/ebpps_sketch_impl.hpp b/sampling/include/ebpps_sketch_impl.hpp index e4dc0198..299b7f5c 100644 --- a/sampling/include/ebpps_sketch_impl.hpp +++ b/sampling/include/ebpps_sketch_impl.hpp @@ -40,7 +40,8 @@ ebpps_sketch::ebpps_sketch(uint32_t k, const A& allocator) : cumulative_wt_(0.0), wt_max_(0.0), rho_(1.0), - sample_(check_k(k), allocator) + sample_(check_k(k), allocator), + tmp_(1, allocator) {} template @@ -53,7 +54,8 @@ ebpps_sketch::ebpps_sketch(uint32_t k, uint64_t n, double cumulative_wt, cumulative_wt_(cumulative_wt), wt_max_(wt_max), rho_(rho), - sample_(sample) + sample_(sample), + tmp_(1, allocator) {} template @@ -148,9 +150,8 @@ void ebpps_sketch::internal_update(FwdItem&& item, double weight) { if (cumulative_wt_ > 0.0) sample_.downsample(new_rho / rho_); - ebpps_sample tmp(conditional_forward(item), new_rho * weight, allocator_); - - sample_.merge(tmp); + tmp_.replace_content(conditional_forward(item), new_rho * weight); + sample_.merge(tmp_); cumulative_wt_ = new_cum_wt; wt_max_ = new_wt_max; @@ -240,9 +241,8 @@ void ebpps_sketch::internal_merge(O&& sk) { if (cumulative_wt_ > 0.0) sample_.downsample(new_rho / rho_); - ebpps_sample tmp(conditional_forward(items[i]), new_rho * avg_wt, allocator_); - - sample_.merge(tmp); + tmp_.replace_content(conditional_forward(items[i]), new_rho * avg_wt); + sample_.merge(tmp_); cumulative_wt_ = new_cum_wt; rho_ = new_rho; @@ -259,9 +259,8 @@ void ebpps_sketch::internal_merge(O&& sk) { if (cumulative_wt_ > 0.0) sample_.downsample(new_rho / rho_); - ebpps_sample tmp(conditional_forward(other_sample.get_partial_item()), new_rho * other_c_frac * avg_wt, allocator_); - - sample_.merge(tmp); + tmp_.replace_content(conditional_forward(other_sample.get_partial_item()), new_rho * other_c_frac * avg_wt); + sample_.merge(tmp_); cumulative_wt_ = new_cum_wt; rho_ = new_rho; diff --git a/sampling/test/ebpps_sample_test.cpp b/sampling/test/ebpps_sample_test.cpp index c83cded6..e39cba7d 100644 --- a/sampling/test/ebpps_sample_test.cpp +++ b/sampling/test/ebpps_sample_test.cpp @@ -42,14 +42,15 @@ TEST_CASE("ebpps sample: basic initialization", "[ebpps_sketch]") { TEST_CASE("ebpps sample: pre-initialized", "[ebpps_sketch]") { double theta = 1.0; - ebpps_sample sample = ebpps_sample(-1, theta); + ebpps_sample sample(1); + sample.replace_content(-1, theta); REQUIRE(sample.get_c() == theta); REQUIRE(sample.get_num_retained_items() == 1); REQUIRE(sample.get_sample().size() == 1); REQUIRE(sample.has_partial_item() == false); theta = 1e-300; - sample = ebpps_sample(-1, theta); + sample.replace_content(-1, theta); REQUIRE(sample.get_c() == theta); REQUIRE(sample.get_num_retained_items() == 1); REQUIRE(sample.get_sample().size() == 0); // assuming the random number is > 1e-300 @@ -57,7 +58,8 @@ TEST_CASE("ebpps sample: pre-initialized", "[ebpps_sketch]") { } TEST_CASE("ebpps sample: downsampling", "[ebpps_sketch]") { - ebpps_sample sample = ebpps_sample('a', 1.0); + ebpps_sample sample(1); + sample.replace_content('a', 1.0); sample.downsample(2.0); // no-op REQUIRE(sample.get_c() == 1.0); @@ -121,8 +123,9 @@ TEST_CASE("ebpps sample: merge unit samples", "[ebpps_sketch]") { uint32_t k = 8; ebpps_sample sample = ebpps_sample(k); + ebpps_sample s(1); for (uint32_t i = 1; i <= k; ++i) { - ebpps_sample s = ebpps_sample(i, 1.0); + s.replace_content(i, 1.0); sample.merge(s); REQUIRE(sample.get_c() == static_cast(i)); REQUIRE(sample.get_num_retained_items() == i);