From 80b0888fa172090c91588142a5f964cc158cb5c8 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 17 Sep 2023 06:35:14 -0400 Subject: [PATCH] fix: export record batch through stream (#4806) * fix: export record batch through stream * Update arrow/src/pyarrow.rs --------- Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- .../tests/test_sql.py | 17 ++++++++++ arrow/src/pyarrow.rs | 31 ++++++------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 3be5b9ec52fe..1748fd3ffb6b 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -393,6 +393,23 @@ def test_sparse_union_python(): del a del b +def test_tensor_array(): + tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3]) + inner = pa.array([float(x) for x in range(1, 7)] + [None] * 12, pa.float32()) + storage = pa.FixedSizeListArray.from_arrays(inner, 6) + f32_array = pa.ExtensionArray.from_storage(tensor_type, storage) + + # Round-tripping as an array gives back storage type, because arrow-rs has + # no notion of extension types. + b = rust.round_trip_array(f32_array) + assert b == f32_array.storage + + batch = pa.record_batch([f32_array], ["tensor"]) + b = rust.round_trip_record_batch(batch) + assert b == batch + + del b + def test_record_batch_reader(): """ Python -> Rust -> Python diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 6063ae763228..ab0ea8ef8d74 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -59,14 +59,14 @@ use std::convert::{From, TryFrom}; use std::ptr::{addr_of, addr_of_mut}; use std::sync::Arc; -use arrow_array::RecordBatchReader; +use arrow_array::{RecordBatchIterator, RecordBatchReader}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::ffi::Py_uintptr_t; use pyo3::import_exception; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PyTuple}; +use pyo3::types::{PyList, PyTuple}; -use crate::array::{make_array, Array, ArrayData}; +use crate::array::{make_array, ArrayData}; use crate::datatypes::{DataType, Field, Schema}; use crate::error::ArrowError; use crate::ffi; @@ -270,25 +270,12 @@ impl FromPyArrow for RecordBatch { impl ToPyArrow for RecordBatch { fn to_pyarrow(&self, py: Python) -> PyResult { - let mut py_arrays = vec![]; - - let schema = self.schema(); - let columns = self.columns().iter(); - - for array in columns { - py_arrays.push(array.to_data().to_pyarrow(py)?); - } - - let py_schema = schema.to_pyarrow(py)?; - - let module = py.import("pyarrow")?; - let class = module.getattr("RecordBatch")?; - let args = (py_arrays,); - let kwargs = PyDict::new(py); - kwargs.set_item("schema", py_schema)?; - let record = class.call_method("from_arrays", args, Some(kwargs))?; - - Ok(PyObject::from(record)) + // Workaround apache/arrow#37669 by returning RecordBatchIterator + let reader = + RecordBatchIterator::new(vec![Ok(self.clone())], self.schema().clone()); + let reader: Box = Box::new(reader); + let py_reader = reader.into_pyarrow(py)?; + py_reader.call_method0(py, "read_next_batch") } }