Skip to content

Commit

Permalink
Fix interval_range right child non-zero offset
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Aug 23, 2024
1 parent a7ca3af commit 82d7497
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
15 changes: 12 additions & 3 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
IntervalColumn,
NumericalColumn,
StringColumn,
StructColumn,
TimeDeltaColumn,
column,
)
Expand Down Expand Up @@ -154,7 +153,7 @@ def _index_from_data(data: MutableMapping, name: Any = no_default):
index_class_type = Index
elif isinstance(values, CategoricalColumn):
index_class_type = CategoricalIndex
elif isinstance(values, (IntervalColumn, StructColumn)):
elif isinstance(values, IntervalColumn):
index_class_type = IntervalIndex
else:
raise NotImplementedError(
Expand Down Expand Up @@ -3250,7 +3249,7 @@ def interval_range(
freq=None,
name=None,
closed="right",
) -> "IntervalIndex":
) -> IntervalIndex:
"""
Returns a fixed frequency IntervalIndex.
Expand Down Expand Up @@ -3347,6 +3346,16 @@ def interval_range(
)
left_col = bin_edges.slice(0, len(bin_edges) - 1)
right_col = bin_edges.slice(1, len(bin_edges))
# For indexing, children should both have 0 offset
right_col = type(right_col)(
data=right_col.data,
dtype=right_col.dtype,
size=right_col.size,
mask=right_col.mask,
offset=0,
null_count=right_col.null_count,
children=right_col.children,
)

if len(right_col) == 0 or len(left_col) == 0:
dtype = IntervalDtype("int64", closed)
Expand Down
6 changes: 6 additions & 0 deletions python/cudf/cudf/tests/indexes/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,9 @@ def test_interval_range_name():
expected = pd.interval_range(start=0, periods=5, freq=2, name="foo")
result = cudf.interval_range(start=0, periods=5, freq=2, name="foo")
assert_eq(result, expected)


def test_from_interval_range_indexing():
result = cudf.interval_range(start=0, end=1, name="a").repeat(2)
expected = pd.interval_range(start=0, end=1, name="a").repeat(2)
assert_eq(result, expected)
6 changes: 6 additions & 0 deletions python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3303,3 +3303,9 @@ def test_bool_rangeindex_raises():
lfunc_args_and_kwargs=[[pd.RangeIndex(0)]],
rfunc_args_and_kwargs=[[cudf.RangeIndex(0)]],
)


def test_index_struct_data_disallowed():
data = pa.array([{"a": [1], "b": [2]}])
with pytest.raises(NotImplementedError):
Index(data)

0 comments on commit 82d7497

Please sign in to comment.