Skip to content

Commit

Permalink
feat: teach PyArray to compare (#1090)
Browse files Browse the repository at this point in the history
  • Loading branch information
danking authored Oct 21, 2024
1 parent 015b067 commit 36a7a94
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/encoding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Arrays
.. automodule:: vortex.encoding
:members:
:imported-members:
:special-members: __len__
:special-members: __len__, __lt__, __le__, __eq__, __ne__, __ge__, __gt__
112 changes: 111 additions & 1 deletion pyvortex/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyList};
use vortex::array::ChunkedArray;
use vortex::compute::unary::fill_forward;
use vortex::compute::{slice, take};
use vortex::compute::{compare, slice, take, Operator};
use vortex::{Array, ArrayDType, IntoCanonical};

use crate::dtype::PyDType;
Expand All @@ -14,6 +14,68 @@ use crate::python_repr::PythonRepr;

#[pyclass(name = "Array", module = "vortex", sequence, subclass)]
/// An array of zero or more *rows* each with the same set of *columns*.
///
/// Examples
/// --------
///
/// Arrays support all the standard comparison operations:
///
/// >>> a = vortex.encoding.array(['dog', None, 'cat', 'mouse', 'fish'])
/// >>> b = vortex.encoding.array(['doug', 'jennifer', 'casper', 'mouse', 'faust'])
/// >>> (a < b).to_arrow_array()
/// <pyarrow.lib.BooleanArray object at ...>
/// [
/// true,
/// null,
/// false,
/// false,
/// false
/// ]
/// >>> (a <= b).to_arrow_array()
/// <pyarrow.lib.BooleanArray object at ...>
/// [
/// true,
/// null,
/// false,
/// true,
/// false
/// ]
/// >>> (a == b).to_arrow_array()
/// <pyarrow.lib.BooleanArray object at ...>
/// [
/// false,
/// null,
/// false,
/// true,
/// false
/// ]
/// >>> (a != b).to_arrow_array()
/// <pyarrow.lib.BooleanArray object at ...>
/// [
/// true,
/// null,
/// true,
/// false,
/// true
/// ]
/// >>> (a >= b).to_arrow_array()
/// <pyarrow.lib.BooleanArray object at ...>
/// [
/// false,
/// null,
/// true,
/// true,
/// true
/// ]
/// >>> (a > b).to_arrow_array()
/// <pyarrow.lib.BooleanArray object at ...>
/// [
/// false,
/// null,
/// true,
/// false,
/// true
/// ]
pub struct PyArray {
inner: Array,
}
Expand Down Expand Up @@ -139,6 +201,54 @@ impl PyArray {
PyDType::wrap(self_.py(), self_.inner.dtype().clone())
}

// Rust docs are *not* copied into Python for __lt__: https://github.com/PyO3/pyo3/issues/4326
fn __lt__(&self, other: &Bound<PyArray>) -> PyResult<PyArray> {
let other = other.borrow();
compare(&self.inner, &other.inner, Operator::Lt)
.map(|arr| PyArray { inner: arr })
.map_err(PyVortexError::map_err)
}

// Rust docs are *not* copied into Python for __le__: https://github.com/PyO3/pyo3/issues/4326
fn __le__(&self, other: &Bound<PyArray>) -> PyResult<PyArray> {
let other = other.borrow();
compare(&self.inner, &other.inner, Operator::Lte)
.map(|arr| PyArray { inner: arr })
.map_err(PyVortexError::map_err)
}

// Rust docs are *not* copied into Python for __eq__: https://github.com/PyO3/pyo3/issues/4326
fn __eq__(&self, other: &Bound<PyArray>) -> PyResult<PyArray> {
let other = other.borrow();
compare(&self.inner, &other.inner, Operator::Eq)
.map(|arr| PyArray { inner: arr })
.map_err(PyVortexError::map_err)
}

// Rust docs are *not* copied into Python for __ne__: https://github.com/PyO3/pyo3/issues/4326
fn __ne__(&self, other: &Bound<PyArray>) -> PyResult<PyArray> {
let other = other.borrow();
compare(&self.inner, &other.inner, Operator::NotEq)
.map(|arr| PyArray { inner: arr })
.map_err(PyVortexError::map_err)
}

// Rust docs are *not* copied into Python for __ge__: https://github.com/PyO3/pyo3/issues/4326
fn __ge__(&self, other: &Bound<PyArray>) -> PyResult<PyArray> {
let other = other.borrow();
compare(&self.inner, &other.inner, Operator::Gte)
.map(|arr| PyArray { inner: arr })
.map_err(PyVortexError::map_err)
}

// Rust docs are *not* copied into Python for __gt__: https://github.com/PyO3/pyo3/issues/4326
fn __gt__(&self, other: &Bound<PyArray>) -> PyResult<PyArray> {
let other = other.borrow();
compare(&self.inner, &other.inner, Operator::Gt)
.map(|arr| PyArray { inner: arr })
.map_err(PyVortexError::map_err)
}

/// Filter an Array by another Boolean array.
///
/// Parameters
Expand Down

0 comments on commit 36a7a94

Please sign in to comment.