Skip to content

Commit

Permalink
fix: export record batch through stream (#4806)
Browse files Browse the repository at this point in the history
* fix: export record batch through stream

* Update arrow/src/pyarrow.rs

---------

Co-authored-by: Raphael Taylor-Davies <[email protected]>
  • Loading branch information
wjones127 and tustvold authored Sep 17, 2023
1 parent d960379 commit 80b0888
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
17 changes: 17 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ def test_sparse_union_python():
del a
del b

def test_tensor_array():
tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3])
inner = pa.array([float(x) for x in range(1, 7)] + [None] * 12, pa.float32())
storage = pa.FixedSizeListArray.from_arrays(inner, 6)
f32_array = pa.ExtensionArray.from_storage(tensor_type, storage)

# Round-tripping as an array gives back storage type, because arrow-rs has
# no notion of extension types.
b = rust.round_trip_array(f32_array)
assert b == f32_array.storage

batch = pa.record_batch([f32_array], ["tensor"])
b = rust.round_trip_record_batch(batch)
assert b == batch

del b

def test_record_batch_reader():
"""
Python -> Rust -> Python
Expand Down
31 changes: 9 additions & 22 deletions arrow/src/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
use std::sync::Arc;

use arrow_array::RecordBatchReader;
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3::types::{PyList, PyTuple};

use crate::array::{make_array, Array, ArrayData};
use crate::array::{make_array, ArrayData};
use crate::datatypes::{DataType, Field, Schema};
use crate::error::ArrowError;
use crate::ffi;
Expand Down Expand Up @@ -270,25 +270,12 @@ impl FromPyArrow for RecordBatch {

impl ToPyArrow for RecordBatch {
fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
let mut py_arrays = vec![];

let schema = self.schema();
let columns = self.columns().iter();

for array in columns {
py_arrays.push(array.to_data().to_pyarrow(py)?);
}

let py_schema = schema.to_pyarrow(py)?;

let module = py.import("pyarrow")?;
let class = module.getattr("RecordBatch")?;
let args = (py_arrays,);
let kwargs = PyDict::new(py);
kwargs.set_item("schema", py_schema)?;
let record = class.call_method("from_arrays", args, Some(kwargs))?;

Ok(PyObject::from(record))
// Workaround apache/arrow#37669 by returning RecordBatchIterator
let reader =
RecordBatchIterator::new(vec![Ok(self.clone())], self.schema().clone());
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
let py_reader = reader.into_pyarrow(py)?;
py_reader.call_method0(py, "read_next_batch")
}
}

Expand Down

0 comments on commit 80b0888

Please sign in to comment.