Skip to content

Commit

Permalink
[dask] Include support for init_score (#3950)
Browse files Browse the repository at this point in the history
* include support for init_score

* use dataframe from init_score and test difference with and without init_score in local model

* revert refactoring

* initial docs. test between distributed models with and without init_score

* remove ranker from tests

* test value for root node and change docs

* comma

* re-include parametrize

* fix incorrect merge

* use single init_score and the booster_ attribute

* use np.float64 instead of float
  • Loading branch information
jmoralez authored Mar 4, 2021
1 parent 19f3577 commit 37e9878
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 14 deletions.
43 changes: 30 additions & 13 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,17 @@ def _train_part(
else:
group = None

if 'init_score' in list_of_parts[0]:
init_score = _concat([x['init_score'] for x in list_of_parts])
else:
init_score = None

try:
model = model_factory(**params)
if is_ranker:
model.fit(data, label, sample_weight=weight, group=group, **kwargs)
model.fit(data, label, sample_weight=weight, init_score=init_score, group=group, **kwargs)
else:
model.fit(data, label, sample_weight=weight, **kwargs)
model.fit(data, label, sample_weight=weight, init_score=init_score, **kwargs)

finally:
_safe_call(_LIB.LGBM_NetworkFree())
Expand Down Expand Up @@ -168,6 +173,7 @@ def _train(
params: Dict[str, Any],
model_factory: Type[LGBMModel],
sample_weight: Optional[_DaskCollection] = None,
init_score: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None,
**kwargs: Any
) -> LGBMModel:
Expand All @@ -187,6 +193,8 @@ def _train(
Class of the local underlying model.
sample_weight : Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)
Weights of training data.
init_score : Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)
Init score of training data.
group : Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)
Group/query data.
Only used in the learning-to-rank task.
Expand Down Expand Up @@ -289,6 +297,11 @@ def _train(
for i in range(n_parts):
parts[i]['group'] = group_parts[i]

if init_score is not None:
init_score_parts = _split_to_parts(data=init_score, is_matrix=False)
for i in range(n_parts):
parts[i]['init_score'] = init_score_parts[i]

# Start computation in the background
parts = list(map(delayed, parts))
parts = client.compute(parts)
Expand Down Expand Up @@ -540,6 +553,7 @@ def _lgb_dask_fit(
X: _DaskMatrixLike,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
init_score: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None,
**kwargs: Any
) -> "_DaskLGBMModel":
Expand All @@ -556,6 +570,7 @@ def _lgb_dask_fit(
params=params,
model_factory=model_factory,
sample_weight=sample_weight,
init_score=init_score,
group=group,
**kwargs
)
Expand Down Expand Up @@ -657,6 +672,7 @@ def fit(
X: _DaskMatrixLike,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
init_score: Optional[_DaskCollection] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
Expand All @@ -665,18 +681,20 @@ def fit(
X=X,
y=y,
sample_weight=sample_weight,
init_score=init_score,
**kwargs
)

_base_doc = _lgbmmodel_doc_fit.format(
X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
sample_weight_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)",
init_score_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)",
group_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)"
)

# DaskLGBMClassifier does not support init_score, evaluation data, or early stopping
_base_doc = (_base_doc[:_base_doc.find('init_score :')]
# DaskLGBMClassifier does not support evaluation data, or early stopping
_base_doc = (_base_doc[:_base_doc.find('group :')]
+ _base_doc[_base_doc.find('verbose :'):])

# DaskLGBMClassifier support for callbacks and init_model is not tested
Expand Down Expand Up @@ -808,6 +826,7 @@ def fit(
X: _DaskMatrixLike,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
init_score: Optional[_DaskCollection] = None,
**kwargs: Any
) -> "DaskLGBMRegressor":
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
Expand All @@ -816,18 +835,20 @@ def fit(
X=X,
y=y,
sample_weight=sample_weight,
init_score=init_score,
**kwargs
)

_base_doc = _lgbmmodel_doc_fit.format(
X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
sample_weight_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)",
init_score_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)",
group_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)"
)

# DaskLGBMRegressor does not support init_score, evaluation data, or early stopping
_base_doc = (_base_doc[:_base_doc.find('init_score :')]
# DaskLGBMRegressor does not support evaluation data, or early stopping
_base_doc = (_base_doc[:_base_doc.find('group :')]
+ _base_doc[_base_doc.find('verbose :'):])

# DaskLGBMRegressor support for callbacks and init_model is not tested
Expand Down Expand Up @@ -945,14 +966,12 @@ def fit(
**kwargs: Any
) -> "DaskLGBMRanker":
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')

return self._lgb_dask_fit(
model_factory=LGBMRanker,
X=X,
y=y,
sample_weight=sample_weight,
init_score=init_score,
group=group,
**kwargs
)
Expand All @@ -961,13 +980,11 @@ def fit(
X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
sample_weight_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)",
init_score_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)",
group_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)"
)

# DaskLGBMRanker does not support init_score, evaluation data, or early stopping
_base_doc = (_base_doc[:_base_doc.find('init_score :')]
+ _base_doc[_base_doc.find('init_score :'):])

# DaskLGBMRanker does not support evaluation data, or early stopping
_base_doc = (_base_doc[:_base_doc.find('eval_set :')]
+ _base_doc[_base_doc.find('verbose :'):])

Expand Down
3 changes: 2 additions & 1 deletion python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __call__(self, preds, dataset):
The target values (class labels in classification, real numbers in regression).
sample_weight : {sample_weight_shape}
Weights of training data.
init_score : array-like of shape = [n_samples] or None, optional (default=None)
init_score : {init_score_shape}
Init score of training data.
group : {group_shape}
Group/query data.
Expand Down Expand Up @@ -706,6 +706,7 @@ def _get_meta_data(collection, name, i):
X_shape="array-like or sparse matrix of shape = [n_samples, n_features]",
y_shape="array-like of shape = [n_samples]",
sample_weight_shape="array-like of shape = [n_samples] or None, optional (default=None)",
init_score_shape="array-like of shape = [n_samples] or None, optional (default=None)",
group_shape="array-like or None, optional (default=None)"
) + "\n\n" + _lgbmmodel_doc_custom_eval_note

Expand Down
45 changes: 45 additions & 0 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import inspect
import pickle
import random
import socket
from itertools import groupby
from os import getenv
Expand Down Expand Up @@ -1228,6 +1229,50 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(
client.close(timeout=CLIENT_CLOSE_TIMEOUT)


@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_init_score(
task,
output,
client):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')

if task == 'ranking':
_, _, _, _, dX, dy, dw, dg = _create_ranking_data(
output=output,
group=None
)
model_factory = lgb.DaskLGBMRanker
else:
_, _, _, dX, dy, dw = _create_data(
objective=task,
output=output,
)
dg = None
if task == 'classification':
model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
model_factory = lgb.DaskLGBMRegressor

params = {
'n_estimators': 1,
'num_leaves': 2,
'time_out': 5
}
init_score = random.random()
if output.startswith('dataframe'):
init_scores = dy.map_partitions(lambda x: pd.Series([init_score] * x.size))
else:
init_scores = da.full_like(dy, fill_value=init_score, dtype=np.float64)
model = model_factory(client=client, **params)
model.fit(dX, dy, sample_weight=dw, init_score=init_scores, group=dg)
# value of the root node is 0 when init_score is set
assert model.booster_.trees_to_dataframe()['value'][0] == 0

client.close(timeout=CLIENT_CLOSE_TIMEOUT)


def sklearn_checks_to_run():
check_names = [
"check_estimator_get_tags_default_keys",
Expand Down

0 comments on commit 37e9878

Please sign in to comment.