Skip to content

Commit

Permalink
Fix pylibcudf to_arrow with multiple nested data types (#17504)
Browse files Browse the repository at this point in the history
Fixes the following case

```python
In [25]: import pyarrow as pa, pylibcudf as plc

In [26]: pa_array = pa.array([[{"a": 1}]])

In [27]: pa_array.type
Out[27]: ListType(list<item: struct<a: int64>>)

In [28]: plc_table = plc.Table([plc.interop.from_arrow(pa_array)])

In [29]: plc.interop.to_arrow(plc_table)
RuntimeError: CUDF failure at: cpp/src/interop/to_arrow_schema.cpp:146: Number of field names and number of children doesn't match
```

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #17504
  • Loading branch information
mroeschke authored Dec 19, 2024
1 parent 550ea35 commit d8f469f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
13 changes: 11 additions & 2 deletions python/pylibcudf/pylibcudf/interop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,19 @@ cdef void _release_array(object array_capsule) noexcept:
free(array)


def _maybe_create_nested_column_metadata(Column col):
return ColumnMetadata(
children_meta=[
_maybe_create_nested_column_metadata(child) for child in col.children()
]
)


def _table_to_schema(Table tbl, metadata):
if metadata is None:
metadata = [ColumnMetadata() for _ in range(len(tbl.columns()))]
metadata = [ColumnMetadata(m) if isinstance(m, str) else m for m in metadata]
metadata = [_maybe_create_nested_column_metadata(col) for col in tbl.columns()]
else:
metadata = [ColumnMetadata(m) if isinstance(m, str) else m for m in metadata]

cdef vector[column_metadata] c_metadata
c_metadata.reserve(len(metadata))
Expand Down
22 changes: 22 additions & 0 deletions python/pylibcudf/pylibcudf/tests/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ def test_struct_dtype_roundtrip():
assert arrow_type == struct_type


def test_table_with_nested_dtype_to_arrow():
pa_array = pa.array([[{"": 1}]])
plc_table = plc.Table([plc.interop.from_arrow(pa_array)])
result = plc.interop.to_arrow(plc_table)
expected_schema = pa.schema(
[
pa.field(
"",
pa.list_(
pa.field(
"",
pa.struct([pa.field("", pa.int64(), nullable=False)]),
nullable=False,
)
),
nullable=False,
)
]
)
assert result.schema == expected_schema


def test_decimal128_roundtrip():
decimal_type = pa.decimal128(10, 2)
plc_type = plc.interop.from_arrow(decimal_type)
Expand Down

0 comments on commit d8f469f

Please sign in to comment.