Skip to content

Commit

Permalink
Sum statistic + compute fn (#2474)
Browse files Browse the repository at this point in the history
FLUP: remove TrueCount
  • Loading branch information
gatesn authored Feb 24, 2025
1 parent 4cad7f3 commit cfb198c
Show file tree
Hide file tree
Showing 27 changed files with 667 additions and 120 deletions.
2 changes: 1 addition & 1 deletion encodings/sparse/src/compute/binary_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl BinaryNumericFn<&SparseArray> for SparseEncoding {
let new_fill_value = array
.fill_scalar()
.as_primitive()
.checked_binary_numeric(rhs_scalar.as_primitive(), op)?
.checked_binary_numeric(&rhs_scalar.as_primitive(), op)
.ok_or_else(|| vortex_err!("numeric overflow"))?
.into();
Ok(Some(
Expand Down
62 changes: 36 additions & 26 deletions vortex-array/src/array/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::RwLock;
use vortex_error::VortexExpect;
use vortex_scalar::{Scalar, ScalarValue};

use crate::compute::{min_max, scalar_at, MinMaxResult};
use crate::compute::{min_max, scalar_at, sum, MinMaxResult};
use crate::stats::{Precision, Stat, Statistics, StatsSet};
use crate::{Array, ArrayImpl};

Expand Down Expand Up @@ -67,36 +67,46 @@ impl<A: Array + ArrayImpl> Statistics for A {

// NOTE(ngates): this is the beginning of the stats refactor that pushes stats compute into
// regular compute functions.
let stats_set = if matches!(stat, Stat::Min | Stat::Max) {
let mut stats_set = self.statistics().stats_set();
if let Some(MinMaxResult { min, max }) =
min_max(self).vortex_expect("Failed to compute min/max")
{
if min == max
&& stats_set.get_as::<u64>(Stat::NullCount) == Some(Precision::exact(0u64))
let stats_set = match stat {
Stat::Min | Stat::Max => {
let mut stats_set = self.statistics().stats_set();
if let Some(MinMaxResult { min, max }) =
min_max(self).vortex_expect("Failed to compute min/max")
{
stats_set.set(Stat::IsConstant, Precision::exact(true));
if min == max
&& stats_set.get_as::<u64>(Stat::NullCount) == Some(Precision::exact(0u64))
{
stats_set.set(Stat::IsConstant, Precision::exact(true));
}

stats_set
.combine_sets(
&StatsSet::from_iter([
(Stat::Min, Precision::exact(min.into_value())),
(Stat::Max, Precision::exact(max.into_value())),
]),
self.dtype(),
)
// TODO(ngates): this shouldn't be fallible
.vortex_expect("Failed to combine stats sets");
}

stats_set
.combine_sets(
&StatsSet::from_iter([
(Stat::Min, Precision::exact(min.into_value())),
(Stat::Max, Precision::exact(max.into_value())),
]),
self.dtype(),
)
// TODO(ngates): this shouldn't be fallible
.vortex_expect("Failed to combine stats sets");
}

stats_set
} else {
let vtable = self.vtable();
vtable
.compute_statistics(self, stat)
// TODO(ngates): hmmm, then why does it return a result?
.vortex_expect("compute_statistics must not fail")
// Try to compute the sum and return it.
Stat::Sum => {
return sum(self)
.inspect_err(|e| log::warn!("{}", e))
.ok()
.map(|sum| sum.into_value())
}
_ => {
let vtable = self.vtable();
vtable
.compute_statistics(self, stat)
// TODO(ngates): hmmm, then why does it return a result?
.vortex_expect("compute_statistics must not fail")
}
};

{
Expand Down
7 changes: 6 additions & 1 deletion vortex-array/src/arrays/bool/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::arrays::BoolEncoding;
use crate::compute::{
BinaryBooleanFn, CastFn, FillForwardFn, FillNullFn, FilterFn, InvertFn, MaskFn, MinMaxFn,
ScalarAtFn, SliceFn, TakeFn, ToArrowFn,
ScalarAtFn, SliceFn, SumFn, TakeFn, ToArrowFn,
};
use crate::vtable::ComputeVTable;
use crate::Array;
Expand All @@ -16,6 +16,7 @@ mod mask;
mod min_max;
mod scalar_at;
mod slice;
mod sum;
mod take;
mod to_arrow;

Expand Down Expand Up @@ -60,6 +61,10 @@ impl ComputeVTable for BoolEncoding {
Some(self)
}

fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
Some(self)
}

fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
Some(self)
}
Expand Down
35 changes: 35 additions & 0 deletions vortex-array/src/arrays/bool/compute/sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use std::ops::BitAnd;

use vortex_error::VortexResult;
use vortex_mask::AllOr;
use vortex_scalar::Scalar;

use crate::arrays::{BoolArray, BoolEncoding};
use crate::compute::SumFn;
use crate::stats::Stat;
use crate::Array;

impl SumFn<&BoolArray> for BoolEncoding {
fn sum(&self, array: &BoolArray) -> VortexResult<Scalar> {
let true_count: Option<u64> = match array.validity_mask()?.boolean_buffer() {
AllOr::All => {
// All-valid
Some(array.boolean_buffer().count_set_bits() as u64)
}
AllOr::None => {
// All-invalid
None
}
AllOr::Some(validity_mask) => Some(
array
.boolean_buffer()
.bitand(validity_mask)
.count_set_bits() as u64,
),
};
Ok(Scalar::new(
Stat::Sum.dtype(array.dtype()),
true_count.into(),
))
}
}
1 change: 1 addition & 0 deletions vortex-array/src/arrays/chunked/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod mask;
mod min_max;
mod scalar_at;
mod slice;
mod sum;
mod take;

impl ComputeVTable for ChunkedEncoding {
Expand Down
60 changes: 60 additions & 0 deletions vortex-array/src/arrays/chunked/compute/sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use num_traits::PrimInt;
use vortex_dtype::{match_each_native_ptype, NativePType, PType};
use vortex_error::{VortexExpect, VortexResult};
use vortex_scalar::{FromPrimitiveOrF16, Scalar};

use crate::arrays::{ChunkedArray, ChunkedEncoding};
use crate::compute::{sum, SumFn};
use crate::stats::Stat;
use crate::{Array, ArrayRef};

impl SumFn<&ChunkedArray> for ChunkedEncoding {
fn sum(&self, array: &ChunkedArray) -> VortexResult<Scalar> {
let sum_dtype = Stat::Sum.dtype(array.dtype());
let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");

let scalar_value = match_each_native_ptype!(
sum_ptype,
unsigned: |$T| { sum_int::<u64>(array.chunks())?.into() }
signed: |$T| { sum_int::<i64>(array.chunks())?.into() }
floating: |$T| { sum_float(array.chunks())?.into() }
);

Ok(Scalar::new(sum_dtype, scalar_value))
}
}

fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
chunks: &[ArrayRef],
) -> VortexResult<Option<T>> {
let mut result = T::zero();
for chunk in chunks {
let chunk_sum = sum(chunk)?;

let Some(chunk_sum) = chunk_sum.as_primitive().as_::<T>()? else {
// Bail out on overflow
return Ok(None);
};

let Some(chunk_result) = result.checked_add(&chunk_sum) else {
// Bail out on overflow
return Ok(None);
};

result = chunk_result;
}
Ok(Some(result))
}

fn sum_float(chunks: &[ArrayRef]) -> VortexResult<f64> {
let mut result = 0f64;
for chunk in chunks {
let chunk_sum = sum(chunk)?;
let chunk_sum = chunk_sum
.as_primitive()
.as_::<f64>()?
.vortex_expect("Float sum should never be null");
result += chunk_sum;
}
Ok(result)
}
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/constant/compute/binary_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl BinaryNumericFn<&ConstantArray> for ConstantEncoding {
array
.scalar()
.as_primitive()
.checked_binary_numeric(rhs.as_primitive(), op)?
.checked_binary_numeric(&rhs.as_primitive(), op)
.ok_or_else(|| vortex_err!("numeric overflow"))?,
array.len(),
)
Expand Down
15 changes: 10 additions & 5 deletions vortex-array/src/arrays/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::arrays::PrimitiveEncoding;
use crate::compute::{
BetweenFn, CastFn, FillForwardFn, FillNullFn, FilterFn, MaskFn, MinMaxFn, ScalarAtFn,
SearchSortedFn, SearchSortedUsizeFn, SliceFn, TakeFn, ToArrowFn,
SearchSortedFn, SearchSortedUsizeFn, SliceFn, SumFn, TakeFn, ToArrowFn,
};
use crate::vtable::ComputeVTable;
use crate::Array;
Expand All @@ -16,6 +16,7 @@ mod min_max;
mod scalar_at;
mod search_sorted;
mod slice;
mod sum;
mod take;
mod to_arrow;

Expand All @@ -24,7 +25,7 @@ impl ComputeVTable for PrimitiveEncoding {
Some(self)
}

fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
fn between_fn(&self) -> Option<&dyn BetweenFn<&dyn Array>> {
Some(self)
}

Expand All @@ -40,15 +41,15 @@ impl ComputeVTable for PrimitiveEncoding {
Some(self)
}

fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
Some(self)
}

fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
Some(self)
}

fn between_fn(&self) -> Option<&dyn BetweenFn<&dyn Array>> {
fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
Some(self)
}

Expand All @@ -60,6 +61,10 @@ impl ComputeVTable for PrimitiveEncoding {
Some(self)
}

fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
Some(self)
}

fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
Some(self)
}
Expand Down
94 changes: 94 additions & 0 deletions vortex-array/src/arrays/primitive/compute/sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use arrow_buffer::BooleanBuffer;
use itertools::Itertools;
use num_traits::{CheckedAdd, Float, ToPrimitive};
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::{VortexExpect, VortexResult};
use vortex_mask::AllOr;
use vortex_scalar::Scalar;

use crate::arrays::{PrimitiveArray, PrimitiveEncoding};
use crate::compute::SumFn;
use crate::stats::Stat;
use crate::variants::PrimitiveArrayTrait;
use crate::Array;

impl SumFn<&PrimitiveArray> for PrimitiveEncoding {
fn sum(&self, array: &PrimitiveArray) -> VortexResult<Scalar> {
let scalar_value = match array.validity_mask()?.boolean_buffer() {
AllOr::All => {
// All-valid
match_each_native_ptype!(
array.ptype(),
unsigned: |$T| { sum_integer::<_, u64>(array.as_slice::<$T>()).into() }
signed: |$T| { sum_integer::<_, i64>(array.as_slice::<$T>()).into() }
floating: |$T| { sum_float(array.as_slice::<$T>()).into() }
)
}
AllOr::None => {
// All-invalid
return Ok(Scalar::null(Stat::Sum.dtype(array.dtype())));
}
AllOr::Some(validity_mask) => {
// Some-valid
match_each_native_ptype!(
array.ptype(),
unsigned: |$T| {
sum_integer_with_validity::<_, u64>(array.as_slice::<$T>(), validity_mask)
.into()
}
signed: |$T| {
sum_integer_with_validity::<_, i64>(array.as_slice::<$T>(), validity_mask)
.into()
}
floating: |$T| {
sum_float_with_validity(array.as_slice::<$T>(), validity_mask).into()
}
)
}
};

let sum_dtype = Stat::Sum.dtype(array.dtype());
Ok(Scalar::new(sum_dtype, scalar_value))
}
}

fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
values: &[T],
) -> Option<R> {
let mut sum = R::zero();
for &x in values {
sum = sum.checked_add(&R::from(x)?)?;
}
Some(sum)
}

fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
values: &[T],
validity: &BooleanBuffer,
) -> Option<R> {
let mut sum = R::zero();
for (&x, valid) in values.iter().zip_eq(validity.iter()) {
if valid {
sum = sum.checked_add(&R::from(x)?)?;
}
}
Some(sum)
}

fn sum_float<T: NativePType + Float>(values: &[T]) -> f64 {
let mut sum = 0.0;
for &x in values {
sum += x.to_f64().vortex_expect("Failed to cast value to f64");
}
sum
}

fn sum_float_with_validity<T: NativePType + Float>(array: &[T], validity: &BooleanBuffer) -> f64 {
let mut sum = 0.0;
for (&x, valid) in array.iter().zip_eq(validity.iter()) {
if valid {
sum += x.to_f64().vortex_expect("Failed to cast value to f64");
}
}
sum
}
1 change: 1 addition & 0 deletions vortex-array/src/arrays/varbin/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub fn compute_varbin_statistics<T: ArrayAccessor<[u8]> + Array>(
stat
)
}
Stat::Sum => unreachable!("Sum is not supported for VarBinArray"),
})
}

Expand Down
Loading

0 comments on commit cfb198c

Please sign in to comment.