diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index adcec769f247..1814ac4fb121 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::array::new_empty_array; +use arrow::record_batch::{RecordBatchIterator, RecordBatchReader}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::wrap_pyfunction; @@ -152,6 +153,20 @@ fn reader_return_errors(obj: PyArrowType) -> PyResult<() } } +#[pyfunction] +fn boxed_reader_roundtrip( + obj: PyArrowType>, +) -> PyArrowType> { + let schema = obj.0.schema(); + let batches = obj + .0 + .collect::, ArrowError>>() + .unwrap(); + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + let reader: Box = Box::new(reader); + PyArrowType(reader) +} + #[pymodule] fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(double))?; @@ -166,5 +181,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?; m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?; m.add_wrapped(wrap_pyfunction!(reader_return_errors))?; + m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?; Ok(()) } diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 92782b9ed473..2d53882b5f12 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -409,6 +409,13 @@ def test_record_batch_reader(): got_batches = list(b) assert got_batches == batches + # Also try the boxed reader variant + a = pa.RecordBatchReader.from_batches(schema, batches) + b = rust.boxed_reader_roundtrip(a) + assert b.schema == schema + got_batches = list(b) + assert got_batches == batches + def test_record_batch_reader_error(): schema = pa.schema([('ints', pa.list_(pa.int32()))]) diff --git a/arrow/src/lib.rs b/arrow/src/lib.rs index fb904c1908e6..f4d0585fa6b5 100644 --- a/arrow/src/lib.rs +++ b/arrow/src/lib.rs @@ -375,7 +375,8 @@ pub mod pyarrow; pub mod record_batch { pub use arrow_array::{ - RecordBatch, RecordBatchOptions, RecordBatchReader, RecordBatchWriter, + RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, + RecordBatchWriter, }; } pub use arrow_array::temporal_conversions; diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 54a247d53e6d..bbb96e2cc4cc 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -24,6 +24,7 @@ use std::convert::{From, TryFrom}; use std::ptr::{addr_of, addr_of_mut}; use std::sync::Arc; +use arrow_array::RecordBatchReader; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::ffi::Py_uintptr_t; use pyo3::import_exception; @@ -277,10 +278,19 @@ impl FromPyArrow for ArrowArrayStreamReader { } } -impl IntoPyArrow for ArrowArrayStreamReader { +impl FromPyArrow for Box { + fn from_pyarrow(value: &PyAny) -> PyResult { + let stream_reader = ArrowArrayStreamReader::from_pyarrow(value)?; + Ok(Box::new(stream_reader)) + } +} + +// We can't implement `ToPyArrow` for `T: RecordBatchReader + Send` because +// there is already a blanket implementation for `T: ToPyArrow`. +impl IntoPyArrow for Box { fn into_pyarrow(self, py: Python) -> PyResult { let mut stream = FFI_ArrowArrayStream::empty(); - unsafe { export_reader_into_raw(Box::new(self), &mut stream) }; + unsafe { export_reader_into_raw(self, &mut stream) }; let stream_ptr = (&mut stream) as *mut FFI_ArrowArrayStream; let module = py.import("pyarrow")?; @@ -292,6 +302,13 @@ impl IntoPyArrow for ArrowArrayStreamReader { } } +impl IntoPyArrow for ArrowArrayStreamReader { + fn into_pyarrow(self, py: Python) -> PyResult { + let boxed: Box = Box::new(self); + boxed.into_pyarrow(py) + } +} + /// A newtype wrapper around a `T: PyArrowConvert` that implements /// [`FromPyObject`] and [`IntoPy`] allowing usage with pyo3 macros #[derive(Debug)]