-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dask] fix Dask docstrings and mimic sklearn wrapper importing way #3855
Changes from 6 commits
a3beeff
9432108
fdd69f3
968fcba
6e51b8f
63466a6
8a7e6c2
ea0f939
95b5697
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,16 +13,12 @@ | |
from urllib.parse import urlparse | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import scipy.sparse as ss | ||
|
||
from dask import array as da | ||
from dask import dataframe as dd | ||
from dask import delayed | ||
from dask.distributed import Client, default_client, get_worker, wait | ||
|
||
from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError | ||
from .compat import DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED | ||
from .compat import (PANDAS_INSTALLED, DataFrame, Series, concat, | ||
SKLEARN_INSTALLED, | ||
DASK_INSTALLED, dask_Frame, dask_Array, delayed, Client, default_client, get_worker, wait) | ||
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker | ||
|
||
|
||
|
@@ -46,7 +42,7 @@ def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Itera | |
|
||
Returns | ||
------- | ||
result : int | ||
port : int | ||
A free port on the machine referenced by ``worker_ip``. | ||
""" | ||
max_tries = 1000 | ||
|
@@ -81,7 +77,7 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc | |
client : dask.distributed.Client | ||
Dask client. | ||
worker_addresses : Iterable[str] | ||
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port`` | ||
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``. | ||
local_listen_port : int | ||
First port to try when searching for open ports. | ||
|
||
|
@@ -109,8 +105,8 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc | |
def _concat(seq): | ||
if isinstance(seq[0], np.ndarray): | ||
return np.concatenate(seq, axis=0) | ||
elif isinstance(seq[0], (pd.DataFrame, pd.Series)): | ||
return pd.concat(seq, axis=0) | ||
elif isinstance(seq[0], (DataFrame, Series)): | ||
return concat(seq, axis=0) | ||
elif isinstance(seq[0], ss.spmatrix): | ||
return ss.vstack(seq, format='csr') | ||
else: | ||
|
@@ -152,9 +148,9 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re | |
try: | ||
model = model_factory(**params) | ||
if is_ranker: | ||
model.fit(data, y=label, sample_weight=weight, group=group, **kwargs) | ||
model.fit(data, label, sample_weight=weight, group=group, **kwargs) | ||
else: | ||
model.fit(data, y=label, sample_weight=weight, **kwargs) | ||
model.fit(data, label, sample_weight=weight, **kwargs) | ||
|
||
finally: | ||
_safe_call(_LIB.LGBM_NetworkFree()) | ||
|
@@ -178,13 +174,16 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group | |
|
||
Parameters | ||
---------- | ||
client: dask.Client - client | ||
X : dask array of shape = [n_samples, n_features] | ||
client : dask.distributed.Client | ||
Dask client. | ||
data : dask array of shape = [n_samples, n_features] | ||
Input feature matrix. | ||
y : dask array of shape = [n_samples] | ||
label : dask array of shape = [n_samples] | ||
The target values (class labels in classification, real numbers in regression). | ||
params : dict | ||
Parameters passed to constructor of the local underlying model. | ||
model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class | ||
Class of the local underlying model. | ||
sample_weight : array-like of shape = [n_samples] or None, optional (default=None) | ||
Weights of training data. | ||
group : array-like or None, optional (default=None) | ||
|
@@ -193,6 +192,13 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group | |
sum(group) = n_samples. | ||
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, | ||
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc. | ||
**kwargs | ||
Other parameters passed to ``fit`` method of the local underlying model. | ||
|
||
Returns | ||
------- | ||
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class | ||
Returns fitted underlying model. | ||
""" | ||
params = deepcopy(params) | ||
|
||
|
@@ -298,7 +304,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group | |
|
||
|
||
def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs): | ||
data = part.values if isinstance(part, pd.DataFrame) else part | ||
data = part.values if isinstance(part, DataFrame) else part | ||
|
||
if data.shape[0] == 0: | ||
result = np.array([]) | ||
|
@@ -319,11 +325,11 @@ def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, * | |
**kwargs | ||
) | ||
|
||
if isinstance(part, pd.DataFrame): | ||
if isinstance(part, DataFrame): | ||
if pred_proba or pred_contrib: | ||
result = pd.DataFrame(result, index=part.index) | ||
result = DataFrame(result, index=part.index) | ||
else: | ||
result = pd.Series(result, index=part.index, name='predictions') | ||
result = Series(result, index=part.index, name='predictions') | ||
|
||
return result | ||
|
||
|
@@ -335,20 +341,34 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr | |
Parameters | ||
---------- | ||
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class | ||
Fitted underlying model. | ||
data : dask array of shape = [n_samples, n_features] | ||
Input feature matrix. | ||
raw_score : bool, optional (default=False) | ||
Whether to predict raw scores. | ||
pred_proba : bool, optional (default=False) | ||
Should method return results of ``predict_proba`` (``pred_proba=True``) or ``predict`` (``pred_proba=False``). | ||
pred_leaf : bool, optional (default=False) | ||
Whether to predict leaf index. | ||
pred_contrib : bool, optional (default=False) | ||
Whether to predict feature contributions. | ||
dtype : np.dtype | ||
dtype : np.dtype, optional (default=np.float32) | ||
Dtype of the output. | ||
kwargs : dict | ||
**kwargs | ||
Other parameters passed to ``predict`` or ``predict_proba`` method. | ||
|
||
Returns | ||
------- | ||
predicted_result : dask array of shape = [n_samples] or shape = [n_samples, n_classes] | ||
The predicted values. | ||
X_leaves : dask arrayof shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes] | ||
If ``pred_leaf=True``, the predicted leaf of every tree for each sample. | ||
X_SHAP_values : dask array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects | ||
If ``pred_contrib=True``, the feature contributions for each sample. | ||
""" | ||
if isinstance(data, dd._Frame): | ||
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): | ||
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') | ||
if isinstance(data, _Frame): | ||
return data.map_partitions( | ||
_predict_part, | ||
model=model, | ||
|
@@ -358,7 +378,7 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr | |
pred_contrib=pred_contrib, | ||
**kwargs | ||
).values | ||
elif isinstance(data, da.Array): | ||
elif isinstance(data, Array): | ||
if pred_proba: | ||
kwargs['chunks'] = (data.chunks[0], (model.n_classes_,)) | ||
else: | ||
|
@@ -378,12 +398,9 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr | |
|
||
|
||
class _LGBMModel: | ||
def __init__(self): | ||
def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None, **kwargs): | ||
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): | ||
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this check is being moved out of the constructor, then can you please put it in the If someone tries to load a saved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great catch! Will do.
I wish I could leave it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah! didn't think about that when reviewing the MRO change in #3822 |
||
|
||
def _fit(self, model_factory, X, y=None, sample_weight=None, group=None, client=None, **kwargs): | ||
"""Docstring is inherited from the LGBMModel.""" | ||
if client is None: | ||
client = default_client() | ||
|
||
|
@@ -422,7 +439,7 @@ def _copy_extra_params(source, dest): | |
class DaskLGBMClassifier(LGBMClassifier, _LGBMModel): | ||
"""Distributed version of lightgbm.LGBMClassifier.""" | ||
|
||
def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): | ||
def fit(self, X, y, sample_weight=None, client=None, **kwargs): | ||
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" | ||
return self._fit( | ||
model_factory=LGBMClassifier, | ||
|
@@ -433,7 +450,12 @@ def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): | |
**kwargs | ||
) | ||
|
||
fit.__doc__ = LGBMClassifier.fit.__doc__ | ||
_base_doc = LGBMClassifier.fit.__doc__ | ||
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :') | ||
fit.__doc__ = (_before_init_score | ||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||
+ ' ' * 12 + 'Dask client.\n' | ||
+ ' ' * 8 + _init_score + _after_init_score) | ||
|
||
def predict(self, X, **kwargs): | ||
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" | ||
|
@@ -463,14 +485,15 @@ def to_local(self): | |
Returns | ||
------- | ||
model : lightgbm.LGBMClassifier | ||
Local underlying model. | ||
""" | ||
return self._to_local(LGBMClassifier) | ||
|
||
|
||
class DaskLGBMRegressor(LGBMRegressor, _LGBMModel): | ||
"""Docstring is inherited from the lightgbm.LGBMRegressor.""" | ||
"""Distributed version of lightgbm.LGBMRegressor.""" | ||
|
||
def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): | ||
def fit(self, X, y, sample_weight=None, client=None, **kwargs): | ||
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" | ||
return self._fit( | ||
model_factory=LGBMRegressor, | ||
|
@@ -481,7 +504,12 @@ def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): | |
**kwargs | ||
) | ||
|
||
fit.__doc__ = LGBMRegressor.fit.__doc__ | ||
_base_doc = LGBMRegressor.fit.__doc__ | ||
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :') | ||
fit.__doc__ = (_before_init_score | ||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||
+ ' ' * 12 + 'Dask client.\n' | ||
+ ' ' * 8 + _init_score + _after_init_score) | ||
|
||
def predict(self, X, **kwargs): | ||
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" | ||
|
@@ -499,14 +527,15 @@ def to_local(self): | |
Returns | ||
------- | ||
model : lightgbm.LGBMRegressor | ||
Local underlying model. | ||
""" | ||
return self._to_local(LGBMRegressor) | ||
|
||
|
||
class DaskLGBMRanker(LGBMRanker, _LGBMModel): | ||
"""Docstring is inherited from the lightgbm.LGBMRanker.""" | ||
"""Distributed version of lightgbm.LGBMRanker.""" | ||
|
||
def fit(self, X, y=None, sample_weight=None, init_score=None, group=None, client=None, **kwargs): | ||
def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None, **kwargs): | ||
"""Docstring is inherited from the lightgbm.LGBMRanker.fit.""" | ||
if init_score is not None: | ||
raise RuntimeError('init_score is not currently supported in lightgbm.dask') | ||
|
@@ -521,7 +550,12 @@ def fit(self, X, y=None, sample_weight=None, init_score=None, group=None, client | |
**kwargs | ||
) | ||
|
||
fit.__doc__ = LGBMRanker.fit.__doc__ | ||
_base_doc = LGBMRanker.fit.__doc__ | ||
_before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :') | ||
fit.__doc__ = (_before_eval_set | ||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||
+ ' ' * 12 + 'Dask client.\n' | ||
+ ' ' * 8 + _eval_set + _after_eval_set) | ||
|
||
def predict(self, X, **kwargs): | ||
"""Docstring is inherited from the lightgbm.LGBMRanker.predict.""" | ||
|
@@ -535,5 +569,6 @@ def to_local(self): | |
Returns | ||
------- | ||
model : lightgbm.LGBMRanker | ||
Local underlying model. | ||
""" | ||
return self._to_local(LGBMRanker) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with importing from
compat
, but could the names be changed on import to something likepd_DataFrame
?Since both
pandas
anddask
have aDataFrame
class, I think just calling thisDataFrame
makes the code difficult to read. I know that I personally will read this in the future and think "wait does that mean pandas or Dask DataFrame".There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm totally agree! But I think we can make import aliases in
compat.py
and then import likefrom .compat import pd_DataFrame
. Otherwise in case of identical names it will be confusing to havein
compat.py
.I'm going to rename only Dask imports in this PR to not overcomplicate review. pandas will be done in a follow-up PR. Do you agree?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes sounds good, thank you