diff --git a/Cargo.toml b/Cargo.toml index 0a9669a..e9ac9f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,11 +22,13 @@ fixedbitset = { version = "0.4", optional = true } num-traits = { version = "^0.2.4", optional = true } rand = { version = "0.8", optional = true } succinct = { version = "^0.5", optional = true } +serde = { version = "1.0", optional = true } [dev-dependencies] criterion = "0.5" rand_chacha = "0.3" rand_distr = "0.4" +serde_json = "1.0" [features] default = [ @@ -35,6 +37,7 @@ default = [ "num-traits", "rand", "succinct", + "serde", ] [[bench]] diff --git a/src/filters/quotientfilter.rs b/src/filters/quotientfilter.rs index d75bc09..b349d46 100644 --- a/src/filters/quotientfilter.rs +++ b/src/filters/quotientfilter.rs @@ -1,7 +1,7 @@ //! QuotientFilter implementation. use std::collections::hash_map::DefaultHasher; use std::collections::VecDeque; -use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; +use std::hash::{BuildHasher, BuildHasherDefault, Hash}; use std::marker::PhantomData; use fixedbitset::FixedBitSet; @@ -346,9 +346,7 @@ where fn calc_quotient_remainder(&self, obj: &T) -> (usize, usize) { let bits_remainder = self.bits_remainder(); - let mut hasher = self.buildhasher.build_hasher(); - obj.hash(&mut hasher); - let fingerprint = hasher.finish(); + let fingerprint = self.buildhasher.hash_one(obj); let bits_trash = 64 - bits_remainder - self.bits_quotient; let trash = if bits_trash > 0 { (fingerprint >> (64 - bits_trash)) << (64 - bits_trash) diff --git a/src/hash_utils.rs b/src/hash_utils.rs index 7509d10..0d287b3 100644 --- a/src/hash_utils.rs +++ b/src/hash_utils.rs @@ -281,7 +281,7 @@ impl Hash for AnyHash { mod tests { use super::{BuildHasherSeeded, HashIterBuilder}; use std::collections::hash_map::DefaultHasher; - use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; + use std::hash::{BuildHasher, BuildHasherDefault}; #[test] fn hash_iter_builder_getter() { @@ -347,18 +347,11 @@ mod tests { let bh2 = BuildHasherSeeded::new(0); let bh3 = BuildHasherSeeded::new(1); - let mut hasher1 = bh1.build_hasher(); - let mut hasher2 = bh2.build_hasher(); - let mut hasher3 = bh3.build_hasher(); - let obj = "foo bar"; - obj.hash(&mut hasher1); - obj.hash(&mut hasher2); - obj.hash(&mut hasher3); - let result1 = hasher1.finish(); - let result2 = hasher2.finish(); - let result3 = hasher3.finish(); + let result1 = bh1.hash_one(obj); + let result2 = bh2.hash_one(obj); + let result3 = bh3.hash_one(obj); assert_eq!(result1, result2); assert_ne!(result1, result3); diff --git a/src/hyperloglog_data.rs b/src/hyperloglog/data.rs similarity index 100% rename from src/hyperloglog_data.rs rename to src/hyperloglog/data.rs diff --git a/src/hyperloglog.rs b/src/hyperloglog/mod.rs similarity index 90% rename from src/hyperloglog.rs rename to src/hyperloglog/mod.rs index 64ffed0..58d46e2 100644 --- a/src/hyperloglog.rs +++ b/src/hyperloglog/mod.rs @@ -2,14 +2,20 @@ use std::cmp; use std::collections::hash_map::DefaultHasher; use std::fmt; -use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; +use std::hash::{BuildHasher, BuildHasherDefault, Hash}; use std::marker::PhantomData; -use crate::hyperloglog_data::{ +use crate::hyperloglog::data::{ BIAS_DATA_OFFSET, BIAS_DATA_VEC, POW2MINX, RAW_ESTIMATE_DATA_OFFSET, RAW_ESTIMATE_DATA_VEC, THRESHOLD_DATA_OFFSET, THRESHOLD_DATA_VEC, }; +mod data; + +/// Serde support for `pdatastructs::hyperloglog::HyperLogLog` +#[cfg(feature = "serde")] +pub mod serde; + /// A HyperLogLog is a data structure to count unique elements on a data stream. /// /// # Examples @@ -79,7 +85,7 @@ use crate::hyperloglog_data::{ /// - ["Appendix to HyperLogLog in Practice: Algorithmic Engineering of a State of the Art /// Cardinality Estimation Algorithm", Stefan Heule, Marc Nunkesser, Alexander Hall, 2016](https://goo.gl/iU8Ig) /// - [Wikipedia: HyperLogLog](https://en.wikipedia.org/wiki/HyperLogLog) -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq)] pub struct HyperLogLog> where T: Hash + ?Sized, @@ -114,6 +120,13 @@ where { /// Same as `new` but with a specific `BuildHasher`. pub fn with_hash(b: usize, buildhasher: B) -> Self { + let m = 1_usize << b; + let registers = vec![0; m]; + Self::with_registers_and_hash(b, registers, buildhasher) + } + + /// Same as `new` but with pre-initialized registers and a specific `BuildHasher`. + pub fn with_registers_and_hash(b: usize, registers: Vec, buildhasher: B) -> Self { assert!( (4..=18).contains(&b), "b ({}) must be larger or equal than 4 and smaller or equal than 18", @@ -121,7 +134,13 @@ where ); let m = 1_usize << b; - let registers = vec![0; m]; + let len = registers.len(); + assert!( + m == len, + "registers must have length of {}, but had {}", + m, + len + ); Self { registers, b, @@ -140,6 +159,13 @@ where self.registers.len() } + /// Get register data + /// This is useful if you need to persist or serialize the structure using something else than + /// Serde + pub fn registers(&self) -> &[u8] { + &self.registers + } + /// Get `BuildHasher`. pub fn buildhasher(&self) -> &B { &self.buildhasher @@ -152,15 +178,18 @@ where /// Adds an element to the HyperLogLog. pub fn add(&mut self, obj: &T) { - let mut hasher = self.buildhasher.build_hasher(); - obj.hash(&mut hasher); - let h: u64 = hasher.finish(); + self.add_hashed(self.buildhasher.hash_one(obj)); + } - // split h into: + /// Adds an already-hashed element to the HyperLogLog + /// + /// Note: Make sure to use the same hasher as the rest of the HyperLogLog when hashing values on your own + pub fn add_hashed(&mut self, hashed_value: u64) { + // split hashed_value into: // - w = 64 - b upper bits // - j = b lower bits - let w = h >> self.b; - let j = h - (w << self.b); // no 1 as in the paper since register indices are 0-based + let w = hashed_value >> self.b; + let j = hashed_value - (w << self.b); // no 1 as in the paper since register indices are 0-based // p = leftmost bit (1-based count) let p = w.leading_zeros() + 1 - (self.b as u32); @@ -328,9 +357,10 @@ where } } -impl fmt::Debug for HyperLogLog +impl fmt::Debug for HyperLogLog where T: Hash + ?Sized, + B: BuildHasher + Clone + Eq, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "HyperLogLog {{ b: {} }}", self.b) @@ -363,7 +393,7 @@ where mod tests { use super::HyperLogLog; use crate::hash_utils::BuildHasherSeeded; - use crate::hyperloglog_data::{RAW_ESTIMATE_DATA_OFFSET, RAW_ESTIMATE_DATA_VEC}; + use crate::hyperloglog::data::{RAW_ESTIMATE_DATA_OFFSET, RAW_ESTIMATE_DATA_VEC}; use crate::test_util::{assert_send, NotSend}; #[test] @@ -720,4 +750,27 @@ mod tests { let hll: HyperLogLog = HyperLogLog::new(4); assert_send(&hll); } + + #[test] + fn reconstruct() { + let h = BuildHasherSeeded::new(0); + let b = 4; + let mut hll = HyperLogLog::with_hash(b, h); + hll.add("abc"); + + let hll2 = HyperLogLog::with_registers_and_hash(b, hll.registers().to_vec(), h); + assert_eq!(hll, hll2); + } + + #[test] + #[should_panic(expected = "registers must have length of 16, but had 0")] + fn reconstruct_panics() { + let h = BuildHasherSeeded::new(0); + let b = 4; + let mut hll = HyperLogLog::with_hash(b, h); + hll.add("abc"); + + let hll2 = HyperLogLog::with_registers_and_hash(b, vec![], h); + assert_eq!(hll, hll2); + } } diff --git a/src/hyperloglog/serde.rs b/src/hyperloglog/serde.rs new file mode 100644 index 0000000..0960aab --- /dev/null +++ b/src/hyperloglog/serde.rs @@ -0,0 +1,189 @@ +use super::HyperLogLog; +use serde::de::{self, Visitor}; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::hash::{BuildHasher, Hash}; +use std::marker::PhantomData; + +impl Serialize for HyperLogLog +where + T: Hash + ?Sized, + B: BuildHasher + Clone + Eq + Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut state = serializer.serialize_struct("HyperLogLog", 3)?; + state.serialize_field("registers", &self.registers)?; + state.serialize_field("b", &self.b)?; + state.serialize_field("buildhasher", &self.buildhasher)?; + state.end() + } +} + +impl<'de, T, B> Deserialize<'de> for HyperLogLog +where + T: Hash + ?Sized, + B: BuildHasher + Clone + Eq + Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + enum Field { + Registers, + B, + Buildhasher, + } + impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct FieldVisitor; + impl<'de> Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("`registers` or `b` or `buildhasher") + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + match v { + "registers" => Ok(Field::Registers), + "b" => Ok(Field::B), + "buildhasher" => Ok(Field::Buildhasher), + _ => Err(de::Error::unknown_field(v, FIELDS)), + } + } + } + + deserializer.deserialize_identifier(FieldVisitor) + } + } + struct HyperLogLogVisitor + where + T: Hash + ?Sized, + B: BuildHasher + Clone + Eq, + { + _t: PhantomData, + _b: PhantomData, + } + impl HyperLogLogVisitor + where + T: Hash + ?Sized, + B: BuildHasher + Clone + Eq, + { + fn new() -> Self { + Self { + _t: PhantomData, + _b: PhantomData, + } + } + } + impl<'de, T, B> Visitor<'de> for HyperLogLogVisitor + where + T: Hash + ?Sized, + B: BuildHasher + Clone + Eq + Deserialize<'de>, + { + type Value = HyperLogLog; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("struct HyperLogLog") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut registers = None; + let mut b = None; + let mut buildhasher = None; + while let Some(key) = map.next_key()? { + match key { + Field::Registers => { + if registers.is_some() { + return Err(de::Error::duplicate_field("registers")); + } + registers = Some(map.next_value()?); + } + Field::B => { + if b.is_some() { + return Err(de::Error::duplicate_field("b")); + } + b = Some(map.next_value()?); + } + Field::Buildhasher => { + if buildhasher.is_some() { + return Err(de::Error::duplicate_field("buildhasher")); + } + buildhasher = Some(map.next_value()?); + } + } + } + let registers = registers.ok_or_else(|| de::Error::missing_field("registers"))?; + let b = b.ok_or_else(|| de::Error::missing_field("b"))?; + let buildhasher = + buildhasher.ok_or_else(|| de::Error::missing_field("buildhasher"))?; + Ok(HyperLogLog { + registers, + b, + buildhasher, + phantom: PhantomData, + }) + } + } + const FIELDS: &[&str] = &["registers", "b", "buildhasher"]; + deserializer.deserialize_struct("HyperLogLog", FIELDS, HyperLogLogVisitor::::new()) + } +} + +#[cfg(test)] +mod tests { + use std::hash::{BuildHasher, Hasher}; + + use serde::{Deserialize, Serialize}; + + use crate::hyperloglog::HyperLogLog; + + #[test] + fn serde() { + #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] + struct MyHasher { + state: u64, + } + impl Hasher for MyHasher { + fn finish(&self) -> u64 { + self.state + } + + fn write(&mut self, bytes: &[u8]) { + let _ = bytes; + } + } + impl BuildHasher for MyHasher { + type Hasher = Self; + fn build_hasher(&self) -> Self::Hasher { + Self { state: 4 } + } + } + let hasher = MyHasher { state: 4 }; + // Construct a HLL + let mut hll = HyperLogLog::with_hash(4, hasher.clone()); + hll.add("abc"); + // Serialize to JSON + let json = serde_json::to_string(&hll).expect("can serialize to json"); + // Deserialize back to HLL + let mut de_hll = serde_json::from_str(&json).expect("can deserialize from json"); + // Check they're the same + assert_eq!(hll, de_hll); + // Add the same string again to check if the hasher is reconstructed correctly + de_hll.add("abc"); + assert_eq!(hll, de_hll); + } +} diff --git a/src/lib.rs b/src/lib.rs index 1d70343..9f6bcb1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,9 +44,6 @@ mod helpers; #[cfg(feature = "bytecount")] pub mod hyperloglog; -#[cfg(feature = "bytecount")] -mod hyperloglog_data; - #[cfg(feature = "rand")] pub mod reservoirsampling; diff --git a/src/reservoirsampling.rs b/src/reservoirsampling.rs index 948fce2..b81d386 100644 --- a/src/reservoirsampling.rs +++ b/src/reservoirsampling.rs @@ -179,9 +179,10 @@ mod tests { #[test] fn getter() { let rs = ReservoirSampling::::new(10, ChaChaRng::from_seed([0; 32])); + let empty: Vec = vec![]; assert_eq!(rs.k(), 10); assert_eq!(rs.i(), 0); - assert_eq!(rs.reservoir(), &vec![]); + assert_eq!(rs.reservoir(), &empty); } #[test]