Skip to content

Commit

Permalink
Merge github.com:microsoft/lightgbm into Rleaks
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes committed Sep 10, 2021
2 parents 4d73f3a + a08c37f commit 3bde37f
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/r_solaris.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
url=${line#*@}
body="${body}**${platform}**: ${url}\r\n"
done < "$GITHUB_WORKSPACE/rhub_logs.txt" || true
body="${body}Reports also have been sent to LightGBM public e-mail: http://www.yopmail.com/lightgbm_rhub_checks\r\n"
body="${body}Reports also have been sent to LightGBM public e-mail: https://yopmail.com?lightgbm_rhub_checks\r\n"
body="${body}Status: ${{ job.status }}."
$GITHUB_WORKSPACE/.ci/append_comment.sh \
"${{ github.event.client_payload.comment_number }}" \
Expand Down
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,9 @@ if(USE_SWIG)
COMMAND "${Java_JAR_EXECUTABLE}" -cf lightgbmlib.jar com)
else()
add_custom_command(TARGET _lightgbm_swig POST_BUILD
COMMAND "${Java_JAVAC_EXECUTABLE}" -d . java/*.java
COMMAND cp "${PROJECT_SOURCE_DIR}/*.so" com/microsoft/ml/lightgbm/linux/x86_64
COMMAND "${Java_JAR_EXECUTABLE}" -cf lightgbmlib.jar com)
COMMAND "${Java_JAVAC_EXECUTABLE}" -d . java/*.java
COMMAND cp "${PROJECT_SOURCE_DIR}/*.so" com/microsoft/ml/lightgbm/linux/x86_64
COMMAND "${Java_JAR_EXECUTABLE}" -cf lightgbmlib.jar com)
endif()
endif(USE_SWIG)

Expand Down
27 changes: 10 additions & 17 deletions R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -663,34 +663,27 @@ Dataset <- R6::R6Class(
# Set reference
set_reference = function(reference) {

# Set known references
self$set_categorical_feature(categorical_feature = reference$.__enclos_env__$private$categorical_feature)
self$set_colnames(colnames = reference$get_colnames())
private$set_predictor(predictor = reference$.__enclos_env__$private$predictor)

# Check for identical references
# setting reference to this same Dataset object doesn't require any changes
if (identical(private$reference, reference)) {
return(invisible(self))
}

# Check for empty data
# changing the reference removes the Dataset object on the C++ side, so it should only
# be done if you still have the raw_data available, so that the new Dataset can be reconstructed
if (is.null(private$raw_data)) {

stop("set_reference: cannot set reference after freeing raw data,
please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset")

}

# Check for non-existing reference
if (!is.null(reference)) {

# Reference is unknown
if (!lgb.is.Dataset(reference)) {
stop("set_reference: Can only use lgb.Dataset as a reference")
}

if (!lgb.is.Dataset(reference)) {
stop("set_reference: Can only use lgb.Dataset as a reference")
}

# Set known references
self$set_categorical_feature(categorical_feature = reference$.__enclos_env__$private$categorical_feature)
self$set_colnames(colnames = reference$get_colnames())
private$set_predictor(predictor = reference$.__enclos_env__$private$predictor)

# Store reference
private$reference <- reference

Expand Down
116 changes: 116 additions & 0 deletions R-package/tests/testthat/test_dataset.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
context("testing lgb.Dataset functionality")

data(agaricus.train, package = "lightgbm")
train_data <- agaricus.train$data[seq_len(1000L), ]
train_label <- agaricus.train$label[seq_len(1000L)]

data(agaricus.test, package = "lightgbm")
test_data <- agaricus.test$data[1L:100L, ]
test_label <- agaricus.test$label[1L:100L]
Expand Down Expand Up @@ -74,6 +78,118 @@ test_that("Dataset$slice() supports passing Dataset attributes through '...'", {
expect_identical(dsub1$getinfo("init_score"), init_score)
})

test_that("Dataset$set_reference() on a constructed Dataset fails if raw data has been freed", {
dtrain <- lgb.Dataset(train_data, label = train_label)
dtrain$construct()
dtest <- lgb.Dataset(test_data, label = test_label)
dtest$construct()
expect_error({
dtest$set_reference(dtrain)
}, regexp = "cannot set reference after freeing raw data")
})

test_that("Dataset$set_reference() fails if reference is not a Dataset", {
dtrain <- lgb.Dataset(
train_data
, label = train_label
, free_raw_data = FALSE
)
expect_error({
dtrain$set_reference(reference = data.frame(x = rnorm(10L)))
}, regexp = "Can only use lgb.Dataset as a reference")

# passing NULL when the Dataset already has a reference raises an error
dtest <- lgb.Dataset(
test_data
, label = test_label
, free_raw_data = FALSE
)
dtrain$set_reference(dtest)
expect_error({
dtrain$set_reference(reference = NULL)
}, regexp = "Can only use lgb.Dataset as a reference")
})

test_that("Dataset$set_reference() setting reference to the same Dataset has no side effects", {
dtrain <- lgb.Dataset(
train_data
, label = train_label
, free_raw_data = FALSE
, categorical_feature = c(2L, 3L)
)
dtrain$construct()

cat_features_before <- dtrain$.__enclos_env__$private$categorical_feature
colnames_before <- dtrain$get_colnames()
predictor_before <- dtrain$.__enclos_env__$private$predictor

dtrain$set_reference(dtrain)
expect_identical(
cat_features_before
, dtrain$.__enclos_env__$private$categorical_feature
)
expect_identical(
colnames_before
, dtrain$get_colnames()
)
expect_identical(
predictor_before
, dtrain$.__enclos_env__$private$predictor
)
})

test_that("Dataset$set_reference() updates categorical_feature, colnames, and predictor", {
dtrain <- lgb.Dataset(
train_data
, label = train_label
, free_raw_data = FALSE
, categorical_feature = c(2L, 3L)
)
dtrain$construct()
bst <- Booster$new(
train_set = dtrain
, params = list(verbose = -1L)
)
dtrain$.__enclos_env__$private$predictor <- bst$to_predictor()

test_original_feature_names <- paste0("feature_col_", seq_len(ncol(test_data)))
dtest <- lgb.Dataset(
test_data
, label = test_label
, free_raw_data = FALSE
, colnames = test_original_feature_names
)
dtest$construct()

# at this point, dtest should not have categorical_feature
expect_null(dtest$.__enclos_env__$private$predictor)
expect_null(dtest$.__enclos_env__$private$categorical_feature)
expect_identical(
dtest$get_colnames()
, test_original_feature_names
)

dtest$set_reference(dtrain)

# after setting reference to dtrain, those attributes should have dtrain's values
expect_is(dtest$.__enclos_env__$private$predictor, "lgb.Predictor")
expect_identical(
dtest$.__enclos_env__$private$predictor$.__enclos_env__$private$handle
, dtrain$.__enclos_env__$private$predictor$.__enclos_env__$private$handle
)
expect_identical(
dtest$.__enclos_env__$private$categorical_feature
, dtrain$.__enclos_env__$private$categorical_feature
)
expect_identical(
dtest$get_colnames()
, dtrain$get_colnames()
)
expect_false(
identical(dtest$get_colnames(), test_original_feature_names)
)
})

test_that("lgb.Dataset: colnames", {
dtest <- lgb.Dataset(test_data, label = test_label)
expect_equal(colnames(dtest), colnames(test_data))
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ Kubeflow Fairing (LightGBM on Kubernetes): https://github.com/kubeflow/fairing

Kubeflow Operator (LightGBM on Kubernetes): https://github.com/kubeflow/xgboost-operator

lightgbm_ray (LightGBM on Ray): https://github.com/ray-project/lightgbm_ray

ML.NET (.NET/C#-package): https://github.com/dotnet/machinelearning

LightGBM.NET (.NET/C#-package): https://github.com/rca22/LightGBM.Net
Expand Down
19 changes: 18 additions & 1 deletion docs/Parallel-Learning-Guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ MPI Version
3. Run following command on one machine (not need to run on all machines), need to change ``your_config_file`` to real config file.

For Windows:

.. code::
mpiexec.exe /machinefile mlist.txt lightgbm.exe config=your_config_file
Expand All @@ -451,6 +451,17 @@ Example

- `A simple distributed learning example`_

Ray
^^^

`Ray`_ is a Python-based framework for distributed computing. The `lightgbm_ray`_ project, maintained within the official Ray GitHub organization, can be used to perform distributed LightGBM training using ``ray``.

See `the lightgbm_ray documentation`_ for usage examples.

.. note::

``lightgbm_ray`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should be directed to https://github.com/ray-project/lightgbm_ray/issues.

.. _Dask: https://docs.dask.org/en/latest/

.. _SynapseML: https://aka.ms/spark
Expand Down Expand Up @@ -482,3 +493,9 @@ Example
.. _here: https://www.youtube.com/watch?v=iqzXhp5TxUY

.. _A simple distributed learning example: https://github.com/microsoft/lightgbm/tree/master/examples/parallel_learning

.. _lightgbm_ray: https://github.com/ray-project/lightgbm_ray

.. _Ray: https://ray.io/

.. _the lightgbm_ray documentation: https://docs.ray.io/en/latest/lightgbm-ray.html
11 changes: 7 additions & 4 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, func):
Parameters
----------
func : callable
Expects a callable with signature ``func(y_true, y_pred)`` or ``func(y_true, y_pred, group)
Expects a callable with signature ``func(y_true, y_pred)`` or ``func(y_true, y_pred, group)``
and returns (grad, hess):
y_true : array-like of shape = [n_samples]
Expand Down Expand Up @@ -611,9 +611,12 @@ def fit(self, X, y,
params.pop(alias, None)
params['num_class'] = self._n_classes
if hasattr(self, '_eval_at'):
eval_at = self._eval_at
for alias in _ConfigAliases.get('eval_at'):
params.pop(alias, None)
params['eval_at'] = self._eval_at
if alias in params:
_log_warning(f"Found '{alias}' in params. Will use it instead of 'eval_at' argument")
eval_at = params.pop(alias)
params['eval_at'] = eval_at
params['objective'] = self._objective
if self._fobj:
params['objective'] = 'None' # objective = nullptr for unknown objective
Expand Down Expand Up @@ -752,7 +755,7 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
"""Docstring is set after definition, using a template."""
if self._n_features is None:
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.")
if not isinstance(X, (pd_DataFrame, dt_DataTable)):
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
n_features = X.shape[1]
Expand Down
20 changes: 16 additions & 4 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@
import socket
from itertools import groupby
from os import getenv
from platform import machine
from sys import platform
from urllib.parse import urlparse

import pytest

import lightgbm as lgb

if not platform.startswith('linux'):
pytest.skip('lightgbm.dask is currently supported in Linux environments', allow_module_level=True)
if machine() != 'x86_64':
pytest.skip('lightgbm.dask tests are currently skipped on some architectures like arm64', allow_module_level=True)
if not lgb.compat.DASK_INSTALLED:
pytest.skip('Dask is not installed', allow_module_level=True)

Expand Down Expand Up @@ -84,6 +88,11 @@ def listen_port():
listen_port.port = 13000


def _get_workers_hostname(cluster: LocalCluster) -> str:
one_worker_address = next(iter(cluster.scheduler_info['workers']))
return urlparse(one_worker_address).hostname


def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs):
X, y, g = make_ranking(n_samples=n_samples, random_state=42, **kwargs)
rnd = np.random.RandomState(42)
Expand Down Expand Up @@ -482,8 +491,9 @@ def test_training_does_not_fail_on_port_conflicts(cluster):
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')

lightgbm_default_port = 12400
workers_hostname = _get_workers_hostname(cluster)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', lightgbm_default_port))
s.bind((workers_hostname, lightgbm_default_port))
dask_classifier = lgb.DaskLGBMClassifier(
client=client,
time_out=5,
Expand Down Expand Up @@ -1392,13 +1402,14 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c
assert 'machines' not in params

# model 2 - machines given
workers_hostname = _get_workers_hostname(cluster)
n_workers = len(client.scheduler_info()['workers'])
open_ports = lgb.dask._find_n_open_ports(n_workers)
dask_model2 = dask_model_factory(
n_estimators=5,
num_leaves=5,
machines=",".join([
f"127.0.0.1:{port}"
f"{workers_hostname}:{port}"
for port in open_ports
]),
)
Expand Down Expand Up @@ -1439,12 +1450,13 @@ def test_machines_should_be_used_if_provided(task, cluster):

n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
workers_hostname = _get_workers_hostname(cluster)
open_ports = lgb.dask._find_n_open_ports(n_workers)
dask_model = dask_model_factory(
n_estimators=5,
num_leaves=5,
machines=",".join([
f"127.0.0.1:{port}"
f"{workers_hostname}:{port}"
for port in open_ports
]),
)
Expand All @@ -1454,7 +1466,7 @@ def test_machines_should_be_used_if_provided(task, cluster):
error_msg = f"Binding port {open_ports[0]} failed"
with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', open_ports[0]))
s.bind((workers_hostname, open_ports[0]))
dask_model.fit(dX, dy, group=dg)

# The above error leaves a worker waiting
Expand Down
13 changes: 13 additions & 0 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@ def test_xendcg():
assert gbm.best_score_['valid_0']['ndcg@3'] > 0.6253


def test_eval_at_aliases():
rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank'
X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train'))
X_test, y_test = load_svmlight_file(str(rank_example_dir / 'rank.test'))
q_train = np.loadtxt(str(rank_example_dir / 'rank.train.query'))
q_test = np.loadtxt(str(rank_example_dir / 'rank.test.query'))
for alias in ('eval_at', 'ndcg_eval_at', 'ndcg_at', 'map_eval_at', 'map_at'):
gbm = lgb.LGBMRanker(n_estimators=5, **{alias: [1, 2, 3, 9]})
with pytest.warns(UserWarning, match=f"Found '{alias}' in params. Will use it instead of 'eval_at' argument"):
gbm.fit(X_train, y_train, group=q_train, eval_set=[(X_test, y_test)], eval_group=[q_test])
assert list(gbm.evals_result_['valid_0'].keys()) == ['ndcg@1', 'ndcg@2', 'ndcg@3', 'ndcg@9']


def test_regression_with_custom_objective():
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down

0 comments on commit 3bde37f

Please sign in to comment.