Skip to content

Commit

Permalink
chore: cleanup dict encoding logic (#1231)
Browse files Browse the repository at this point in the history
fixes #965

---------

Co-authored-by: Will Manning <[email protected]>
  • Loading branch information
robert3005 and lwwmanning authored Nov 6, 2024
1 parent 1880c2e commit 0195c02
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 98 deletions.
2 changes: 1 addition & 1 deletion docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Use :func:`~vortex.encoding.compress` to compress the Vortex array and check the

>>> cvtx = vortex.compress(vtx)
>>> cvtx.nbytes
13970
13963
>>> cvtx.nbytes / vtx.nbytes
0.099...

Expand Down
34 changes: 25 additions & 9 deletions encodings/dict/benches/dict_compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion
use rand::distributions::{Alphanumeric, Uniform};
use rand::prelude::SliceRandom;
use rand::{thread_rng, Rng};
use vortex_array::array::{PrimitiveArray, VarBinArray};
use vortex_array::array::{PrimitiveArray, VarBinArray, VarBinViewArray};
use vortex_array::{ArrayTrait, IntoArray as _, IntoCanonical as _};
use vortex_dict::{dict_encode_primitive, dict_encode_varbin, DictArray};
use vortex_dict::{dict_encode_primitive, dict_encode_varbin, dict_encode_varbinview, DictArray};

fn gen_primitive_dict(len: usize, uniqueness: f64) -> PrimitiveArray {
let mut rng = thread_rng();
Expand All @@ -17,7 +17,7 @@ fn gen_primitive_dict(len: usize, uniqueness: f64) -> PrimitiveArray {
PrimitiveArray::from(data)
}

fn gen_varbin_dict(len: usize, uniqueness: f64) -> VarBinArray {
fn gen_varbin_words(len: usize, uniqueness: f64) -> Vec<String> {
let mut rng = thread_rng();
let uniq_cnt = (len as f64 * uniqueness) as usize;
let dict: Vec<String> = (0..uniq_cnt)
Expand All @@ -29,10 +29,9 @@ fn gen_varbin_dict(len: usize, uniqueness: f64) -> VarBinArray {
.collect()
})
.collect();
let words: Vec<&str> = (0..len)
.map(|_| dict.choose(&mut rng).unwrap().as_str())
.collect();
VarBinArray::from(words)
(0..len)
.map(|_| dict.choose(&mut rng).unwrap().clone())
.collect()
}

fn dict_encode(c: &mut Criterion) {
Expand All @@ -44,11 +43,17 @@ fn dict_encode(c: &mut Criterion) {
b.iter(|| black_box(dict_encode_primitive(&primitive_arr)));
});

let varbin_arr = gen_varbin_dict(1_000_000, 0.00005);
let varbin_arr = VarBinArray::from(gen_varbin_words(1_000_000, 0.00005));
group.throughput(Throughput::Bytes(varbin_arr.nbytes() as u64));
group.bench_function("dict_encode_varbin", |b| {
b.iter(|| black_box(dict_encode_varbin(&varbin_arr)));
});

let varbinview_arr = VarBinViewArray::from_iter_str(gen_varbin_words(1_000_000, 0.00005));
group.throughput(Throughput::Bytes(varbinview_arr.nbytes() as u64));
group.bench_function("dict_encode_varbinview", |b| {
b.iter(|| black_box(dict_encode_varbinview(&varbinview_arr)));
});
}

fn dict_decode(c: &mut Criterion) {
Expand All @@ -65,7 +70,7 @@ fn dict_decode(c: &mut Criterion) {
);
});

let varbin_arr = gen_varbin_dict(1_000_000, 0.00005);
let varbin_arr = VarBinArray::from(gen_varbin_words(1_000_000, 0.00005));
let (codes, values) = dict_encode_varbin(&varbin_arr);
group.throughput(Throughput::Bytes(varbin_arr.nbytes() as u64));
group.bench_function("dict_decode_varbin", |b| {
Expand All @@ -75,6 +80,17 @@ fn dict_decode(c: &mut Criterion) {
BatchSize::SmallInput,
);
});

let varbinview_arr = VarBinViewArray::from_iter_str(gen_varbin_words(1_000_000, 0.00005));
let (codes, values) = dict_encode_varbinview(&varbinview_arr);
group.throughput(Throughput::Bytes(varbin_arr.nbytes() as u64));
group.bench_function("dict_decode_varbinview", |b| {
b.iter_batched(
|| DictArray::try_new(codes.clone().into_array(), values.clone().into_array()).unwrap(),
|dict_arr| black_box(dict_arr.into_canonical().unwrap()),
BatchSize::SmallInput,
);
});
}

criterion_group!(benches, dict_encode, dict_decode);
Expand Down
157 changes: 79 additions & 78 deletions encodings/dict/src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::hash::{BuildHasher, Hash, Hasher};

use hashbrown::hash_map::{Entry, RawEntryMut};
use hashbrown::{DefaultHashBuilder, HashMap};
use hashbrown::hash_map::Entry;
use hashbrown::HashTable;
use num_traits::AsPrimitive;
use vortex_array::accessor::ArrayAccessor;
use vortex_array::array::{PrimitiveArray, VarBinArray, VarBinViewArray};
use vortex_array::aliases::hash_map::{DefaultHashBuilder, HashMap};
use vortex_array::array::{
ConstantArray, PrimitiveArray, SparseArray, VarBinArray, VarBinViewArray,
};
use vortex_array::validity::Validity;
use vortex_array::{ArrayDType, IntoArray, IntoCanonical};
use vortex_dtype::{match_each_native_ptype, DType, NativePType, ToBytes};
use vortex_error::{VortexExpect as _, VortexUnwrap};
use vortex_scalar::ScalarValue;

/// Statically assigned code for a null value.
pub const NULL_CODE: u64 = 0;
Expand Down Expand Up @@ -41,43 +45,36 @@ pub fn dict_encode_primitive(array: &PrimitiveArray) -> (PrimitiveArray, Primiti
pub fn dict_encode_typed_primitive<T: NativePType>(
array: &PrimitiveArray,
) -> (PrimitiveArray, PrimitiveArray) {
let mut lookup_dict: HashMap<Value<T>, u64> = HashMap::new();
let mut lookup: HashMap<Value<T>, u64> = HashMap::new();
let mut codes: Vec<u64> = Vec::new();
let mut values: Vec<T> = Vec::new();

if array.dtype().is_nullable() {
values.push(T::zero());
}

ArrayAccessor::<T>::with_iterator(array, |iter| {
for ov in iter {
match ov {
None => codes.push(NULL_CODE),
Some(&v) => {
let code = match lookup_dict.entry(Value(v)) {
Entry::Occupied(o) => *o.get(),
Entry::Vacant(vac) => {
let next_code = values.len() as u64;
vac.insert(next_code.as_());
values.push(v);
next_code
}
};
codes.push(code);
array
.with_iterator(|iter| {
for ov in iter {
match ov {
None => codes.push(NULL_CODE),
Some(&v) => {
codes.push(match lookup.entry(Value(v)) {
Entry::Occupied(o) => *o.get(),
Entry::Vacant(vac) => {
let next_code = values.len() as u64;
vac.insert(next_code.as_());
values.push(v);
next_code
}
});
}
}
}
}
})
.vortex_expect("Failed to dictionary encode primitive array");
})
.vortex_expect("Failed to dictionary encode primitive array");

let values_validity = if array.dtype().is_nullable() {
let mut validity = vec![true; values.len()];
validity[0] = false;

validity.into()
} else {
Validity::NonNullable
};
let values_validity = dict_values_validity(array.dtype().is_nullable(), values.len());

(
PrimitiveArray::from(codes),
Expand All @@ -88,14 +85,14 @@ pub fn dict_encode_typed_primitive<T: NativePType>(
/// Dictionary encode varbin array. Specializes for primitive byte arrays to avoid double copying
pub fn dict_encode_varbin(array: &VarBinArray) -> (PrimitiveArray, VarBinArray) {
array
.with_iterator(|iter| dict_encode_typed_varbin(array.dtype().clone(), iter))
.vortex_expect("Failed to dictionary encode varbin array")
.with_iterator(|iter| dict_encode_varbin_bytes(array.dtype().clone(), iter))
.vortex_unwrap()
}

/// Dictionary encode a VarbinViewArray.
pub fn dict_encode_varbinview(array: &VarBinViewArray) -> (PrimitiveArray, VarBinViewArray) {
let (codes, values) = array
.with_iterator(|iter| dict_encode_typed_varbin(array.dtype().clone(), iter))
.with_iterator(|iter| dict_encode_varbin_bytes(array.dtype().clone(), iter))
.vortex_unwrap();
(
codes,
Expand All @@ -107,74 +104,51 @@ pub fn dict_encode_varbinview(array: &VarBinViewArray) -> (PrimitiveArray, VarBi
)
}

fn lookup_bytes<'a, T: NativePType + AsPrimitive<usize>>(
offsets: &'a [T],
bytes: &'a [u8],
idx: usize,
) -> &'a [u8] {
let begin: usize = offsets[idx].as_();
let end: usize = offsets[idx + 1].as_();
&bytes[begin..end]
}

fn dict_encode_typed_varbin<I, U>(dtype: DType, values: I) -> (PrimitiveArray, VarBinArray)
where
I: Iterator<Item = Option<U>>,
U: AsRef<[u8]>,
{
fn dict_encode_varbin_bytes<'a, I: Iterator<Item = Option<&'a [u8]>>>(
dtype: DType,
values: I,
) -> (PrimitiveArray, VarBinArray) {
let (lower, _) = values.size_hint();
let hasher = DefaultHashBuilder::default();
let mut lookup_dict: HashMap<u64, (), ()> = HashMap::with_hasher(());
let mut lookup_dict: HashTable<u64> = HashTable::new();
let mut codes: Vec<u64> = Vec::with_capacity(lower);
let mut bytes: Vec<u8> = Vec::new();
let mut offsets: Vec<u32> = Vec::new();
offsets.push(0);
let mut offsets: Vec<u32> = vec![0];

if dtype.is_nullable() {
offsets.push(0);
}

for o_val in values {
match o_val {
None => codes.push(0),
None => codes.push(NULL_CODE),
Some(val) => {
let byte_ref = val.as_ref();
let value_hash = hasher.hash_one(byte_ref);
let raw_entry = lookup_dict.raw_entry_mut().from_hash(value_hash, |idx| {
byte_ref == lookup_bytes(offsets.as_slice(), bytes.as_slice(), idx.as_())
});

let code = match raw_entry {
RawEntryMut::Occupied(o) => *o.into_key(),
RawEntryMut::Vacant(vac) => {
let next_code = offsets.len() as u64 - 1;
bytes.extend_from_slice(byte_ref);
offsets.push(bytes.len() as u32);
vac.insert_with_hasher(value_hash, next_code, (), |idx| {
let code = *lookup_dict
.entry(
hasher.hash_one(val),
|idx| val == lookup_bytes(offsets.as_slice(), bytes.as_slice(), idx.as_()),
|idx| {
hasher.hash_one(lookup_bytes(
offsets.as_slice(),
bytes.as_slice(),
idx.as_(),
))
});
},
)
.or_insert_with(|| {
let next_code = offsets.len() as u64 - 1;
bytes.extend_from_slice(val);
offsets.push(bytes.len() as u32);
next_code
}
};
})
.get();

codes.push(code)
}
}
}

let values_validity = if dtype.is_nullable() {
let mut validity = Vec::with_capacity(offsets.len() - 1);
validity.push(false);
validity.extend(vec![true; offsets.len() - 2]);

validity.into()
} else {
Validity::NonNullable
};

let values_validity = dict_values_validity(dtype.is_nullable(), offsets.len() - 1);
(
PrimitiveArray::from(codes),
VarBinArray::try_new(
Expand All @@ -187,6 +161,33 @@ where
)
}

fn dict_values_validity(nullable: bool, len: usize) -> Validity {
if nullable {
Validity::Array(
SparseArray::try_new(
ConstantArray::new(0u64, 1).into_array(),
ConstantArray::new(false, 1).into_array(),
len,
ScalarValue::Bool(true),
)
.vortex_unwrap()
.into_array(),
)
} else {
Validity::NonNullable
}
}

fn lookup_bytes<'a, T: AsPrimitive<usize>>(
offsets: &'a [T],
bytes: &'a [u8],
idx: usize,
) -> &'a [u8] {
let begin: usize = offsets[idx].as_();
let end: usize = offsets[idx + 1].as_();
&bytes[begin..end]
}

#[cfg(test)]
mod test {
use std::str;
Expand Down
2 changes: 2 additions & 0 deletions vortex-array/src/aliases/hash_map.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub type DefaultHashBuilder = hashbrown::DefaultHashBuilder;
pub type HashMap<K, V> = hashbrown::HashMap<K, V>;
pub type Entry<'a, K, V, S> = hashbrown::hash_map::Entry<'a, K, V, S>;
pub type IntoIter<K, V> = hashbrown::hash_map::IntoIter<K, V>;
pub type HashTable<T> = hashbrown::HashTable<T>;
17 changes: 7 additions & 10 deletions vortex-array/src/array/varbinview/accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@ impl ArrayAccessor<[u8]> for VarBinViewArray {
let bytes: Vec<PrimitiveArray> = (0..self.metadata().buffer_lens.len())
.map(|i| self.buffer(i).into_canonical()?.into_primitive())
.try_collect()?;
let bytes_slices: Vec<&[u8]> = bytes.iter().map(|b| b.maybe_null_slice::<u8>()).collect();
let views: Vec<BinaryView> = self.binary_views()?.collect();
let validity = self.logical_validity().to_null_buffer()?;

match validity {
None => {
let mut iter = views.iter().map(|view| {
if view.is_inlined() {
Some(unsafe { &view.inlined.data[..view.len() as usize] })
Some(view.as_inlined().value())
} else {
let offset = unsafe { view._ref.offset as usize };
let buffer_idx = unsafe { view._ref.buffer_index as usize };
Some(
&bytes[buffer_idx].maybe_null_slice::<u8>()
[offset..offset + view.len() as usize],
&bytes_slices[view.as_view().buffer_index() as usize]
[view.as_view().to_range()],
)
}
});
Expand All @@ -39,13 +38,11 @@ impl ArrayAccessor<[u8]> for VarBinViewArray {
let mut iter = views.iter().zip(validity.iter()).map(|(view, valid)| {
if valid {
if view.is_inlined() {
Some(unsafe { &view.inlined.data[..view.len() as usize] })
Some(view.as_inlined().value())
} else {
let offset = unsafe { view._ref.offset as usize };
let buffer_idx = unsafe { view._ref.buffer_index as usize };
Some(
&bytes[buffer_idx].maybe_null_slice::<u8>()
[offset..offset + view.len() as usize],
&bytes_slices[view.as_view().buffer_index() as usize]
[view.as_view().to_range()],
)
}
} else {
Expand Down
Loading

0 comments on commit 0195c02

Please sign in to comment.