Skip to content

Commit

Permalink
REF: simplify core.algorithms, reshape.cut (pandas-dev#29385)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and jreback committed Nov 4, 2019
1 parent 6cc8234 commit 0d977e9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 54 deletions.
36 changes: 13 additions & 23 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from pandas._libs import algos, hashtable as htable, lib
from pandas._libs import Timestamp, algos, hashtable as htable, lib
from pandas._libs.tslib import iNaT
from pandas.util._decorators import Appender, Substitution, deprecate_kwarg

Expand Down Expand Up @@ -1440,7 +1440,9 @@ def _take_nd_object(arr, indexer, out, axis: int, fill_value, mask_info):
}


def _get_take_nd_function(ndim, arr_dtype, out_dtype, axis: int = 0, mask_info=None):
def _get_take_nd_function(
ndim: int, arr_dtype, out_dtype, axis: int = 0, mask_info=None
):
if ndim <= 2:
tup = (arr_dtype.name, out_dtype.name)
if ndim == 1:
Expand Down Expand Up @@ -1474,7 +1476,7 @@ def func2(arr, indexer, out, fill_value=np.nan):
return func2


def take(arr, indices, axis=0, allow_fill: bool = False, fill_value=None):
def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None):
"""
Take elements from an array.
Expand Down Expand Up @@ -1568,13 +1570,7 @@ def take(arr, indices, axis=0, allow_fill: bool = False, fill_value=None):


def take_nd(
arr,
indexer,
axis=0,
out=None,
fill_value=np.nan,
mask_info=None,
allow_fill: bool = True,
arr, indexer, axis: int = 0, out=None, fill_value=np.nan, allow_fill: bool = True
):
"""
Specialized Cython take which sets NaN values in one pass
Expand All @@ -1597,10 +1593,6 @@ def take_nd(
maybe_promote to determine this type for any fill_value
fill_value : any, default np.nan
Fill value to replace -1 values with
mask_info : tuple of (ndarray, boolean)
If provided, value should correspond to:
(indexer != -1, (indexer != -1).any())
If not provided, it will be computed internally if necessary
allow_fill : boolean, default True
If False, indexer is assumed to contain no -1 values so no filling
will be done. This short-circuits computation of a mask. Result is
Expand All @@ -1611,6 +1603,7 @@ def take_nd(
subarray : array-like
May be the same type as the input, or cast to an ndarray.
"""
mask_info = None

if is_extension_array_dtype(arr):
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
Expand All @@ -1632,12 +1625,9 @@ def take_nd(
dtype, fill_value = maybe_promote(arr.dtype, fill_value)
if dtype != arr.dtype and (out is None or out.dtype != dtype):
# check if promotion is actually required based on indexer
if mask_info is not None:
mask, needs_masking = mask_info
else:
mask = indexer == -1
needs_masking = mask.any()
mask_info = mask, needs_masking
mask = indexer == -1
needs_masking = mask.any()
mask_info = mask, needs_masking
if needs_masking:
if out is not None and out.dtype != dtype:
raise TypeError("Incompatible type for fill_value")
Expand Down Expand Up @@ -1818,12 +1808,12 @@ def searchsorted(arr, value, side="left", sorter=None):
elif not (
is_object_dtype(arr) or is_numeric_dtype(arr) or is_categorical_dtype(arr)
):
from pandas.core.series import Series

# E.g. if `arr` is an array with dtype='datetime64[ns]'
# and `value` is a pd.Timestamp, we may need to convert value
value_ser = Series(value)._values
value_ser = array([value]) if is_scalar(value) else array(value)
value = value_ser[0] if is_scalar(value) else value_ser
if isinstance(value, Timestamp) and value.tzinfo is None:
value = value.to_datetime64()

result = arr.searchsorted(value, side=side, sorter=sorter)
return result
Expand Down
44 changes: 13 additions & 31 deletions pandas/core/reshape/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from pandas._libs import Timedelta, Timestamp
from pandas._libs.interval import Interval
from pandas._libs.lib import infer_dtype

from pandas.core.dtypes.common import (
Expand All @@ -18,17 +19,10 @@
is_scalar,
is_timedelta64_dtype,
)
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.missing import isna

from pandas import (
Categorical,
Index,
Interval,
IntervalIndex,
Series,
to_datetime,
to_timedelta,
)
from pandas import Categorical, Index, IntervalIndex, to_datetime, to_timedelta
import pandas.core.algorithms as algos
import pandas.core.nanops as nanops

Expand Down Expand Up @@ -206,7 +200,8 @@ def cut(
# NOTE: this binning code is changed a bit from histogram for var(x) == 0

# for handling the cut for datetime and timedelta objects
x_is_series, series_index, name, x = _preprocess_for_cut(x)
original = x
x = _preprocess_for_cut(x)
x, dtype = _coerce_to_type(x)

if not np.iterable(bins):
Expand Down Expand Up @@ -268,9 +263,7 @@ def cut(
duplicates=duplicates,
)

return _postprocess_for_cut(
fac, bins, retbins, x_is_series, series_index, name, dtype
)
return _postprocess_for_cut(fac, bins, retbins, dtype, original)


def qcut(
Expand Down Expand Up @@ -333,8 +326,8 @@ def qcut(
>>> pd.qcut(range(5), 4, labels=False)
array([0, 0, 1, 2, 3])
"""
x_is_series, series_index, name, x = _preprocess_for_cut(x)

original = x
x = _preprocess_for_cut(x)
x, dtype = _coerce_to_type(x)

if is_integer(q):
Expand All @@ -352,9 +345,7 @@ def qcut(
duplicates=duplicates,
)

return _postprocess_for_cut(
fac, bins, retbins, x_is_series, series_index, name, dtype
)
return _postprocess_for_cut(fac, bins, retbins, dtype, original)


def _bins_to_cuts(
Expand Down Expand Up @@ -544,13 +535,6 @@ def _preprocess_for_cut(x):
input to array, strip the index information and store it
separately
"""
x_is_series = isinstance(x, Series)
series_index = None
name = None

if x_is_series:
series_index = x.index
name = x.name

# Check that the passed array is a Pandas or Numpy object
# We don't want to strip away a Pandas data-type here (e.g. datetimetz)
Expand All @@ -560,19 +544,17 @@ def _preprocess_for_cut(x):
if x.ndim != 1:
raise ValueError("Input array must be 1 dimensional")

return x_is_series, series_index, name, x
return x


def _postprocess_for_cut(
fac, bins, retbins: bool, x_is_series, series_index, name, dtype
):
def _postprocess_for_cut(fac, bins, retbins: bool, dtype, original):
"""
handles post processing for the cut method where
we combine the index information if the originally passed
datatype was a series
"""
if x_is_series:
fac = Series(fac, index=series_index, name=name)
if isinstance(original, ABCSeries):
fac = original._constructor(fac, index=original.index, name=original.name)

if not retbins:
return fac
Expand Down

0 comments on commit 0d977e9

Please sign in to comment.