Skip to content

Commit

Permalink
Read from Arrow C Data interface
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Jul 23, 2024
1 parent 007bd44 commit 0c39d5d
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 1 deletion.
62 changes: 62 additions & 0 deletions vegafusion-common/src/data/ffi.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use arrow::array::{RecordBatch, RecordBatchReader};
use arrow::datatypes::SchemaRef;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::ffi_stream::FFI_ArrowArrayStream;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyCapsule;

/// Validate PyCapsule has provided name
fn validate_pycapsule_name(capsule: &PyCapsule, expected_name: &str) -> PyResult<()> {
let capsule_name = capsule.name()?;
if let Some(capsule_name) = capsule_name {
let capsule_name = capsule_name.to_str()?;
if capsule_name != expected_name {
return Err(PyValueError::new_err(format!(
"Expected name '{}' in PyCapsule, instead got '{}'",
expected_name, capsule_name
)));
}
} else {
return Err(PyValueError::new_err(
"Expected schema PyCapsule to have name set.",
));
}

Ok(())
}

/// Import `__arrow_c_stream__` across Python boundary.
fn call_arrow_c_stream(ob: &'_ PyAny) -> PyResult<&'_ PyCapsule> {
if !ob.hasattr("__arrow_c_stream__")? {
return Err(PyValueError::new_err(
"Expected an object with dunder __arrow_c_stream__",
));
}

let capsule = ob.getattr("__arrow_c_stream__")?.call0()?.downcast()?;
Ok(capsule)
}

fn import_stream_pycapsule(capsule: &PyCapsule) -> PyResult<FFI_ArrowArrayStream> {
validate_pycapsule_name(capsule, "arrow_array_stream")?;

let stream = unsafe { FFI_ArrowArrayStream::from_raw(capsule.pointer() as _) };
Ok(stream)
}

pub(crate) fn import_arrow_c_stream(ob: &'_ PyAny) -> PyResult<(Vec<RecordBatch>, SchemaRef)> {
let capsule = call_arrow_c_stream(ob)?;
let stream = import_stream_pycapsule(capsule)?;
let stream_reader = ArrowArrayStreamReader::try_new(stream)
.map_err(|err| PyValueError::new_err(err.to_string()))?;
let schema = stream_reader.schema();

let mut batches = vec![];
for batch in stream_reader {
let batch = batch.map_err(|err| PyTypeError::new_err(err.to_string()))?;
batches.push(batch);
}

Ok((batches, schema))
}
3 changes: 3 additions & 0 deletions vegafusion-common/src/data/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
pub mod table;

#[cfg(feature = "pyarrow")]
mod ffi;

#[cfg(feature = "json")]
pub mod json_writer;
pub mod scalar;
Expand Down
10 changes: 9 additions & 1 deletion vegafusion-common/src/data/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use arrow::{
record_batch::RecordBatch,
};

#[cfg(feature = "pyarrow")]
use crate::data::ffi::import_arrow_c_stream;
use crate::{
data::{ORDER_COL, ORDER_COL_DTYPE},
error::{Result, ResultWithContext, VegaFusionError},
Expand Down Expand Up @@ -36,7 +38,7 @@ use {
pyo3::{
prelude::PyModule,
types::{PyList, PyTuple},
PyAny, PyErr, PyObject, Python,
PyAny, PyErr, PyObject, PyResult, Python,
},
};

Expand Down Expand Up @@ -267,6 +269,12 @@ impl VegaFusionTable {
}
}

#[cfg(feature = "pyarrow")]
pub fn from_arrow_c_stream(ob: &PyAny) -> PyResult<Self> {
let (batches, schema) = import_arrow_c_stream(ob)?;
Ok(VegaFusionTable::try_new(schema, batches)?)
}

#[cfg(feature = "pyarrow")]
pub fn from_pyarrow(py: Python, pyarrow_table: &PyAny) -> std::result::Result<Self, PyErr> {
// Extract table.schema as a Rust Schema
Expand Down
4 changes: 4 additions & 0 deletions vegafusion-python-embed/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ impl PyVegaFusionRuntime {
.scan_py_datasource(inline_dataset.to_object(py)),
)?;
VegaFusionDataset::DataFrame(df)
} else if inline_dataset.hasattr("__arrow_c_stream__")? {
// Import via Arrow PyCapsule Interface
let table = VegaFusionTable::from_arrow_c_stream(inline_dataset)?;
VegaFusionDataset::from_table_ipc_bytes(&table.to_ipc_bytes()?)?
} else {
// Assume PyArrow Table
// We convert to ipc bytes for two reasons:
Expand Down

0 comments on commit 0c39d5d

Please sign in to comment.