diff --git a/Cargo.lock b/Cargo.lock index b5a3fac37a..a353aaeb54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3079,6 +3079,7 @@ dependencies = [ "arrow", "flexbuffers", "futures", + "itertools 0.13.0", "log", "paste", "pyo3", diff --git a/pyvortex/Cargo.toml b/pyvortex/Cargo.toml index c8ca8b86d5..3368ca9290 100644 --- a/pyvortex/Cargo.toml +++ b/pyvortex/Cargo.toml @@ -42,6 +42,7 @@ vortex-sampling-compressor = { workspace = true } vortex-serde = { workspace = true, features = ["tokio"] } vortex-scalar = { workspace = true } vortex-zigzag = { workspace = true } +itertools = { workspace = true } # We may need this workaround? # https://pyo3.rs/v0.20.2/faq.html#i-cant-run-cargo-test-or-i-cant-build-in-a-cargo-workspace-im-having-linker-issues-like-symbol-not-found-or-undefined-reference-to-_pyexc_systemerror diff --git a/pyvortex/src/array.rs b/pyvortex/src/array.rs index 5e99ac86a8..3fa0596d50 100644 --- a/pyvortex/src/array.rs +++ b/pyvortex/src/array.rs @@ -9,6 +9,7 @@ use vortex::{Array, ArrayDType, IntoCanonical}; use crate::dtype::PyDType; use crate::error::PyVortexError; +use crate::python_repr::PythonRepr; #[pyclass(name = "Array", module = "vortex", sequence, subclass)] /// An array of zero or more *rows* each with the same set of *columns*. @@ -167,9 +168,19 @@ impl PyArray { /// "b", /// "a" /// ] - fn take<'py>(&self, indices: PyRef<'py, Self>) -> PyResult> { - take(&self.inner, indices.unwrap()) + fn take<'py>(&self, indices: &Bound<'py, PyArray>) -> PyResult> { + let py = indices.py(); + let indices = &indices.borrow().inner; + + if !indices.dtype().is_int() { + return Err(PyValueError::new_err(format!( + "indices: expected int or uint array, but found: {}", + indices.dtype().python_repr() + ))); + } + + take(&self.inner, &indices) .map_err(PyVortexError::map_err) - .and_then(|arr| Bound::new(indices.py(), PyArray { inner: arr })) + .and_then(|arr| Bound::new(py, PyArray { inner: arr })) } } diff --git a/pyvortex/src/dtype.rs b/pyvortex/src/dtype.rs index 0737f2c0f2..5677287692 100644 --- a/pyvortex/src/dtype.rs +++ b/pyvortex/src/dtype.rs @@ -6,6 +6,8 @@ use pyo3::{pyclass, pyfunction, pymethods, Bound, Py, PyAny, PyResult, Python}; use vortex::arrow::FromArrowType; use vortex_dtype::{DType, PType}; +use crate::python_repr::PythonRepr; + #[pyclass(name = "DType", module = "vortex", subclass)] /// A data type describes the set of operations available on a given column. These operations are /// implemented by the column *encoding*. Each data type is implemented by one or more encodings. diff --git a/pyvortex/src/lib.rs b/pyvortex/src/lib.rs index d9895c3a22..d8629e6850 100644 --- a/pyvortex/src/lib.rs +++ b/pyvortex/src/lib.rs @@ -11,6 +11,7 @@ mod encode; mod error; mod expr; mod io; +mod python_repr; /// Vortex is an Apache Arrow-compatible toolkit for working with compressed array data. #[pymodule] diff --git a/pyvortex/src/python_repr.rs b/pyvortex/src/python_repr.rs new file mode 100644 index 0000000000..83e34899a2 --- /dev/null +++ b/pyvortex/src/python_repr.rs @@ -0,0 +1,107 @@ +use std::convert::AsRef; +use std::fmt::{Display, Formatter}; + +use itertools::Itertools; +use vortex_dtype::{DType, ExtID, ExtMetadata, Nullability, PType}; + +pub trait PythonRepr { + fn python_repr(&self) -> impl Display; +} + +struct DTypePythonRepr<'a>(&'a DType); + +impl PythonRepr for DType { + fn python_repr(&self) -> impl Display { + return DTypePythonRepr(self); + } +} + +impl Display for DTypePythonRepr<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let DTypePythonRepr(x) = self; + match x { + DType::Null => write!(f, "null()"), + DType::Bool(n) => write!(f, "bool({})", n.python_repr()), + DType::Primitive(p, n) => match p { + PType::U8 | PType::U16 | PType::U32 | PType::U64 => { + write!(f, "uint({}, {})", p.bit_width(), n.python_repr()) + } + PType::I8 | PType::I16 | PType::I32 | PType::I64 => { + write!(f, "int({}, {})", p.bit_width(), n.python_repr()) + } + PType::F16 | PType::F32 | PType::F64 => { + write!(f, "float({}, {})", p.bit_width(), n.python_repr()) + } + }, + DType::Utf8(n) => write!(f, "utf8({})", n.python_repr()), + DType::Binary(n) => write!(f, "binary({})", n.python_repr()), + DType::Struct(st, n) => write!( + f, + "struct({{{}}}, {})", + st.names() + .iter() + .zip(st.dtypes().iter()) + .map(|(n, dt)| format!("\"{}\": {}", n, dt.python_repr())) + .join(", "), + n.python_repr() + ), + DType::List(c, n) => write!(f, "list({}, {})", c.python_repr(), n.python_repr()), + DType::Extension(ext, n) => { + write!(f, "ext(\"{}\", ", ext.id().python_repr())?; + match ext.metadata() { + None => write!(f, "None")?, + Some(metadata) => write!(f, "{}", metadata.python_repr())?, + }; + write!(f, ", {})", n.python_repr()) + } + } + } +} + +struct NullabilityPythonRepr<'a>(&'a Nullability); + +impl PythonRepr for Nullability { + fn python_repr(&self) -> impl Display { + return NullabilityPythonRepr(self); + } +} + +impl Display for NullabilityPythonRepr<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let NullabilityPythonRepr(x) = self; + match x { + Nullability::NonNullable => write!(f, "False"), + Nullability::Nullable => write!(f, "True"), + } + } +} + +struct ExtMetadataPythonRepr<'a>(&'a ExtMetadata); + +impl PythonRepr for ExtMetadata { + fn python_repr(&self) -> impl Display { + return ExtMetadataPythonRepr(self); + } +} + +impl Display for ExtMetadataPythonRepr<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let ExtMetadataPythonRepr(metadata) = self; + write!(f, "\"{}\"", metadata.as_ref().escape_ascii()) + } +} + +struct ExtIDPythonRepr<'a>(&'a ExtID); + +impl PythonRepr for ExtID { + fn python_repr(&self) -> impl Display { + ExtIDPythonRepr(self) + } +} + +impl Display for ExtIDPythonRepr<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let ExtIDPythonRepr(ext_id) = self; + write!(f, "\"{}\"", ext_id.as_ref().escape_default()) + } +} diff --git a/vortex-array/src/array/struct_/mod.rs b/vortex-array/src/array/struct_/mod.rs index 5d32c4899a..e6d0486301 100644 --- a/vortex-array/src/array/struct_/mod.rs +++ b/vortex-array/src/array/struct_/mod.rs @@ -41,6 +41,8 @@ impl StructArray { length: usize, validity: Validity, ) -> VortexResult { + let nullability = validity.nullability(); + if names.len() != fields.len() { vortex_bail!("Got {} names and {} fields", names.len(), fields.len()); } @@ -55,15 +57,12 @@ impl StructArray { let mut children = Vec::with_capacity(fields.len() + 1); children.extend(fields); - if let Some(v) = validity.clone().into_array() { + if let Some(v) = validity.into_array() { children.push(v); } Self::try_from_parts( - DType::Struct( - StructDType::new(names, field_dtypes), - validity.nullability(), - ), + DType::Struct(StructDType::new(names, field_dtypes), nullability), length, StructMetadata { length, diff --git a/vortex-array/src/compute/take.rs b/vortex-array/src/compute/take.rs index dcab99d1a2..6f4136a021 100644 --- a/vortex-array/src/compute/take.rs +++ b/vortex-array/src/compute/take.rs @@ -1,7 +1,7 @@ use log::info; -use vortex_error::{vortex_bail, vortex_err, VortexResult}; +use vortex_error::{vortex_err, VortexResult}; -use crate::{Array, ArrayDType, IntoCanonical}; +use crate::{Array, IntoCanonical}; pub trait TakeFn { fn take(&self, indices: &Array) -> VortexResult; @@ -10,10 +10,6 @@ pub trait TakeFn { pub fn take(array: &Array, indices: &Array) -> VortexResult { array.with_dyn(|a| { if let Some(take) = a.take() { - if !indices.dtype().is_int() { - vortex_bail!(InvalidArgument: "indices: expected int or uint array, but found: {}", indices.dtype().python_repr()); - } - return take.take(indices); } diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index 97450d367a..c513604962 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -120,10 +120,6 @@ impl DType { _ => None, } } - - pub fn python_repr(&self) -> DTypePythonRepr { - DTypePythonRepr { dtype: self } - } } impl Display for DType { @@ -158,51 +154,6 @@ impl Display for DType { } } -pub struct DTypePythonRepr<'a> { - dtype: &'a DType, -} - -impl Display for DTypePythonRepr<'_> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.dtype { - Null => write!(f, "null()"), - Bool(n) => write!(f, "bool({})", n.python_repr()), - Primitive(p, n) => match p { - PType::U8 | PType::U16 | PType::U32 | PType::U64 => { - write!(f, "uint({}, {})", p.bit_width(), n.python_repr()) - } - PType::I8 | PType::I16 | PType::I32 | PType::I64 => { - write!(f, "int({}, {})", p.bit_width(), n.python_repr()) - } - PType::F16 | PType::F32 | PType::F64 => { - write!(f, "float({}, {})", p.bit_width(), n.python_repr()) - } - }, - Utf8(n) => write!(f, "utf8({})", n.python_repr()), - Binary(n) => write!(f, "binary({})", n.python_repr()), - Struct(st, n) => write!( - f, - "struct({{{}}}, {})", - st.names() - .iter() - .zip(st.dtypes().iter()) - .map(|(n, dt)| format!("\"{}\": {}", n, dt.python_repr())) - .join(", "), - n.python_repr() - ), - List(c, n) => write!(f, "list({}, {})", c.python_repr(), n.python_repr()), - Extension(ext, n) => { - write!(f, "ext(\"{}\", ", ext.id().python_repr())?; - match ext.metadata() { - None => write!(f, "None")?, - Some(metadata) => write!(f, "{}", metadata.python_repr())?, - }; - write!(f, ", {})", n.python_repr()) - } - } - } -} - #[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct StructDType { diff --git a/vortex-dtype/src/extension.rs b/vortex-dtype/src/extension.rs index 6651d97fe5..18daec5505 100644 --- a/vortex-dtype/src/extension.rs +++ b/vortex-dtype/src/extension.rs @@ -9,10 +9,6 @@ impl ExtID { pub fn new(value: Arc) -> Self { Self(value) } - - pub fn python_repr(&self) -> ExtIDPythonRepr { - ExtIDPythonRepr { ext_id: self } - } } impl Display for ExtID { @@ -33,18 +29,6 @@ impl From<&str> for ExtID { } } -pub struct ExtIDPythonRepr<'a> { - ext_id: &'a ExtID, -} - -impl Display for ExtIDPythonRepr<'_> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.ext_id { - ExtID(id) => write!(f, "\"{}\"", id.escape_default()), - } - } -} - #[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct ExtMetadata(Arc<[u8]>); @@ -53,10 +37,6 @@ impl ExtMetadata { pub fn new(value: Arc<[u8]>) -> Self { Self(value) } - - pub fn python_repr(&self) -> ExtMetadataPythonRepr { - ExtMetadataPythonRepr { ext_metadata: self } - } } impl AsRef<[u8]> for ExtMetadata { @@ -71,18 +51,6 @@ impl From<&[u8]> for ExtMetadata { } } -pub struct ExtMetadataPythonRepr<'a> { - ext_metadata: &'a ExtMetadata, -} - -impl Display for ExtMetadataPythonRepr<'_> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.ext_metadata { - ExtMetadata(metadata) => write!(f, "\"{}\"", metadata.escape_ascii()), - } - } -} - #[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct ExtDType { diff --git a/vortex-dtype/src/nullability.rs b/vortex-dtype/src/nullability.rs index e9decc31cd..d60ba49610 100644 --- a/vortex-dtype/src/nullability.rs +++ b/vortex-dtype/src/nullability.rs @@ -7,11 +7,7 @@ pub enum Nullability { Nullable, } -impl Nullability { - pub fn python_repr(&self) -> NullabilityPythonRepr { - NullabilityPythonRepr { nullability: self } - } -} +impl Nullability {} impl From for Nullability { fn from(value: bool) -> Self { @@ -40,16 +36,3 @@ impl Display for Nullability { } } } - -pub struct NullabilityPythonRepr<'a> { - nullability: &'a Nullability, -} - -impl Display for NullabilityPythonRepr<'_> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.nullability { - Nullability::NonNullable => write!(f, "False"), - Nullability::Nullable => write!(f, "True"), - } - } -}