From 0c39d5d497c39bf1280dfdf15307af8eba2a6bd0 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 22 Jul 2024 22:00:15 -0400 Subject: [PATCH] Read from Arrow C Data interface --- vegafusion-common/src/data/ffi.rs | 62 +++++++++++++++++++++++++++++ vegafusion-common/src/data/mod.rs | 3 ++ vegafusion-common/src/data/table.rs | 10 ++++- vegafusion-python-embed/src/lib.rs | 4 ++ 4 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 vegafusion-common/src/data/ffi.rs diff --git a/vegafusion-common/src/data/ffi.rs b/vegafusion-common/src/data/ffi.rs new file mode 100644 index 000000000..0ef3a5e48 --- /dev/null +++ b/vegafusion-common/src/data/ffi.rs @@ -0,0 +1,62 @@ +use arrow::array::{RecordBatch, RecordBatchReader}; +use arrow::datatypes::SchemaRef; +use arrow::ffi_stream::ArrowArrayStreamReader; +use arrow::ffi_stream::FFI_ArrowArrayStream; +use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::PyCapsule; + +/// Validate PyCapsule has provided name +fn validate_pycapsule_name(capsule: &PyCapsule, expected_name: &str) -> PyResult<()> { + let capsule_name = capsule.name()?; + if let Some(capsule_name) = capsule_name { + let capsule_name = capsule_name.to_str()?; + if capsule_name != expected_name { + return Err(PyValueError::new_err(format!( + "Expected name '{}' in PyCapsule, instead got '{}'", + expected_name, capsule_name + ))); + } + } else { + return Err(PyValueError::new_err( + "Expected schema PyCapsule to have name set.", + )); + } + + Ok(()) +} + +/// Import `__arrow_c_stream__` across Python boundary. +fn call_arrow_c_stream(ob: &'_ PyAny) -> PyResult<&'_ PyCapsule> { + if !ob.hasattr("__arrow_c_stream__")? { + return Err(PyValueError::new_err( + "Expected an object with dunder __arrow_c_stream__", + )); + } + + let capsule = ob.getattr("__arrow_c_stream__")?.call0()?.downcast()?; + Ok(capsule) +} + +fn import_stream_pycapsule(capsule: &PyCapsule) -> PyResult { + validate_pycapsule_name(capsule, "arrow_array_stream")?; + + let stream = unsafe { FFI_ArrowArrayStream::from_raw(capsule.pointer() as _) }; + Ok(stream) +} + +pub(crate) fn import_arrow_c_stream(ob: &'_ PyAny) -> PyResult<(Vec, SchemaRef)> { + let capsule = call_arrow_c_stream(ob)?; + let stream = import_stream_pycapsule(capsule)?; + let stream_reader = ArrowArrayStreamReader::try_new(stream) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + let schema = stream_reader.schema(); + + let mut batches = vec![]; + for batch in stream_reader { + let batch = batch.map_err(|err| PyTypeError::new_err(err.to_string()))?; + batches.push(batch); + } + + Ok((batches, schema)) +} diff --git a/vegafusion-common/src/data/mod.rs b/vegafusion-common/src/data/mod.rs index cbbc7c93e..71c6d3b21 100644 --- a/vegafusion-common/src/data/mod.rs +++ b/vegafusion-common/src/data/mod.rs @@ -1,5 +1,8 @@ pub mod table; +#[cfg(feature = "pyarrow")] +mod ffi; + #[cfg(feature = "json")] pub mod json_writer; pub mod scalar; diff --git a/vegafusion-common/src/data/table.rs b/vegafusion-common/src/data/table.rs index 50568364b..6d48f876f 100644 --- a/vegafusion-common/src/data/table.rs +++ b/vegafusion-common/src/data/table.rs @@ -8,6 +8,8 @@ use arrow::{ record_batch::RecordBatch, }; +#[cfg(feature = "pyarrow")] +use crate::data::ffi::import_arrow_c_stream; use crate::{ data::{ORDER_COL, ORDER_COL_DTYPE}, error::{Result, ResultWithContext, VegaFusionError}, @@ -36,7 +38,7 @@ use { pyo3::{ prelude::PyModule, types::{PyList, PyTuple}, - PyAny, PyErr, PyObject, Python, + PyAny, PyErr, PyObject, PyResult, Python, }, }; @@ -267,6 +269,12 @@ impl VegaFusionTable { } } + #[cfg(feature = "pyarrow")] + pub fn from_arrow_c_stream(ob: &PyAny) -> PyResult { + let (batches, schema) = import_arrow_c_stream(ob)?; + Ok(VegaFusionTable::try_new(schema, batches)?) + } + #[cfg(feature = "pyarrow")] pub fn from_pyarrow(py: Python, pyarrow_table: &PyAny) -> std::result::Result { // Extract table.schema as a Rust Schema diff --git a/vegafusion-python-embed/src/lib.rs b/vegafusion-python-embed/src/lib.rs index f64ff5700..1b21e5e97 100644 --- a/vegafusion-python-embed/src/lib.rs +++ b/vegafusion-python-embed/src/lib.rs @@ -186,6 +186,10 @@ impl PyVegaFusionRuntime { .scan_py_datasource(inline_dataset.to_object(py)), )?; VegaFusionDataset::DataFrame(df) + } else if inline_dataset.hasattr("__arrow_c_stream__")? { + // Import via Arrow PyCapsule Interface + let table = VegaFusionTable::from_arrow_c_stream(inline_dataset)?; + VegaFusionDataset::from_table_ipc_bytes(&table.to_ipc_bytes()?)? } else { // Assume PyArrow Table // We convert to ipc bytes for two reasons: