Skip to content

Commit

Permalink
implement for boxed rbr
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Aug 30, 2023
1 parent 32e973d commit 8675826
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
16 changes: 16 additions & 0 deletions arrow-pyarrow-integration-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use std::sync::Arc;

use arrow::array::new_empty_array;
use arrow::record_batch::{RecordBatchIterator, RecordBatchReader};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
Expand Down Expand Up @@ -152,6 +153,20 @@ fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()
}
}

#[pyfunction]
fn boxed_reader_roundtrip(
obj: PyArrowType<Box<dyn RecordBatchReader + Send>>,
) -> PyArrowType<Box<dyn RecordBatchReader + Send>> {
let schema = obj.0.schema();
let batches = obj
.0
.collect::<Result<Vec<RecordBatch>, ArrowError>>()
.unwrap();
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
PyArrowType(reader)
}

#[pymodule]
fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
Expand All @@ -166,5 +181,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()>
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
Ok(())
}
7 changes: 7 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,13 @@ def test_record_batch_reader():
got_batches = list(b)
assert got_batches == batches

# Also try the boxed reader variant
a = pa.RecordBatchReader.from_batches(schema, batches)
b = rust.boxed_reader_roundtrip(a)
assert b.schema == schema
got_batches = list(b)
assert got_batches == batches

def test_record_batch_reader_error():
schema = pa.schema([('ints', pa.list_(pa.int32()))])

Expand Down
3 changes: 2 additions & 1 deletion arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ pub mod pyarrow;

pub mod record_batch {
pub use arrow_array::{
RecordBatch, RecordBatchOptions, RecordBatchReader, RecordBatchWriter,
RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader,
RecordBatchWriter,
};
}
pub use arrow_array::temporal_conversions;
Expand Down
21 changes: 19 additions & 2 deletions arrow/src/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
use std::sync::Arc;

use arrow_array::RecordBatchReader;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
Expand Down Expand Up @@ -277,10 +278,19 @@ impl FromPyArrow for ArrowArrayStreamReader {
}
}

impl IntoPyArrow for ArrowArrayStreamReader {
impl FromPyArrow for Box<dyn RecordBatchReader + Send> {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
let stream_reader = ArrowArrayStreamReader::from_pyarrow(value)?;
Ok(Box::new(stream_reader))
}
}

// We can't implement `ToPyArrow` for `T: RecordBatchReader + Send` because
// there is already a blanket implementation for `T: ToPyArrow`.
impl IntoPyArrow for Box<dyn RecordBatchReader + Send> {
fn into_pyarrow(self, py: Python) -> PyResult<PyObject> {
let mut stream = FFI_ArrowArrayStream::empty();
unsafe { export_reader_into_raw(Box::new(self), &mut stream) };
unsafe { export_reader_into_raw(self, &mut stream) };

let stream_ptr = (&mut stream) as *mut FFI_ArrowArrayStream;
let module = py.import("pyarrow")?;
Expand All @@ -292,6 +302,13 @@ impl IntoPyArrow for ArrowArrayStreamReader {
}
}

impl IntoPyArrow for ArrowArrayStreamReader {
fn into_pyarrow(self, py: Python) -> PyResult<PyObject> {
let boxed: Box<dyn RecordBatchReader + Send> = Box::new(self);
boxed.into_pyarrow(py)
}
}

/// A newtype wrapper around a `T: PyArrowConvert` that implements
/// [`FromPyObject`] and [`IntoPy`] allowing usage with pyo3 macros
#[derive(Debug)]
Expand Down

0 comments on commit 8675826

Please sign in to comment.