Skip to content

Commit

Permalink
[dask] fix Dask docstrings and mimic sklearn wrapper importing way (#…
Browse files Browse the repository at this point in the history
…3855)

* fix Dask docstrings and mimic sklearn importing way

* Update .vsts-ci.yml

* revert CI checks

* use import aliases for Dask classes

* check Dask is installed in _predict() func

* fix lint issues introduced during resolving merge conflicts

* Update dask.py
  • Loading branch information
StrikerRUS authored Jan 26, 2021
1 parent 56b99d4 commit 5312b95
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 42 deletions.
3 changes: 1 addition & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
import scipy.sparse

from .compat import PANDAS_INSTALLED, DataFrame, Series, is_dtype_sparse, DataTable
from .compat import PANDAS_INSTALLED, DataFrame, Series, concat, is_dtype_sparse, DataTable
from .libpath import find_lib_path


Expand Down Expand Up @@ -2081,7 +2081,6 @@ def add_features_from(self, other):
if not PANDAS_INSTALLED:
raise LightGBMError("Cannot add features to DataFrame type of raw data "
"without pandas installed")
from pandas import concat
if isinstance(other.data, np.ndarray):
self.data = concat((self.data, DataFrame(other.data)),
axis=1, ignore_index=True)
Expand Down
25 changes: 21 additions & 4 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""pandas"""
try:
from pandas import Series, DataFrame
from pandas import Series, DataFrame, concat
from pandas.api.types import is_sparse as is_dtype_sparse
PANDAS_INSTALLED = True
except ImportError:
Expand All @@ -19,6 +19,7 @@ class DataFrame:

pass

concat = None
is_dtype_sparse = None

"""matplotlib"""
Expand Down Expand Up @@ -108,9 +109,25 @@ def _check_sample_weight(sample_weight, X, dtype=None):

"""dask"""
try:
from dask import array
from dask import dataframe
from dask.distributed import Client
from dask import delayed
from dask.array import Array as dask_Array
from dask.dataframe import _Frame as dask_Frame
from dask.distributed import Client, default_client, get_worker, wait
DASK_INSTALLED = True
except ImportError:
DASK_INSTALLED = False
delayed = None
Client = object
default_client = None
get_worker = None
wait = None

class dask_Array:
"""Dummy class for dask.array.Array."""

pass

class dask_Frame:
"""Dummy class for ddask.dataframe._Frame."""

pass
107 changes: 71 additions & 36 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,12 @@
from urllib.parse import urlparse

import numpy as np
import pandas as pd
import scipy.sparse as ss

from dask import array as da
from dask import dataframe as dd
from dask import delayed
from dask.distributed import Client, default_client, get_worker, wait

from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .compat import DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED
from .compat import (PANDAS_INSTALLED, DataFrame, Series, concat,
SKLEARN_INSTALLED,
DASK_INSTALLED, dask_Frame, dask_Array, delayed, Client, default_client, get_worker, wait)
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker


Expand All @@ -46,7 +42,7 @@ def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Itera
Returns
-------
result : int
port : int
A free port on the machine referenced by ``worker_ip``.
"""
max_tries = 1000
Expand Down Expand Up @@ -81,7 +77,7 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc
client : dask.distributed.Client
Dask client.
worker_addresses : Iterable[str]
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``.
local_listen_port : int
First port to try when searching for open ports.
Expand Down Expand Up @@ -109,8 +105,8 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc
def _concat(seq):
if isinstance(seq[0], np.ndarray):
return np.concatenate(seq, axis=0)
elif isinstance(seq[0], (pd.DataFrame, pd.Series)):
return pd.concat(seq, axis=0)
elif isinstance(seq[0], (DataFrame, Series)):
return concat(seq, axis=0)
elif isinstance(seq[0], ss.spmatrix):
return ss.vstack(seq, format='csr')
else:
Expand Down Expand Up @@ -152,9 +148,9 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
try:
model = model_factory(**params)
if is_ranker:
model.fit(data, y=label, sample_weight=weight, group=group, **kwargs)
model.fit(data, label, sample_weight=weight, group=group, **kwargs)
else:
model.fit(data, y=label, sample_weight=weight, **kwargs)
model.fit(data, label, sample_weight=weight, **kwargs)

finally:
_safe_call(_LIB.LGBM_NetworkFree())
Expand All @@ -178,13 +174,16 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
Parameters
----------
client: dask.Client - client
X : dask array of shape = [n_samples, n_features]
client : dask.distributed.Client
Dask client.
data : dask array of shape = [n_samples, n_features]
Input feature matrix.
y : dask array of shape = [n_samples]
label : dask array of shape = [n_samples]
The target values (class labels in classification, real numbers in regression).
params : dict
Parameters passed to constructor of the local underlying model.
model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Class of the local underlying model.
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
Weights of training data.
group : array-like or None, optional (default=None)
Expand All @@ -193,6 +192,13 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
**kwargs
Other parameters passed to ``fit`` method of the local underlying model.
Returns
-------
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Returns fitted underlying model.
"""
params = deepcopy(params)

Expand Down Expand Up @@ -298,7 +304,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group


def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs):
data = part.values if isinstance(part, pd.DataFrame) else part
data = part.values if isinstance(part, DataFrame) else part

if data.shape[0] == 0:
result = np.array([])
Expand All @@ -319,11 +325,11 @@ def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, *
**kwargs
)

if isinstance(part, pd.DataFrame):
if isinstance(part, DataFrame):
if pred_proba or pred_contrib:
result = pd.DataFrame(result, index=part.index)
result = DataFrame(result, index=part.index)
else:
result = pd.Series(result, index=part.index, name='predictions')
result = Series(result, index=part.index, name='predictions')

return result

Expand All @@ -335,20 +341,34 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr
Parameters
----------
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Fitted underlying model.
data : dask array of shape = [n_samples, n_features]
Input feature matrix.
raw_score : bool, optional (default=False)
Whether to predict raw scores.
pred_proba : bool, optional (default=False)
Should method return results of ``predict_proba`` (``pred_proba=True``) or ``predict`` (``pred_proba=False``).
pred_leaf : bool, optional (default=False)
Whether to predict leaf index.
pred_contrib : bool, optional (default=False)
Whether to predict feature contributions.
dtype : np.dtype
dtype : np.dtype, optional (default=np.float32)
Dtype of the output.
kwargs : dict
**kwargs
Other parameters passed to ``predict`` or ``predict_proba`` method.
Returns
-------
predicted_result : dask array of shape = [n_samples] or shape = [n_samples, n_classes]
The predicted values.
X_leaves : dask arrayof shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
X_SHAP_values : dask array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects
If ``pred_contrib=True``, the feature contributions for each sample.
"""
if isinstance(data, dd._Frame):
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if isinstance(data, dask_Frame):
return data.map_partitions(
_predict_part,
model=model,
Expand All @@ -358,7 +378,7 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr
pred_contrib=pred_contrib,
**kwargs
).values
elif isinstance(data, da.Array):
elif isinstance(data, dask_Array):
if pred_proba:
kwargs['chunks'] = (data.chunks[0], (model.n_classes_,))
else:
Expand All @@ -378,12 +398,9 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr


class _DaskLGBMModel:
def __init__(self):
def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None, **kwargs):
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')

def _fit(self, model_factory, X, y=None, sample_weight=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
if client is None:
client = default_client()

Expand Down Expand Up @@ -422,7 +439,7 @@ def _copy_extra_params(source, dest):
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMClassifier."""

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
def fit(self, X, y, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(
model_factory=LGBMClassifier,
Expand All @@ -433,7 +450,12 @@ def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
**kwargs
)

fit.__doc__ = LGBMClassifier.fit.__doc__
_base_doc = LGBMClassifier.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)

def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
Expand Down Expand Up @@ -463,14 +485,15 @@ def to_local(self):
Returns
-------
model : lightgbm.LGBMClassifier
Local underlying model.
"""
return self._to_local(LGBMClassifier)


class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
"""Docstring is inherited from the lightgbm.LGBMRegressor."""
"""Distributed version of lightgbm.LGBMRegressor."""

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
def fit(self, X, y, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(
model_factory=LGBMRegressor,
Expand All @@ -481,7 +504,12 @@ def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
**kwargs
)

fit.__doc__ = LGBMRegressor.fit.__doc__
_base_doc = LGBMRegressor.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)

def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
Expand All @@ -499,14 +527,15 @@ def to_local(self):
Returns
-------
model : lightgbm.LGBMRegressor
Local underlying model.
"""
return self._to_local(LGBMRegressor)


class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
"""Docstring is inherited from the lightgbm.LGBMRanker."""
"""Distributed version of lightgbm.LGBMRanker."""

def fit(self, X, y=None, sample_weight=None, init_score=None, group=None, client=None, **kwargs):
def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')
Expand All @@ -521,7 +550,12 @@ def fit(self, X, y=None, sample_weight=None, init_score=None, group=None, client
**kwargs
)

fit.__doc__ = LGBMRanker.fit.__doc__
_base_doc = LGBMRanker.fit.__doc__
_before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :')
fit.__doc__ = (_before_eval_set
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _eval_set + _after_eval_set)

def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
Expand All @@ -535,5 +569,6 @@ def to_local(self):
Returns
-------
model : lightgbm.LGBMRanker
Local underlying model.
"""
return self._to_local(LGBMRanker)

0 comments on commit 5312b95

Please sign in to comment.