Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [python-package] ensure predict() always returns an array #6348

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ def predict(
pred_contrib: bool = False,
data_has_header: bool = False,
validate_features: bool = False,
) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
) -> Union[np.ndarray, scipy.sparse.spmatrix]:
"""Predict logic.

Parameters
Expand Down Expand Up @@ -1112,9 +1112,9 @@ def predict(

Returns
-------
result : numpy array, scipy.sparse or list of scipy.sparse
result : numpy array or scipy.sparse
Prediction result.
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
If ``data`` is a sparse matrix, result will be a sparse matrix.
"""
if isinstance(data, Dataset):
raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead")
Expand Down Expand Up @@ -1354,7 +1354,7 @@ def __create_sparse_native(
indptr_type: int,
data_type: int,
is_csr: bool,
) -> Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]]:
) -> Union[scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]:
# create numpy array from output arrays
data_indices_len = out_shape[0]
indptr_len = out_shape[1]
Expand Down Expand Up @@ -1402,9 +1402,10 @@ def __create_sparse_native(
ctypes.c_int(data_type),
)
)
if len(cs_output_matrices) == 1:
return cs_output_matrices[0]
return cs_output_matrices
if is_csr:
return scipy.sparse.hstack(cs_output_matrices, format="csr")
else:
return scipy.sparse.hstack(cs_output_matrices, format="csc")

def __inner_predict_csr(
self,
Expand Down Expand Up @@ -1462,7 +1463,7 @@ def __inner_predict_csr_sparse(
start_iteration: int,
num_iteration: int,
predict_type: int,
) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]:
) -> Tuple[Union[scipy.sparse.csc_matrix, scipy.sparse.csr_matrix], int]:
ptr_indptr, type_ptr_indptr, __ = _c_int_array(csr.indptr)
ptr_data, type_ptr_data, _ = _c_float_array(csr.data)
csr_indices = csr.indices.astype(np.int32, copy=False)
Expand Down Expand Up @@ -1501,7 +1502,7 @@ def __inner_predict_csr_sparse(
ctypes.byref(out_ptr_data),
)
)
matrices = self.__create_sparse_native(
out_mat = self.__create_sparse_native(
cs=csr,
out_shape=out_shape,
out_ptr_indptr=out_ptr_indptr,
Expand All @@ -1512,7 +1513,7 @@ def __inner_predict_csr_sparse(
is_csr=True,
)
nrow = len(csr.indptr) - 1
return matrices, nrow
return out_mat, nrow

def __pred_for_csr(
self,
Expand Down Expand Up @@ -1563,7 +1564,7 @@ def __inner_predict_sparse_csc(
start_iteration: int,
num_iteration: int,
predict_type: int,
):
) -> Tuple[scipy.sparse.csc_matrix, int]:
ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
csc_indices = csc.indices.astype(np.int32, copy=False)
Expand Down Expand Up @@ -1602,7 +1603,7 @@ def __inner_predict_sparse_csc(
ctypes.byref(out_ptr_data),
)
)
matrices = self.__create_sparse_native(
out_mat = self.__create_sparse_native(
cs=csc,
out_shape=out_shape,
out_ptr_indptr=out_ptr_indptr,
Expand All @@ -1613,7 +1614,7 @@ def __inner_predict_sparse_csc(
is_csr=False,
)
nrow = csc.shape[0]
return matrices, nrow
return out_mat, nrow

def __pred_for_csc(
self,
Expand Down Expand Up @@ -4677,7 +4678,7 @@ def predict(
data_has_header: bool = False,
validate_features: bool = False,
**kwargs: Any,
) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
) -> Union[np.ndarray, scipy.sparse.spmatrix]:
"""Make a prediction.

Parameters
Expand Down Expand Up @@ -4719,9 +4720,9 @@ def predict(

Returns
-------
result : numpy array, scipy.sparse or list of scipy.sparse
result : numpy array or scipy.sparse
Prediction result.
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
If ``data`` is a sparse matrix, result will be a sparse matrix.
"""
predictor = _InnerPredictor.from_booster(
booster=self,
Expand Down
4 changes: 0 additions & 4 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ class _LGBMRegressorBase: # type: ignore
try:
from dask import delayed
from dask.array import Array as dask_Array
from dask.array import from_delayed as dask_array_from_delayed
from dask.bag import from_delayed as dask_bag_from_delayed
from dask.dataframe import DataFrame as dask_DataFrame
from dask.dataframe import Series as dask_Series
from dask.distributed import Client, Future, default_client, wait
Expand All @@ -167,8 +165,6 @@ class _LGBMRegressorBase: # type: ignore
except ImportError:
DASK_INSTALLED = False

dask_array_from_delayed = None # type: ignore[assignment]
dask_bag_from_delayed = None # type: ignore[assignment]
delayed = None
default_client = None # type: ignore[assignment]
wait = None # type: ignore[assignment]
Expand Down
74 changes: 3 additions & 71 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
LGBMNotFittedError,
concat,
dask_Array,
dask_array_from_delayed,
dask_bag_from_delayed,
dask_DataFrame,
dask_Series,
default_client,
Expand Down Expand Up @@ -906,7 +904,7 @@ def _predict(
The predicted values.
X_leaves : Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]
X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]
If ``pred_contrib=True``, the feature contributions for each sample.
"""
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
Expand All @@ -922,72 +920,6 @@ def _predict(
**kwargs,
).values
elif isinstance(data, dask_Array):
# for multi-class classification with sparse matrices, pred_contrib predictions
# are returned as a list of sparse matrices (one per class)
num_classes = model._n_classes

if num_classes > 2 and pred_contrib and isinstance(data._meta, ss.spmatrix):
predict_function = partial(
_predict_part,
model=model,
raw_score=False,
pred_proba=pred_proba,
pred_leaf=False,
pred_contrib=True,
**kwargs,
)

delayed_chunks = data.to_delayed()
bag = dask_bag_from_delayed(delayed_chunks[:, 0])

@delayed
def _extract(items: List[Any], i: int) -> Any:
return items[i]

preds = bag.map_partitions(predict_function)

# pred_contrib output will have one column per feature,
# plus one more for the base value
num_cols = model.n_features_ + 1

nrows_per_chunk = data.chunks[0]
out: List[List[dask_Array]] = [[] for _ in range(num_classes)]

# need to tell Dask the expected type and shape of individual preds
pred_meta = data._meta

for j, partition in enumerate(preds.to_delayed()):
for i in range(num_classes):
part = dask_array_from_delayed(
value=_extract(partition, i),
shape=(nrows_per_chunk[j], num_cols),
meta=pred_meta,
)
out[i].append(part)

# by default, dask.array.concatenate() concatenates sparse arrays into a COO matrix
# the code below is used instead to ensure that the sparse type is preserved during concatentation
if isinstance(pred_meta, ss.csr_matrix):
concat_fn = partial(ss.vstack, format="csr")
elif isinstance(pred_meta, ss.csc_matrix):
concat_fn = partial(ss.vstack, format="csc")
else:
concat_fn = ss.vstack

# At this point, `out` is a list of lists of delayeds (each of which points to a matrix).
# Concatenate them to return a list of Dask Arrays.
out_arrays: List[dask_Array] = []
for i in range(num_classes):
out_arrays.append(
dask_array_from_delayed(
value=delayed(concat_fn)(out[i]),
shape=(data.shape[0], num_cols),
meta=pred_meta,
)
)

return out_arrays

data_row = client.compute(data[[0]]).result()
predict_fn = partial(
_predict_part,
Expand Down Expand Up @@ -1263,7 +1195,7 @@ def predict(
output_name="predicted_result",
predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]",
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]",
)

def predict_proba(
Expand Down Expand Up @@ -1298,7 +1230,7 @@ def predict_proba(
output_name="predicted_probability",
predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]",
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]",
)

def to_local(self) -> LGBMClassifier:
Expand Down
15 changes: 8 additions & 7 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
_LGBM_ScikitCustomEvalFunction,
List[Union[str, _LGBM_ScikitCustomEvalFunction]],
]
_LGBM_ScikitPredictReturnType = Union[np.ndarray, scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]
_LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType]


Expand Down Expand Up @@ -945,7 +946,7 @@ def _get_meta_data(collection, name, i):

fit.__doc__ = (
_lgbmmodel_doc_fit.format(
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame , scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
y_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples]",
sample_weight_shape="numpy array, pandas Series, list of int or float of shape = [n_samples] or None, optional (default=None)",
init_score_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) or shape = [n_samples, n_classes] (for multi-class task) or None, optional (default=None)",
Expand All @@ -968,7 +969,7 @@ def predict(
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any,
):
) -> _LGBM_ScikitPredictReturnType:
"""Docstring is set after definition, using a template."""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.")
Expand Down Expand Up @@ -1015,11 +1016,11 @@ def predict(

predict.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted value for each sample.",
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame , scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
output_name="predicted_result",
predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects",
X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]",
)

@property
Expand Down Expand Up @@ -1270,7 +1271,7 @@ def predict(
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any,
):
) -> _LGBM_ScikitPredictReturnType:
"""Docstring is inherited from the LGBMModel."""
result = self.predict_proba(
X=X,
Expand Down Expand Up @@ -1300,7 +1301,7 @@ def predict_proba(
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any,
):
) -> _LGBM_ScikitPredictReturnType:
"""Docstring is set after definition, using a template."""
result = super().predict(
X=X,
Expand Down Expand Up @@ -1330,7 +1331,7 @@ def predict_proba(
output_name="predicted_probability",
predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects",
X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]",
)

@property
Expand Down
29 changes: 0 additions & 29 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,35 +343,6 @@ def test_classifier_pred_contrib(output, task, cluster):
else:
expected_num_cols = (num_features + 1) * num_classes

# in the special case of multi-class classification using scipy sparse matrices,
# the output of `.predict(..., pred_contrib=True)` is a list of sparse matrices (one per class)
#
# since that case is so different than all other cases, check the relevant things here
# and then return early
if output.startswith("scipy") and task == "multiclass-classification":
if output == "scipy_csr_matrix":
expected_type = csr_matrix
elif output == "scipy_csc_matrix":
expected_type = csc_matrix
else:
raise ValueError(f"Unrecognized output type: {output}")
assert isinstance(preds_with_contrib, list)
assert all(isinstance(arr, da.Array) for arr in preds_with_contrib)
assert all(isinstance(arr._meta, expected_type) for arr in preds_with_contrib)
assert len(preds_with_contrib) == num_classes
assert len(preds_with_contrib) == len(local_preds_with_contrib)
for i in range(num_classes):
computed_preds = preds_with_contrib[i].compute()
assert isinstance(computed_preds, expected_type)
assert computed_preds.shape[1] == num_classes
assert computed_preds.shape == local_preds_with_contrib[i].shape
assert len(np.unique(computed_preds[:, -1])) == 1
# raw scores will probably be different, but at least check that all predicted classes are the same
pred_classes = np.argmax(computed_preds.toarray(), axis=1)
local_pred_classes = np.argmax(local_preds_with_contrib[i].toarray(), axis=1)
np.testing.assert_array_equal(pred_classes, local_pred_classes)
return

preds_with_contrib = preds_with_contrib.compute()
if output.startswith("scipy"):
preds_with_contrib = preds_with_contrib.toarray()
Expand Down
Loading
Loading