Skip to content

Commit

Permalink
implement retract_batch for xor accumulator (#11500)
Browse files Browse the repository at this point in the history
* implement retract_batch for xor accumulator

* add comment
  • Loading branch information
drewhayward authored Jul 17, 2024
1 parent b0925c8 commit fb34ef2
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions datafusion/functions-aggregate/src/bit_and_or_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ where
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
// XOR is it's own inverse
self.update_batch(values)
}

fn supports_retract_batch(&self) -> bool {
true
}

fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
}
Expand Down Expand Up @@ -456,3 +465,41 @@ where
Ok(())
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::UInt64Type;
use datafusion_common::ScalarValue;

use crate::bit_and_or_xor::BitXorAccumulator;
use datafusion_expr::Accumulator;

#[test]
fn test_bit_xor_accumulator() {
let mut accumulator = BitXorAccumulator::<UInt64Type> { value: None };
let batches: Vec<_> = vec![vec![1, 2], vec![1]]
.into_iter()
.map(|b| Arc::new(b.into_iter().collect::<UInt64Array>()) as ArrayRef)
.collect();

let added = &[Arc::clone(&batches[0])];
let retracted = &[Arc::clone(&batches[1])];

// XOR of 1..3 is 3
accumulator.update_batch(added).unwrap();
assert_eq!(
accumulator.evaluate().unwrap(),
ScalarValue::UInt64(Some(3))
);

// Removing [1] ^ 3 = 2
accumulator.retract_batch(retracted).unwrap();
assert_eq!(
accumulator.evaluate().unwrap(),
ScalarValue::UInt64(Some(2))
);
}
}

0 comments on commit fb34ef2

Please sign in to comment.