From d3782da962051b9c4340ff615430e64ce16ec3fa Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 29 Feb 2024 21:29:44 -0600 Subject: [PATCH 1/4] [python-package] ensure predict() always returns an array --- python-package/lightgbm/basic.py | 33 +++++++------- python-package/lightgbm/compat.py | 4 -- python-package/lightgbm/dask.py | 74 ++----------------------------- 3 files changed, 20 insertions(+), 91 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index f78d8c35216c..3d6783a303a0 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1083,7 +1083,7 @@ def predict( pred_contrib: bool = False, data_has_header: bool = False, validate_features: bool = False, - ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]: + ) -> Union[np.ndarray, scipy.sparse.spmatrix]: """Predict logic. Parameters @@ -1112,9 +1112,9 @@ def predict( Returns ------- - result : numpy array, scipy.sparse or list of scipy.sparse + result : numpy array or scipy.sparse Prediction result. - Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). + If ``data`` is a sparse matrix, result will be a sparse matrix. """ if isinstance(data, Dataset): raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") @@ -1354,7 +1354,7 @@ def __create_sparse_native( indptr_type: int, data_type: int, is_csr: bool, - ) -> Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]]: + ) -> Union[scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]: # create numpy array from output arrays data_indices_len = out_shape[0] indptr_len = out_shape[1] @@ -1402,9 +1402,10 @@ def __create_sparse_native( ctypes.c_int(data_type), ) ) - if len(cs_output_matrices) == 1: - return cs_output_matrices[0] - return cs_output_matrices + if is_csr: + return scipy.sparse.hstack(cs_output_matrices, format="csr") + else: + return scipy.sparse.hstack(cs_output_matrices, format="csc") def __inner_predict_csr( self, @@ -1462,7 +1463,7 @@ def __inner_predict_csr_sparse( start_iteration: int, num_iteration: int, predict_type: int, - ) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]: + ) -> Tuple[Union[scipy.sparse.csc_matrix, scipy.sparse.csr_matrix], int]: ptr_indptr, type_ptr_indptr, __ = _c_int_array(csr.indptr) ptr_data, type_ptr_data, _ = _c_float_array(csr.data) csr_indices = csr.indices.astype(np.int32, copy=False) @@ -1501,7 +1502,7 @@ def __inner_predict_csr_sparse( ctypes.byref(out_ptr_data), ) ) - matrices = self.__create_sparse_native( + out_mat = self.__create_sparse_native( cs=csr, out_shape=out_shape, out_ptr_indptr=out_ptr_indptr, @@ -1512,7 +1513,7 @@ def __inner_predict_csr_sparse( is_csr=True, ) nrow = len(csr.indptr) - 1 - return matrices, nrow + return out_mat, nrow def __pred_for_csr( self, @@ -1563,7 +1564,7 @@ def __inner_predict_sparse_csc( start_iteration: int, num_iteration: int, predict_type: int, - ): + ) -> Tuple[scipy.sparse.csc_matrix, int]: ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr) ptr_data, type_ptr_data, _ = _c_float_array(csc.data) csc_indices = csc.indices.astype(np.int32, copy=False) @@ -1602,7 +1603,7 @@ def __inner_predict_sparse_csc( ctypes.byref(out_ptr_data), ) ) - matrices = self.__create_sparse_native( + out_mat = self.__create_sparse_native( cs=csc, out_shape=out_shape, out_ptr_indptr=out_ptr_indptr, @@ -1613,7 +1614,7 @@ def __inner_predict_sparse_csc( is_csr=False, ) nrow = csc.shape[0] - return matrices, nrow + return out_mat, nrow def __pred_for_csc( self, @@ -4677,7 +4678,7 @@ def predict( data_has_header: bool = False, validate_features: bool = False, **kwargs: Any, - ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]: + ) -> Union[np.ndarray, scipy.sparse.spmatrix]: """Make a prediction. Parameters @@ -4719,9 +4720,9 @@ def predict( Returns ------- - result : numpy array, scipy.sparse or list of scipy.sparse + result : numpy array or scipy.sparse Prediction result. - Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). + If ``data`` is a sparse matrix, result will be a sparse matrix. """ predictor = _InnerPredictor.from_booster( booster=self, diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 086c6a199ff3..d8534708ff96 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -157,8 +157,6 @@ class _LGBMRegressorBase: # type: ignore try: from dask import delayed from dask.array import Array as dask_Array - from dask.array import from_delayed as dask_array_from_delayed - from dask.bag import from_delayed as dask_bag_from_delayed from dask.dataframe import DataFrame as dask_DataFrame from dask.dataframe import Series as dask_Series from dask.distributed import Client, Future, default_client, wait @@ -167,8 +165,6 @@ class _LGBMRegressorBase: # type: ignore except ImportError: DASK_INSTALLED = False - dask_array_from_delayed = None # type: ignore[assignment] - dask_bag_from_delayed = None # type: ignore[assignment] delayed = None default_client = None # type: ignore[assignment] wait = None # type: ignore[assignment] diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 928fe51bddce..fad54efaba3e 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -28,8 +28,6 @@ LGBMNotFittedError, concat, dask_Array, - dask_array_from_delayed, - dask_bag_from_delayed, dask_DataFrame, dask_Series, default_client, @@ -906,7 +904,7 @@ def _predict( The predicted values. X_leaves : Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes] If ``pred_leaf=True``, the predicted leaf of every tree for each sample. - X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1] + X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] If ``pred_contrib=True``, the feature contributions for each sample. """ if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): @@ -922,72 +920,6 @@ def _predict( **kwargs, ).values elif isinstance(data, dask_Array): - # for multi-class classification with sparse matrices, pred_contrib predictions - # are returned as a list of sparse matrices (one per class) - num_classes = model._n_classes - - if num_classes > 2 and pred_contrib and isinstance(data._meta, ss.spmatrix): - predict_function = partial( - _predict_part, - model=model, - raw_score=False, - pred_proba=pred_proba, - pred_leaf=False, - pred_contrib=True, - **kwargs, - ) - - delayed_chunks = data.to_delayed() - bag = dask_bag_from_delayed(delayed_chunks[:, 0]) - - @delayed - def _extract(items: List[Any], i: int) -> Any: - return items[i] - - preds = bag.map_partitions(predict_function) - - # pred_contrib output will have one column per feature, - # plus one more for the base value - num_cols = model.n_features_ + 1 - - nrows_per_chunk = data.chunks[0] - out: List[List[dask_Array]] = [[] for _ in range(num_classes)] - - # need to tell Dask the expected type and shape of individual preds - pred_meta = data._meta - - for j, partition in enumerate(preds.to_delayed()): - for i in range(num_classes): - part = dask_array_from_delayed( - value=_extract(partition, i), - shape=(nrows_per_chunk[j], num_cols), - meta=pred_meta, - ) - out[i].append(part) - - # by default, dask.array.concatenate() concatenates sparse arrays into a COO matrix - # the code below is used instead to ensure that the sparse type is preserved during concatentation - if isinstance(pred_meta, ss.csr_matrix): - concat_fn = partial(ss.vstack, format="csr") - elif isinstance(pred_meta, ss.csc_matrix): - concat_fn = partial(ss.vstack, format="csc") - else: - concat_fn = ss.vstack - - # At this point, `out` is a list of lists of delayeds (each of which points to a matrix). - # Concatenate them to return a list of Dask Arrays. - out_arrays: List[dask_Array] = [] - for i in range(num_classes): - out_arrays.append( - dask_array_from_delayed( - value=delayed(concat_fn)(out[i]), - shape=(data.shape[0], num_cols), - meta=pred_meta, - ) - ) - - return out_arrays - data_row = client.compute(data[[0]]).result() predict_fn = partial( _predict_part, @@ -1263,7 +1195,7 @@ def predict( output_name="predicted_result", predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]", X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", - X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]", + X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]", ) def predict_proba( @@ -1298,7 +1230,7 @@ def predict_proba( output_name="predicted_probability", predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]", X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", - X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]", + X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]", ) def to_local(self) -> LGBMClassifier: From 7605296b62339d3d372a74ee5426f21e0d8434f2 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 29 Feb 2024 22:24:20 -0600 Subject: [PATCH 2/4] update tests --- tests/python_package_test/test_engine.py | 34 ++++++++++-------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 3fad36b34407..cfee0b6f6b8b 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1834,37 +1834,31 @@ def test_contribs_sparse_multiclass(): lgb_train = lgb.Dataset(X_train, y_train) gbm = lgb.train(params, lgb_train, num_boost_round=20) contribs_csr = gbm.predict(X_test, pred_contrib=True) - assert isinstance(contribs_csr, list) - for perclass_contribs_csr in contribs_csr: - assert isspmatrix_csr(perclass_contribs_csr) + isspmatrix_csr(contribs_csr) # convert data to dense and get back same contribs contribs_dense = gbm.predict(X_test.toarray(), pred_contrib=True) # validate the values are the same - contribs_csr_array = np.swapaxes(np.array([sparse_array.toarray() for sparse_array in contribs_csr]), 0, 1) - contribs_csr_arr_re = contribs_csr_array.reshape( - (contribs_csr_array.shape[0], contribs_csr_array.shape[1] * contribs_csr_array.shape[2]) - ) if platform.machine() == "aarch64": - np.testing.assert_allclose(contribs_csr_arr_re, contribs_dense, rtol=1, atol=1e-12) + np.testing.assert_allclose(contribs_csr.toarray(), contribs_dense, rtol=1, atol=1e-12) else: - np.testing.assert_allclose(contribs_csr_arr_re, contribs_dense) - contribs_dense_re = contribs_dense.reshape(contribs_csr_array.shape) - assert np.linalg.norm(gbm.predict(X_test, raw_score=True) - np.sum(contribs_dense_re, axis=2)) < 1e-4 + np.testing.assert_allclose(contribs_csr.toarray(), contribs_dense) + # values should sum to predictions + preds_by_class = np.hstack( + [ + np.sum(contribs_dense[:, i * (n_features + 1) : (i + 1) * (n_features + 1)], axis=1).reshape(-1, 1) + for i in range(n_labels) + ] + ) + assert np.linalg.norm(gbm.predict(X_test, raw_score=True) - preds_by_class) < 1e-4 # validate using CSC matrix X_test_csc = X_test.tocsc() contribs_csc = gbm.predict(X_test_csc, pred_contrib=True) - assert isinstance(contribs_csc, list) - for perclass_contribs_csc in contribs_csc: - assert isspmatrix_csc(perclass_contribs_csc) + isspmatrix_csc(contribs_csc) # validate the values are the same - contribs_csc_array = np.swapaxes(np.array([sparse_array.toarray() for sparse_array in contribs_csc]), 0, 1) - contribs_csc_array = contribs_csc_array.reshape( - (contribs_csc_array.shape[0], contribs_csc_array.shape[1] * contribs_csc_array.shape[2]) - ) if platform.machine() == "aarch64": - np.testing.assert_allclose(contribs_csc_array, contribs_dense, rtol=1, atol=1e-12) + np.testing.assert_allclose(contribs_csc.toarray(), contribs_dense, rtol=1, atol=1e-12) else: - np.testing.assert_allclose(contribs_csc_array, contribs_dense) + np.testing.assert_allclose(contribs_csc.toarray(), contribs_dense) @pytest.mark.skipif(psutil.virtual_memory().available / 1024 / 1024 / 1024 < 3, reason="not enough RAM") From 5a45947cee0f7f8e325f0303f9f8186634b50da6 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 29 Feb 2024 22:30:07 -0600 Subject: [PATCH 3/4] remove special case from dask test --- tests/python_package_test/test_dask.py | 29 -------------------------- 1 file changed, 29 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 9fe4da18faaf..9f64f2d0d2ee 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -343,35 +343,6 @@ def test_classifier_pred_contrib(output, task, cluster): else: expected_num_cols = (num_features + 1) * num_classes - # in the special case of multi-class classification using scipy sparse matrices, - # the output of `.predict(..., pred_contrib=True)` is a list of sparse matrices (one per class) - # - # since that case is so different than all other cases, check the relevant things here - # and then return early - if output.startswith("scipy") and task == "multiclass-classification": - if output == "scipy_csr_matrix": - expected_type = csr_matrix - elif output == "scipy_csc_matrix": - expected_type = csc_matrix - else: - raise ValueError(f"Unrecognized output type: {output}") - assert isinstance(preds_with_contrib, list) - assert all(isinstance(arr, da.Array) for arr in preds_with_contrib) - assert all(isinstance(arr._meta, expected_type) for arr in preds_with_contrib) - assert len(preds_with_contrib) == num_classes - assert len(preds_with_contrib) == len(local_preds_with_contrib) - for i in range(num_classes): - computed_preds = preds_with_contrib[i].compute() - assert isinstance(computed_preds, expected_type) - assert computed_preds.shape[1] == num_classes - assert computed_preds.shape == local_preds_with_contrib[i].shape - assert len(np.unique(computed_preds[:, -1])) == 1 - # raw scores will probably be different, but at least check that all predicted classes are the same - pred_classes = np.argmax(computed_preds.toarray(), axis=1) - local_pred_classes = np.argmax(local_preds_with_contrib[i].toarray(), axis=1) - np.testing.assert_array_equal(pred_classes, local_pred_classes) - return - preds_with_contrib = preds_with_contrib.compute() if output.startswith("scipy"): preds_with_contrib = preds_with_contrib.toarray() From 92d3171b14c138af1bc8a935ab9dd03859d41da5 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 1 Mar 2024 00:29:02 -0600 Subject: [PATCH 4/4] add scikit-learn type hints --- python-package/lightgbm/sklearn.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 5e0d51f4546d..617cc7d00a84 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -110,6 +110,7 @@ _LGBM_ScikitCustomEvalFunction, List[Union[str, _LGBM_ScikitCustomEvalFunction]], ] +_LGBM_ScikitPredictReturnType = Union[np.ndarray, scipy.sparse.csc_matrix, scipy.sparse.csr_matrix] _LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType] @@ -945,7 +946,7 @@ def _get_meta_data(collection, name, i): fit.__doc__ = ( _lgbmmodel_doc_fit.format( - X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame , scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]", + X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]", y_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples]", sample_weight_shape="numpy array, pandas Series, list of int or float of shape = [n_samples] or None, optional (default=None)", init_score_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) or shape = [n_samples, n_classes] (for multi-class task) or None, optional (default=None)", @@ -968,7 +969,7 @@ def predict( pred_contrib: bool = False, validate_features: bool = False, **kwargs: Any, - ): + ) -> _LGBM_ScikitPredictReturnType: """Docstring is set after definition, using a template.""" if not self.__sklearn_is_fitted__(): raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.") @@ -1015,11 +1016,11 @@ def predict( predict.__doc__ = _lgbmmodel_doc_predict.format( description="Return the predicted value for each sample.", - X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame , scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]", + X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]", output_name="predicted_result", predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]", X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", - X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects", + X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]", ) @property @@ -1270,7 +1271,7 @@ def predict( pred_contrib: bool = False, validate_features: bool = False, **kwargs: Any, - ): + ) -> _LGBM_ScikitPredictReturnType: """Docstring is inherited from the LGBMModel.""" result = self.predict_proba( X=X, @@ -1300,7 +1301,7 @@ def predict_proba( pred_contrib: bool = False, validate_features: bool = False, **kwargs: Any, - ): + ) -> _LGBM_ScikitPredictReturnType: """Docstring is set after definition, using a template.""" result = super().predict( X=X, @@ -1330,7 +1331,7 @@ def predict_proba( output_name="predicted_probability", predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]", X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", - X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects", + X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]", ) @property