diff --git a/.ci/test.sh b/.ci/test.sh index 84b7d82cad38..2b8b48daa6c0 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -48,7 +48,7 @@ if [[ $TRAVIS == "true" ]] && [[ $TASK == "lint" ]]; then conda install -q -y -n $CONDA_ENV \ -c conda-forge \ libxml2 \ - r-lintr>=2.0 + "r-lintr>=2.0" pip install --user cpplint echo "Linting Python code" pycodestyle --ignore=E501,W503 --exclude=./compute,./.nuget . || exit -1 @@ -74,7 +74,7 @@ if [[ $TASK == "r-package" ]]; then exit 0 fi -conda install -q -y -n $CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz "scikit-learn<=0.21.3" scipy +conda install -q -y -n $CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy if [[ $OS_NAME == "macos" ]] && [[ $COMPILER == "clang" ]]; then # fix "OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized." (OpenMP library conflict due to conda's MKL) diff --git a/.ci/test_r_package.sh b/.ci/test_r_package.sh index 30254f8ecf70..4dddd36225f4 100755 --- a/.ci/test_r_package.sh +++ b/.ci/test_r_package.sh @@ -35,11 +35,15 @@ fi # Installing R precompiled for Mac OS 10.11 or higher if [[ $OS_NAME == "macos" ]]; then + # temp fix for basictex + if [[ $AZURE == "true" ]]; then + brew update + fi brew install qpdf brew cask install basictex export PATH="/Library/TeX/texbin:$PATH" - sudo tlmgr update --self - sudo tlmgr install inconsolata helvetic + sudo tlmgr --verify-repo=none update --self + sudo tlmgr --verify-repo=none install inconsolata helvetic wget -q https://cran.r-project.org/bin/macosx/R-${R_MAC_VERSION}.pkg -O R.pkg sudo installer \ diff --git a/.ci/test_windows.ps1 b/.ci/test_windows.ps1 index ec292b0b79fc..fd0e9f95a5e4 100644 --- a/.ci/test_windows.ps1 +++ b/.ci/test_windows.ps1 @@ -22,7 +22,7 @@ conda init powershell conda activate conda config --set always_yes yes --set changeps1 no conda update -q -y conda -conda create -q -y -n $env:CONDA_ENV python=$env:PYTHON_VERSION joblib matplotlib numpy pandas psutil pytest python-graphviz "scikit-learn<=0.21.3" scipy ; Check-Output $? +conda create -q -y -n $env:CONDA_ENV python=$env:PYTHON_VERSION joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy ; Check-Output $? conda activate $env:CONDA_ENV if ($env:TASK -eq "regular") { diff --git a/.travis.yml b/.travis.yml index dde786e9dfbb..acba8bfc3ad0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,7 +12,7 @@ osx_image: xcode11.3 env: global: # default values - - PYTHON_VERSION=3.7 + - PYTHON_VERSION=3.8 matrix: - TASK=regular PYTHON_VERSION=3.6 - TASK=sdist PYTHON_VERSION=2.7 @@ -21,7 +21,7 @@ env: - TASK=lint - TASK=check-docs - TASK=mpi METHOD=source - - TASK=mpi METHOD=pip + - TASK=mpi METHOD=pip PYTHON_VERSION=3.7 - TASK=gpu METHOD=source PYTHON_VERSION=3.5 - TASK=gpu METHOD=pip PYTHON_VERSION=3.6 - TASK=r-package diff --git a/.vsts-ci.yml b/.vsts-ci.yml index 81fbc1c38366..5e99f9d36f45 100644 --- a/.vsts-ci.yml +++ b/.vsts-ci.yml @@ -6,7 +6,7 @@ trigger: include: - v* variables: - PYTHON_VERSION: 3.7 + PYTHON_VERSION: 3.8 CONDA_ENV: test-env resources: containers: @@ -33,7 +33,7 @@ jobs: PYTHON_VERSION: 3.5 bdist: TASK: bdist - PYTHON_VERSION: 3.6 + PYTHON_VERSION: 3.7 inference: TASK: if-else mpi_source: @@ -82,7 +82,7 @@ jobs: TASK: r-package regular: TASK: regular - PYTHON_VERSION: 3.6 + PYTHON_VERSION: 3.7 sdist: TASK: sdist PYTHON_VERSION: 3.5 @@ -124,13 +124,12 @@ jobs: R_WINDOWS_VERSION: 3.6.3 regular: TASK: regular - PYTHON_VERSION: 3.7 + PYTHON_VERSION: 3.6 sdist: TASK: sdist PYTHON_VERSION: 2.7 bdist: TASK: bdist - PYTHON_VERSION: 3.5 steps: - powershell: | Write-Host "##vso[task.prependpath]$env:CONDA\Scripts" diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index d8e3e9ce485a..f7aa4d10f49d 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -795,32 +795,27 @@ predict.lgb.Booster <- function(object, #' @export lgb.load <- function(filename = NULL, model_str = NULL) { - if (is.null(filename) && is.null(model_str)) { - stop("lgb.load: either filename or model_str must be given") - } - - # Load from filename - if (!is.null(filename) && !is.character(filename)) { - stop("lgb.load: filename should be character") - } + filename_provided <- !is.null(filename) + model_str_provided <- !is.null(model_str) - # Return new booster - if (!is.null(filename) && !file.exists(filename)) { - stop("lgb.load: file does not exist for supplied filename") - } - if (!is.null(filename)) { + if (filename_provided) { + if (!is.character(filename)) { + stop("lgb.load: filename should be character") + } + if (!file.exists(filename)) { + stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename)) + } return(invisible(Booster$new(modelfile = filename))) } - # Load from model_str - if (!is.null(model_str) && !is.character(model_str)) { - stop("lgb.load: model_str should be character") - } - # Return new booster - if (!is.null(model_str)) { + if (model_str_provided) { + if (!is.character(model_str)) { + stop("lgb.load: model_str should be character") + } return(invisible(Booster$new(model_str = model_str))) } + stop("lgb.load: either filename or model_str must be given") } #' @name lgb.save diff --git a/R-package/README.md b/R-package/README.md index 6845a63a1c04..6e4a6eb33050 100644 --- a/R-package/README.md +++ b/R-package/README.md @@ -150,7 +150,6 @@ export CC=/usr/local/bin/gcc-8 Rscript build_r.R # Get coverage -rm -rf lightgbm_r/build Rscript -e " \ coverage <- covr::package_coverage('./lightgbm_r', quiet=FALSE); print(coverage); diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index d35f257ddac9..e3f1030d7755 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -571,3 +571,27 @@ test_that("lgb.train() works with early stopping for regression", { , early_stopping_rounds + 1L ) }) + +test_that("lgb.train() supports non-ASCII feature names", { + testthat::skip("UTF-8 feature names are not fully supported in the R package") + dtrain <- lgb.Dataset( + data = matrix(rnorm(400L), ncol = 4L) + , label = rnorm(100L) + ) + feature_names <- c("F_零", "F_一", "F_二", "F_三") + bst <- lgb.train( + data = dtrain + , nrounds = 5L + , obj = "regression" + , params = list( + metric = "rmse" + ) + , colnames = feature_names + ) + expect_true(lgb.is.Booster(bst)) + dumped_model <- jsonlite::fromJSON(bst$dump_model()) + expect_identical( + dumped_model[["feature_names"]] + , feature_names + ) +}) diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index c038a6aa3194..8ccb357626ce 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -88,3 +88,142 @@ test_that("lgb.get.eval.result() should throw an informative error for incorrect ) }, regexp = "Only the following eval_names exist for dataset.*\\: \\[l2\\]", fixed = FALSE) }) + +context("lgb.load()") + +test_that("lgb.load() gives the expected error messages given different incorrect inputs", { + set.seed(708L) + data(agaricus.train, package = "lightgbm") + data(agaricus.test, package = "lightgbm") + train <- agaricus.train + test <- agaricus.test + bst <- lightgbm( + data = as.matrix(train$data) + , label = train$label + , num_leaves = 4L + , learning_rate = 1.0 + , nrounds = 2L + , objective = "binary" + ) + + # you have to give model_str or filename + expect_error({ + lgb.load() + }, regexp = "either filename or model_str must be given") + expect_error({ + lgb.load(filename = NULL, model_str = NULL) + }, regexp = "either filename or model_str must be given") + + # if given, filename should be a string that points to an existing file + out_file <- "lightgbm.model" + expect_error({ + lgb.load(filename = list(out_file)) + }, regexp = "filename should be character") + file_to_check <- paste0("a.model") + while (file.exists(file_to_check)) { + file_to_check <- paste0("a", file_to_check) + } + expect_error({ + lgb.load(filename = file_to_check) + }, regexp = "passed to filename does not exist") + + # if given, model_str should be a string + expect_error({ + lgb.load(model_str = c(4.0, 5.0, 6.0)) + }, regexp = "model_str should be character") + +}) + +test_that("Loading a Booster from a file works", { + set.seed(708L) + data(agaricus.train, package = "lightgbm") + data(agaricus.test, package = "lightgbm") + train <- agaricus.train + test <- agaricus.test + bst <- lightgbm( + data = as.matrix(train$data) + , label = train$label + , num_leaves = 4L + , learning_rate = 1.0 + , nrounds = 2L + , objective = "binary" + ) + expect_true(lgb.is.Booster(bst)) + + pred <- predict(bst, test$data) + lgb.save(bst, "lightgbm.model") + + # finalize the booster and destroy it so you know we aren't cheating + bst$finalize() + expect_null(bst$.__enclos_env__$private$handle) + rm(bst) + + bst2 <- lgb.load( + filename = "lightgbm.model" + ) + pred2 <- predict(bst2, test$data) + expect_identical(pred, pred2) +}) + +test_that("Loading a Booster from a string works", { + set.seed(708L) + data(agaricus.train, package = "lightgbm") + data(agaricus.test, package = "lightgbm") + train <- agaricus.train + test <- agaricus.test + bst <- lightgbm( + data = as.matrix(train$data) + , label = train$label + , num_leaves = 4L + , learning_rate = 1.0 + , nrounds = 2L + , objective = "binary" + ) + expect_true(lgb.is.Booster(bst)) + + pred <- predict(bst, test$data) + model_string <- bst$save_model_to_string() + + # finalize the booster and destroy it so you know we aren't cheating + bst$finalize() + expect_null(bst$.__enclos_env__$private$handle) + rm(bst) + + bst2 <- lgb.load( + model_str = model_string + ) + pred2 <- predict(bst2, test$data) + expect_identical(pred, pred2) +}) + +test_that("If a string and a file are both passed to lgb.load() the file is used model_str is totally ignored", { + set.seed(708L) + data(agaricus.train, package = "lightgbm") + data(agaricus.test, package = "lightgbm") + train <- agaricus.train + test <- agaricus.test + bst <- lightgbm( + data = as.matrix(train$data) + , label = train$label + , num_leaves = 4L + , learning_rate = 1.0 + , nrounds = 2L + , objective = "binary" + ) + expect_true(lgb.is.Booster(bst)) + + pred <- predict(bst, test$data) + lgb.save(bst, "lightgbm.model") + + # finalize the booster and destroy it so you know we aren't cheating + bst$finalize() + expect_null(bst$.__enclos_env__$private$handle) + rm(bst) + + bst2 <- lgb.load( + filename = "lightgbm.model" + , model_str = 4.0 + ) + pred2 <- predict(bst2, test$data) + expect_identical(pred, pred2) +}) diff --git a/docker/dockerfile-python b/docker/dockerfile-python index 4029a097fac4..b157b41117ba 100644 --- a/docker/dockerfile-python +++ b/docker/dockerfile-python @@ -18,7 +18,7 @@ RUN apt-get update && \ export PATH="$CONDA_DIR/bin:$PATH" && \ conda config --set always_yes yes --set changeps1 no && \ # lightgbm - conda install -q -y numpy scipy "scikit-learn<=0.21.3" pandas && \ + conda install -q -y numpy scipy scikit-learn pandas && \ git clone --recursive --branch stable --depth 1 https://github.com/Microsoft/LightGBM && \ cd LightGBM/python-package && python setup.py install && \ # clean diff --git a/docker/gpu/dockerfile.gpu b/docker/gpu/dockerfile.gpu index 2060b39974bf..c4801d6e462f 100644 --- a/docker/gpu/dockerfile.gpu +++ b/docker/gpu/dockerfile.gpu @@ -75,8 +75,8 @@ RUN echo "export PATH=$CONDA_DIR/bin:"'$PATH' > /etc/profile.d/conda.sh && \ rm ~/miniconda.sh RUN conda config --set always_yes yes --set changeps1 no && \ - conda create -y -q -n py2 python=2.7 mkl numpy scipy "scikit-learn<=0.21.3" jupyter notebook ipython pandas matplotlib && \ - conda create -y -q -n py3 python=3.6 mkl numpy scipy "scikit-learn<=0.21.3" jupyter notebook ipython pandas matplotlib + conda create -y -q -n py2 python=2.7 mkl numpy scipy scikit-learn jupyter notebook ipython pandas matplotlib && \ + conda create -y -q -n py3 python=3.6 mkl numpy scipy scikit-learn jupyter notebook ipython pandas matplotlib ################################################################################################################# # LightGBM diff --git a/docs/GPU-Targets.rst b/docs/GPU-Targets.rst index 47e59865507e..d5cbfb873f27 100644 --- a/docs/GPU-Targets.rst +++ b/docs/GPU-Targets.rst @@ -157,7 +157,7 @@ Known issues: .. _Intel SDK for OpenCL: https://software.intel.com/en-us/articles/opencl-drivers -.. _ROCm: https://rocm.github.io/ +.. _ROCm: https://rocm-documentation.readthedocs.io/en/latest/ .. _our GitHub repo: https://github.com/microsoft/LightGBM/releases/download/v2.0.12/AMD-APP-SDKInstaller-v3.0.130.136-GA-linux64.tar.bz2 diff --git a/docs/GPU-Tutorial.rst b/docs/GPU-Tutorial.rst index dca95c1e2031..d8da7ec83385 100644 --- a/docs/GPU-Tutorial.rst +++ b/docs/GPU-Tutorial.rst @@ -1,4 +1,4 @@ -LightGBM GPU Tutorial +LightGBM GPU Tutorial ===================== The purpose of this document is to give you a quick step-by-step tutorial on GPU training. @@ -78,7 +78,7 @@ If you want to use the Python interface of LightGBM, you can install it now (alo :: sudo apt-get -y install python-pip - sudo -H pip install setuptools numpy scipy "scikit-learn<=0.21.3" -U + sudo -H pip install setuptools numpy scipy scikit-learn -U cd python-package/ sudo python setup.py install --precompile cd .. diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 0d9ade659fef..30ea6cd51674 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -470,6 +470,14 @@ Learning Control Parameters - ``intermediate``, a `more advanced method `__, which may slow the library very slightly. However, this method is much less constraining than the basic method and should significantly improve the results +- ``monotone_penalty`` :raw-html:`🔗︎`, default = ``0.0``, type = double, aliases: ``monotone_splits_penalty``, ``ms_penalty``, ``mc_penalty``, constraints: ``monotone_penalty >= 0.0`` + + - used only if ``monotone_constraints`` is set + + - `monotone penalty `__: a penalization parameter X forbids any monotone splits on the first X (rounded down) level(s) of the tree. The penalty applied to monotone splits on a given depth is a continuous, increasing function the penalization parameter + + - if ``0.0`` (the default), no penalization is applied + - ``feature_contri`` :raw-html:`🔗︎`, default = ``None``, type = multi-double, aliases: ``feature_contrib``, ``fc``, ``fp``, ``feature_penalty`` - used to control feature's split gain, will use ``gain[i] = max(0, feature_contri[i]) * gain[i]`` to replace the split gain of i-th feature diff --git a/docs/Python-API.rst b/docs/Python-API.rst index e87a3523223b..de6b1ec6f2b9 100644 --- a/docs/Python-API.rst +++ b/docs/Python-API.rst @@ -24,10 +24,6 @@ Training API Scikit-learn API ---------------- -.. warning:: - - The last supported version of scikit-learn is ``0.21.3``. Our estimators are incompatible with newer versions. - .. autosummary:: :toctree: pythonapi/ diff --git a/docs/Python-Intro.rst b/docs/Python-Intro.rst index 8293733809c5..1c69af8d0c53 100644 --- a/docs/Python-Intro.rst +++ b/docs/Python-Intro.rst @@ -15,11 +15,11 @@ Install ------- Install Python-package dependencies, -``setuptools``, ``wheel``, ``numpy`` and ``scipy`` are required, ``scikit-learn<=0.21.3`` is required for sklearn interface and recommended: +``setuptools``, ``wheel``, ``numpy`` and ``scipy`` are required, ``scikit-learn`` is required for sklearn interface and recommended: :: - pip install setuptools wheel numpy scipy "scikit-learn<=0.21.3" -U + pip install setuptools wheel numpy scipy scikit-learn -U Refer to `Python-package`_ folder for the installation guide. diff --git a/docs/requirements.txt b/docs/requirements.txt index 17896e0c7283..2fb1ed05cb53 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,2 @@ -r requirements_base.txt -breathe +breathe < 4.15 diff --git a/docs/requirements_base.txt b/docs/requirements_base.txt index 9c3dfc2a5b90..23bd0e5b8c86 100644 --- a/docs/requirements_base.txt +++ b/docs/requirements_base.txt @@ -1,3 +1,3 @@ -sphinx +sphinx < 3.0 sphinx_rtd_theme >= 0.3 mock; python_version < '3' diff --git a/examples/python-guide/README.md b/examples/python-guide/README.md index 8ff716344d33..aba3c9f51d7a 100644 --- a/examples/python-guide/README.md +++ b/examples/python-guide/README.md @@ -8,7 +8,7 @@ You should install LightGBM [Python-package](https://github.com/microsoft/LightG You also need scikit-learn, pandas, matplotlib (only for plot example), and scipy (only for logistic regression example) to run the examples, but they are not required for the package itself. You can install them with pip: ``` -pip install "scikit-learn<=0.21.3" pandas matplotlib scipy -U +pip install scikit-learn pandas matplotlib scipy -U ``` Now you can run examples in this folder, for example: diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 057f56e99491..01a2061f8efb 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -447,6 +447,13 @@ struct Config { // descl2 = ``intermediate``, a `more advanced method `__, which may slow the library very slightly. However, this method is much less constraining than the basic method and should significantly improve the results std::string monotone_constraints_method = "basic"; + // alias = monotone_splits_penalty, ms_penalty, mc_penalty + // check = >=0.0 + // desc = used only if ``monotone_constraints`` is set + // desc = `monotone penalty `__: a penalization parameter X forbids any monotone splits on the first X (rounded down) level(s) of the tree. The penalty applied to monotone splits on a given depth is a continuous, increasing function the penalization parameter + // desc = if ``0.0`` (the default), no penalization is applied + double monotone_penalty = 0.0; + // type = multi-double // alias = feature_contrib, fc, fp, feature_penalty // default = None diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 866da3697460..802b44b9fc2e 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -633,10 +632,6 @@ class Dataset { // replace ' ' in feature_names with '_' bool spaceInFeatureName = false; for (auto& feature_name : feature_names_) { - // check ascii - if (!Common::CheckASCII(feature_name)) { - Log::Fatal("Do not support non-ASCII characters in feature name."); - } // check json if (!Common::CheckAllowedJSON(feature_name)) { Log::Fatal("Do not support special JSON characters in feature name."); diff --git a/include/LightGBM/utils/common.h b/include/LightGBM/utils/common.h index 6d332f49450a..8acf62d218e9 100644 --- a/include/LightGBM/utils/common.h +++ b/include/LightGBM/utils/common.h @@ -313,6 +313,7 @@ inline static unsigned CountDecimalDigit32(uint32_t n) { 1000000000 }; #ifdef _MSC_VER + // NOLINTNEXTLINE unsigned long i = 0; _BitScanReverse(&i, n | 1); uint32_t t = (i + 1) * 1233 >> 12; @@ -921,15 +922,6 @@ static T SafeLog(T x) { } } -inline bool CheckASCII(const std::string& s) { - for (auto c : s) { - if (static_cast(c) > 127) { - return false; - } - } - return true; -} - inline bool CheckAllowedJSON(const std::string& s) { unsigned char char_code; for (auto c : s) { diff --git a/include/LightGBM/utils/json11.h b/include/LightGBM/utils/json11.h index cd8bd485bac5..eac70a3c91d0 100644 --- a/include/LightGBM/utils/json11.h +++ b/include/LightGBM/utils/json11.h @@ -21,200 +21,206 @@ /* json11 * - * json11 is a tiny JSON library for C++11, providing JSON parsing and serialization. + * json11 is a tiny JSON library for C++11, providing JSON parsing and + * serialization. * - * The core object provided by the library is json11::Json. A Json object represents any JSON - * value: null, bool, number (int or double), string (std::string), array (std::vector), or - * object (std::map). + * The core object provided by the library is json11::Json. A Json object + * represents any JSON value: null, bool, number (int or double), string + * (std::string), array (std::vector), or object (std::map). * - * Json objects act like values: they can be assigned, copied, moved, compared for equality or - * order, etc. There are also helper methods Json::dump, to serialize a Json to a string, and - * Json::parse (static) to parse a std::string as a Json object. + * Json objects act like values: they can be assigned, copied, moved, compared + * for equality or order, etc. There are also helper methods Json::dump, to + * serialize a Json to a string, and Json::parse (static) to parse a std::string + * as a Json object. * - * Internally, the various types of Json object are represented by the JsonValue class - * hierarchy. + * Internally, the various types of Json object are represented by the JsonValue + * class hierarchy. * - * A note on numbers - JSON specifies the syntax of number formatting but not its semantics, - * so some JSON implementations distinguish between integers and floating-point numbers, while - * some don't. In json11, we choose the latter. Because some JSON implementations (namely - * Javascript itself) treat all numbers as the same type, distinguishing the two leads - * to JSON that will be *silently* changed by a round-trip through those implementations. - * Dangerous! To avoid that risk, json11 stores all numbers as double internally, but also + * A note on numbers - JSON specifies the syntax of number formatting but not + * its semantics, so some JSON implementations distinguish between integers and + * floating-point numbers, while some don't. In json11, we choose the latter. + * Because some JSON implementations (namely Javascript itself) treat all + * numbers as the same type, distinguishing the two leads to JSON that will be + * *silently* changed by a round-trip through those implementations. Dangerous! + * To avoid that risk, json11 stores all numbers as double internally, but also * provides integer helpers. * - * Fortunately, double-precision IEEE754 ('double') can precisely store any integer in the - * range +/-2^53, which includes every 'int' on most systems. (Timestamps often use int64 - * or long long to avoid the Y2038K problem; a double storing microseconds since some epoch - * will be exact for +/- 275 years.) + * Fortunately, double-precision IEEE754 ('double') can precisely store any + * integer in the range +/-2^53, which includes every 'int' on most systems. + * (Timestamps often use int64 or long long to avoid the Y2038K problem; a + * double storing microseconds since some epoch will be exact for +/- 275 + * years.) */ #pragma once -#include #include #include #include +#include #include #include namespace json11 { -enum JsonParse { - STANDARD, COMMENTS -}; +enum JsonParse { STANDARD, COMMENTS }; class JsonValue; class Json final { public: - // Types - enum Type { - NUL, NUMBER, BOOL, STRING, ARRAY, OBJECT - }; - - // Array and object typedefs - typedef std::vector array; - typedef std::map object; - - // Constructors for the various types of JSON value. - Json() noexcept; // NUL - Json(std::nullptr_t) noexcept; // NUL - Json(double value); // NUMBER - Json(int value); // NUMBER - Json(bool value); // BOOL - Json(const std::string &value); // STRING - Json(std::string &&value); // STRING - Json(const char * value); // STRING - Json(const array &values); // ARRAY - Json(array &&values); // ARRAY - Json(const object &values); // OBJECT - Json(object &&values); // OBJECT - - // Implicit constructor: anything with a to_json() function. - template - Json(const T & t) : Json(t.to_json()) {} - - // Implicit constructor: map-like objects (std::map, std::unordered_map, etc) - template ().begin()->first)>::value - && std::is_constructible().begin()->second)>::value, - int>::type = 0> - Json(const M & m) : Json(object(m.begin(), m.end())) {} - - // Implicit constructor: vector-like objects (std::list, std::vector, std::set, etc) - template ().begin())>::value, - int>::type = 0> - Json(const V & v) : Json(array(v.begin(), v.end())) {} - - // This prevents Json(some_pointer) from accidentally producing a bool. Use - // Json(bool(some_pointer)) if that behavior is desired. - Json(void *) = delete; - - // Accessors - Type type() const; - - bool is_null() const { return type() == NUL; } - bool is_number() const { return type() == NUMBER; } - bool is_bool() const { return type() == BOOL; } - bool is_string() const { return type() == STRING; } - bool is_array() const { return type() == ARRAY; } - bool is_object() const { return type() == OBJECT; } - - // Return the enclosed value if this is a number, 0 otherwise. Note that json11 does not - // distinguish between integer and non-integer numbers - number_value() and int_value() - // can both be applied to a NUMBER-typed object. - double number_value() const; - int int_value() const; - - // Return the enclosed value if this is a boolean, false otherwise. - bool bool_value() const; - // Return the enclosed string if this is a string, "" otherwise. - const std::string &string_value() const; - // Return the enclosed std::vector if this is an array, or an empty vector otherwise. - const array &array_items() const; - // Return the enclosed std::map if this is an object, or an empty map otherwise. - const object &object_items() const; - - // Return a reference to arr[i] if this is an array, Json() otherwise. - const Json & operator[](size_t i) const; - // Return a reference to obj[key] if this is an object, Json() otherwise. - const Json & operator[](const std::string &key) const; - - // Serialize. - void dump(std::string &out) const; - std::string dump() const { - std::string out; - dump(out); - return out; + // Types + enum Type { NUL, NUMBER, BOOL, STRING, ARRAY, OBJECT }; + + // Array and object typedefs + typedef std::vector array; + typedef std::map object; + + // Constructors for the various types of JSON value. + Json() noexcept; // NUL + explicit Json(std::nullptr_t) noexcept; // NUL + explicit Json(double value); // NUMBER + explicit Json(int value); // NUMBER + explicit Json(bool value); // BOOL + explicit Json(const std::string &value); // STRING + explicit Json(std::string &&value); // STRING + explicit Json(const char *value); // STRING + explicit Json(const array &values); // ARRAY + explicit Json(array &&values); // ARRAY + explicit Json(const object &values); // OBJECT + explicit Json(object &&values); // OBJECT + + // Implicit constructor: anything with a to_json() function. + template + explicit Json(const T &t) : Json(t.to_json()) {} + + // Implicit constructor: map-like objects (std::map, std::unordered_map, etc) + template < + class M, + typename std::enable_if< + std::is_constructible< + std::string, decltype(std::declval().begin()->first)>::value && + std::is_constructible< + Json, decltype(std::declval().begin()->second)>::value, + int>::type = 0> + explicit Json(const M &m) : Json(object(m.begin(), m.end())) {} + + // Implicit constructor: vector-like objects (std::list, std::vector, + // std::set, etc) + template ().begin())>::value, + int>::type = 0> + explicit Json(const V &v) : Json(array(v.begin(), v.end())) {} + + // This prevents Json(some_pointer) from accidentally producing a bool. Use + // Json(bool(some_pointer)) if that behavior is desired. + explicit Json(void *) = delete; + + // Accessors + Type type() const; + + bool is_null() const { return type() == NUL; } + bool is_number() const { return type() == NUMBER; } + bool is_bool() const { return type() == BOOL; } + bool is_string() const { return type() == STRING; } + bool is_array() const { return type() == ARRAY; } + bool is_object() const { return type() == OBJECT; } + + // Return the enclosed value if this is a number, 0 otherwise. Note that + // json11 does not distinguish between integer and non-integer numbers - + // number_value() and int_value() can both be applied to a NUMBER-typed + // object. + double number_value() const; + int int_value() const; + + // Return the enclosed value if this is a boolean, false otherwise. + bool bool_value() const; + // Return the enclosed string if this is a string, "" otherwise. + const std::string &string_value() const; + // Return the enclosed std::vector if this is an array, or an empty vector + // otherwise. + const array &array_items() const; + // Return the enclosed std::map if this is an object, or an empty map + // otherwise. + const object &object_items() const; + + // Return a reference to arr[i] if this is an array, Json() otherwise. + const Json &operator[](size_t i) const; + // Return a reference to obj[key] if this is an object, Json() otherwise. + const Json &operator[](const std::string &key) const; + + // Serialize. + void dump(std::string *out) const; + std::string dump() const { + std::string out; + dump(&out); + return out; + } + + // Parse. If parse fails, return Json() and assign an error message to err. + static Json parse(const std::string &in, std::string *err, + JsonParse strategy = JsonParse::STANDARD); + static Json parse(const char *in, std::string *err, + JsonParse strategy = JsonParse::STANDARD) { + if (in) { + return parse(std::string(in), err, strategy); + } else { + *err = "null input"; + return Json(nullptr); } - - // Parse. If parse fails, return Json() and assign an error message to err. - static Json parse(const std::string & in, - std::string & err, - JsonParse strategy = JsonParse::STANDARD); - static Json parse(const char * in, - std::string & err, - JsonParse strategy = JsonParse::STANDARD) { - if (in) { - return parse(std::string(in), err, strategy); - } else { - err = "null input"; - return nullptr; - } - } - // Parse multiple objects, concatenated or separated by whitespace - static std::vector parse_multi( - const std::string & in, - std::string::size_type & parser_stop_pos, - std::string & err, - JsonParse strategy = JsonParse::STANDARD); - - static inline std::vector parse_multi( - const std::string & in, - std::string & err, - JsonParse strategy = JsonParse::STANDARD) { - std::string::size_type parser_stop_pos; - return parse_multi(in, parser_stop_pos, err, strategy); - } - - bool operator== (const Json &rhs) const; - bool operator< (const Json &rhs) const; - bool operator!= (const Json &rhs) const { return !(*this == rhs); } - bool operator<= (const Json &rhs) const { return !(rhs < *this); } - bool operator> (const Json &rhs) const { return (rhs < *this); } - bool operator>= (const Json &rhs) const { return !(*this < rhs); } - - /* has_shape(types, err) - * - * Return true if this is a JSON object and, for each item in types, has a field of - * the given type. If not, return false and set err to a descriptive message. - */ - typedef std::initializer_list> shape; - bool has_shape(const shape & types, std::string & err) const; + } + // Parse multiple objects, concatenated or separated by whitespace + static std::vector parse_multi( + const std::string &in, std::string::size_type *parser_stop_pos, + std::string *err, JsonParse strategy = JsonParse::STANDARD); + + static inline std::vector parse_multi( + const std::string &in, std::string *err, + JsonParse strategy = JsonParse::STANDARD) { + std::string::size_type parser_stop_pos; + return parse_multi(in, &parser_stop_pos, err, strategy); + } + + bool operator==(const Json &rhs) const; + bool operator<(const Json &rhs) const; + bool operator!=(const Json &rhs) const { return !(*this == rhs); } + bool operator<=(const Json &rhs) const { return !(rhs < *this); } + bool operator>(const Json &rhs) const { return (rhs < *this); } + bool operator>=(const Json &rhs) const { return !(*this < rhs); } + + /* has_shape(types, err) + * + * Return true if this is a JSON object and, for each item in types, has a + * field of the given type. If not, return false and set err to a descriptive + * message. + */ + typedef std::initializer_list> shape; + bool has_shape(const shape &types, std::string *err) const; private: - std::shared_ptr m_ptr; + std::shared_ptr m_ptr; }; -// Internal class hierarchy - JsonValue objects are not exposed to users of this API. +// Internal class hierarchy - JsonValue objects are not exposed to users of this +// API. class JsonValue { protected: - friend class Json; - friend class JsonInt; - friend class JsonDouble; - virtual Json::Type type() const = 0; - virtual bool equals(const JsonValue * other) const = 0; - virtual bool less(const JsonValue * other) const = 0; - virtual void dump(std::string &out) const = 0; - virtual double number_value() const; - virtual int int_value() const; - virtual bool bool_value() const; - virtual const std::string &string_value() const; - virtual const Json::array &array_items() const; - virtual const Json &operator[](size_t i) const; - virtual const Json::object &object_items() const; - virtual const Json &operator[](const std::string &key) const; - virtual ~JsonValue() {} + friend class Json; + friend class JsonInt; + friend class JsonDouble; + virtual Json::Type type() const = 0; + virtual bool equals(const JsonValue *other) const = 0; + virtual bool less(const JsonValue *other) const = 0; + virtual void dump(std::string *out) const = 0; + virtual double number_value() const; + virtual int int_value() const; + virtual bool bool_value() const; + virtual const std::string &string_value() const; + virtual const Json::array &array_items() const; + virtual const Json &operator[](size_t i) const; + virtual const Json::object &object_items() const; + virtual const Json &operator[](const std::string &key) const; + virtual ~JsonValue() {} }; } // namespace json11 diff --git a/include/LightGBM/utils/locale_context.h b/include/LightGBM/utils/locale_context.h deleted file mode 100644 index f8f0073fb5cc..000000000000 --- a/include/LightGBM/utils/locale_context.h +++ /dev/null @@ -1,40 +0,0 @@ -/*! - * Copyright (c) 2020 Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the project root for license information. - */ -#ifndef LIGHTGBM_LOCALE_CONTEXT_H_ -#define LIGHTGBM_LOCALE_CONTEXT_H_ - -#include -#include - -/*! - * Class to override the program global locale during this object lifetime. - * After the object is destroyed, the locale is returned to its original state. - * - * @warn This is not thread-safe. - */ -class LocaleContext { - public: - /*! - * Override the current program global locale during this object lifetime. - * - * @param target_locale override the locale to this locale setting. - * @warn This is not thread-safe. - * @note This doesn't override cout, cerr, etc. - */ - explicit LocaleContext(const char* target_locale = "C") { - std::locale::global(std::locale(target_locale)); - } - - /*! - * Restores the old global locale. - */ - ~LocaleContext() { - std::locale::global(_saved_global_locale); - } - private: - std::locale _saved_global_locale; //!< Stores global locale at initialization. -}; - -#endif // LIGHTGBM_LOCALE_CONTEXT_H_ diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h index aab5cea2905c..075c991371c0 100644 --- a/include/LightGBM/utils/openmp_wrapper.h +++ b/include/LightGBM/utils/openmp_wrapper.h @@ -6,9 +6,10 @@ #define LIGHTGBM_OPENMP_WRAPPER_H_ #ifdef _OPENMP -#include #include +#include + #include #include #include diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index be260fa8cfb5..5d9255d5be74 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -2536,7 +2536,7 @@ def model_to_string(self, num_iteration=None, start_iteration=0): ctypes.c_int64(actual_len), ctypes.byref(tmp_out_len), ptr_string_buffer)) - ret = string_buffer.value.decode() + ret = string_buffer.value.decode('utf-8') ret += _dump_pandas_categorical(self.pandas_categorical) return ret @@ -2582,7 +2582,7 @@ def dump_model(self, num_iteration=None, start_iteration=0): ctypes.c_int64(actual_len), ctypes.byref(tmp_out_len), ptr_string_buffer)) - ret = json.loads(string_buffer.value.decode()) + ret = json.loads(string_buffer.value.decode('utf-8')) ret['pandas_categorical'] = json.loads(json.dumps(self.pandas_categorical, default=json_default_with_numpy)) return ret @@ -2754,7 +2754,7 @@ def feature_name(self): "Allocated feature name buffer size ({}) was inferior to the needed size ({})." .format(reserved_string_buffer_size, required_string_buffer_size.value) ) - return [string_buffers[i].value.decode() for i in range_(num_feature)] + return [string_buffers[i].value.decode('utf-8') for i in range_(num_feature)] def feature_importance(self, importance_type='split', iteration=None): """Get feature importances. @@ -2954,7 +2954,7 @@ def __get_eval_info(self): .format(reserved_string_buffer_size, required_string_buffer_size.value) ) self.__name_inner_eval = \ - [string_buffers[i].value.decode() for i in range_(self.__num_inner_eval)] + [string_buffers[i].value.decode('utf-8') for i in range_(self.__num_inner_eval)] self.__higher_better_inner_eval = \ [name.startswith(('auc', 'ndcg@', 'map@')) for name in self.__name_inner_eval] diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 358478d7305f..5d951a56800a 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -116,16 +116,24 @@ class DataTable(object): from sklearn.preprocessing import LabelEncoder from sklearn.utils.class_weight import compute_sample_weight from sklearn.utils.multiclass import check_classification_targets - from sklearn.utils.validation import (assert_all_finite, check_X_y, - check_array, check_consistent_length) + from sklearn.utils.validation import assert_all_finite, check_X_y, check_array try: from sklearn.model_selection import StratifiedKFold, GroupKFold from sklearn.exceptions import NotFittedError except ImportError: from sklearn.cross_validation import StratifiedKFold, GroupKFold from sklearn.utils.validation import NotFittedError + try: + from sklearn.utils.validation import _check_sample_weight + except ImportError: + from sklearn.utils.validation import check_consistent_length + + # dummy function to support older version of scikit-learn + def _check_sample_weight(sample_weight, X, dtype=None): + check_consistent_length(sample_weight, X) + return sample_weight + SKLEARN_INSTALLED = True - from sklearn import __version__ as SKLEARN_VERSION _LGBMModelBase = BaseEstimator _LGBMRegressorBase = RegressorMixin _LGBMClassifierBase = ClassifierMixin @@ -135,13 +143,12 @@ class DataTable(object): _LGBMGroupKFold = GroupKFold _LGBMCheckXY = check_X_y _LGBMCheckArray = check_array - _LGBMCheckConsistentLength = check_consistent_length + _LGBMCheckSampleWeight = _check_sample_weight _LGBMAssertAllFinite = assert_all_finite _LGBMCheckClassificationTargets = check_classification_targets _LGBMComputeSampleWeight = compute_sample_weight except ImportError: SKLEARN_INSTALLED = False - SKLEARN_VERSION = '0.0.0' _LGBMModelBase = object _LGBMClassifierBase = object _LGBMRegressorBase = object @@ -151,7 +158,7 @@ class DataTable(object): _LGBMGroupKFold = None _LGBMCheckXY = None _LGBMCheckArray = None - _LGBMCheckConsistentLength = None + _LGBMCheckSampleWeight = None _LGBMAssertAllFinite = None _LGBMCheckClassificationTargets = None _LGBMComputeSampleWeight = None diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index e6d5b33a651f..2731bb120a9a 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -7,9 +7,9 @@ import numpy as np from .basic import Dataset, LightGBMError, _ConfigAliases -from .compat import (SKLEARN_INSTALLED, SKLEARN_VERSION, _LGBMClassifierBase, +from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase, LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase, - _LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength, + _LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckSampleWeight, _LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight, argc_, range_, zip_, string_type, DataFrame, DataTable) from .engine import train @@ -298,9 +298,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, """ if not SKLEARN_INSTALLED: raise LightGBMError('Scikit-learn is required for this module') - elif SKLEARN_VERSION > '0.21.3': - raise RuntimeError("The last supported version of scikit-learn is 0.21.3.\n" - "Found version: {0}.".format(SKLEARN_VERSION)) self.boosting_type = boosting_type self.objective = objective @@ -547,7 +544,8 @@ def fit(self, X, y, if not isinstance(X, (DataFrame, DataTable)): _X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2) - _LGBMCheckConsistentLength(_X, _y, sample_weight) + if sample_weight is not None: + sample_weight = _LGBMCheckSampleWeight(sample_weight, _X) else: _X, _y = X, y diff --git a/python-package/setup.py b/python-package/setup.py index f04eaa405028..962d9a3537cb 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -276,7 +276,7 @@ def run(self): install_requires=[ 'numpy', 'scipy', - 'scikit-learn<=0.21.3' + 'scikit-learn!=0.22.0' ], maintainer='Guolin Ke', maintainer_email='guolin.ke@microsoft.com', @@ -304,4 +304,5 @@ def run(self): 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Topic :: Scientific/Engineering :: Artificial Intelligence']) diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 9543e977e082..6a2e3e27c791 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -64,7 +64,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective std::stringstream buffer; buffer << forced_splits_file.rdbuf(); std::string err; - forced_splits_json_ = Json::parse(buffer.str(), err); + forced_splits_json_ = Json::parse(buffer.str(), &err); } objective_function_ = objective_function; @@ -725,7 +725,7 @@ void GBDT::ResetConfig(const Config* config) { std::stringstream buffer; buffer << forced_splits_file.rdbuf(); std::string err; - forced_splits_json_ = Json::parse(buffer.str(), err); + forced_splits_json_ = Json::parse(buffer.str(), &err); tree_learner_->SetForcedSplit(&forced_splits_json_); } else { forced_splits_json_ = Json(); diff --git a/src/c_api.cpp b/src/c_api.cpp index 678de700f759..5066947a1482 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -13,7 +13,6 @@ #include #include #include -#include #include #include #include @@ -1237,7 +1236,6 @@ int LGBM_BoosterCreateFromModelfile( int* out_num_iterations, BoosterHandle* out) { API_BEGIN(); - LocaleContext withLocaleContext("C"); auto ret = std::unique_ptr(new Booster(filename)); *out_num_iterations = ret->GetBoosting()->GetCurrentIteration(); *out = ret.release(); @@ -1249,7 +1247,6 @@ int LGBM_BoosterLoadModelFromString( int* out_num_iterations, BoosterHandle* out) { API_BEGIN(); - LocaleContext withLocaleContext("C"); auto ret = std::unique_ptr(new Booster(nullptr)); ret->LoadModelFromString(model_str); *out_num_iterations = ret->GetBoosting()->GetCurrentIteration(); @@ -1674,7 +1671,6 @@ int LGBM_BoosterSaveModel(BoosterHandle handle, int num_iteration, const char* filename) { API_BEGIN(); - LocaleContext withLocaleContext("C"); Booster* ref_booster = reinterpret_cast(handle); ref_booster->SaveModelToFile(start_iteration, num_iteration, filename); API_END(); @@ -1687,7 +1683,6 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle, int64_t* out_len, char* out_str) { API_BEGIN(); - LocaleContext withLocaleContext("C"); Booster* ref_booster = reinterpret_cast(handle); std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration); *out_len = static_cast(model.size()) + 1; @@ -1704,7 +1699,6 @@ int LGBM_BoosterDumpModel(BoosterHandle handle, int64_t* out_len, char* out_str) { API_BEGIN(); - LocaleContext withLocaleContext("C"); Booster* ref_booster = reinterpret_cast(handle); std::string model = ref_booster->DumpModel(start_iteration, num_iteration); *out_len = static_cast(model.size()) + 1; diff --git a/src/io/config.cpp b/src/io/config.cpp index 0cf1d3c8bf21..1d4e0a52d736 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -20,9 +20,6 @@ void Config::KV2Map(std::unordered_map* params, const if (tmp_strs.size() == 2) { value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1])); } - if (!Common::CheckASCII(key) || !Common::CheckASCII(value)) { - Log::Fatal("Do not support non-ASCII characters in config."); - } if (key.size() > 0) { auto value_search = params->find(key); if (value_search == params->end()) { // not set @@ -328,6 +325,9 @@ void Config::CheckParamConflict() { Log::Warning("Cannot use \"intermediate\" monotone constraints with feature fraction different from 1, auto set monotone constraints to \"basic\" method."); monotone_constraints_method = "basic"; } + if (max_depth > 0 && monotone_penalty >= max_depth) { + Log::Warning("Monotone penalty greater than tree depth. Monotone features won't be used."); + } } std::string Config::ToString() const { diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 59cc62a5d375..b2204affb4df 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -87,6 +87,9 @@ const std::unordered_map& Config::alias_table() { {"monotone_constraint", "monotone_constraints"}, {"monotone_constraining_method", "monotone_constraints_method"}, {"mc_method", "monotone_constraints_method"}, + {"monotone_splits_penalty", "monotone_penalty"}, + {"ms_penalty", "monotone_penalty"}, + {"mc_penalty", "monotone_penalty"}, {"feature_contrib", "feature_contri"}, {"fc", "feature_contri"}, {"fp", "feature_contri"}, @@ -218,6 +221,7 @@ const std::unordered_set& Config::parameter_set() { "top_k", "monotone_constraints", "monotone_constraints_method", + "monotone_penalty", "feature_contri", "forcedsplits_filename", "refit_decay_rate", @@ -419,6 +423,9 @@ void Config::GetMembersFromString(const std::unordered_map(tmp_str, ','); } @@ -639,6 +646,7 @@ std::string Config::SaveMembersToString() const { str_buf << "[top_k: " << top_k << "]\n"; str_buf << "[monotone_constraints: " << Common::Join(Common::ArrayCast(monotone_constraints), ",") << "]\n"; str_buf << "[monotone_constraints_method: " << monotone_constraints_method << "]\n"; + str_buf << "[monotone_penalty: " << monotone_penalty << "]\n"; str_buf << "[feature_contri: " << Common::Join(feature_contri, ",") << "]\n"; str_buf << "[forcedsplits_filename: " << forcedsplits_filename << "]\n"; str_buf << "[refit_decay_rate: " << refit_decay_rate << "]\n"; diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index e48636c103df..fa979d61cadd 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -434,7 +434,7 @@ void Dataset::FinishLoad() { void PushDataToMultiValBin( data_size_t num_data, const std::vector most_freq_bins, const std::vector offsets, - std::vector>>& iters, + std::vector>>* iters, MultiValBin* ret) { Common::FunctionTimer fun_time("Dataset::PushDataToMultiValBin", global_timer); @@ -444,12 +444,12 @@ void PushDataToMultiValBin( std::vector cur_data; cur_data.reserve(most_freq_bins.size()); for (size_t j = 0; j < most_freq_bins.size(); ++j) { - iters[tid][j]->Reset(start); + (*iters)[tid][j]->Reset(start); } for (data_size_t i = start; i < end; ++i) { cur_data.clear(); for (size_t j = 0; j < most_freq_bins.size(); ++j) { - auto cur_bin = iters[tid][j]->Get(i); + auto cur_bin = (*iters)[tid][j]->Get(i); if (cur_bin == most_freq_bins[j]) { continue; } @@ -467,11 +467,11 @@ void PushDataToMultiValBin( 0, num_data, 1024, [&](int tid, data_size_t start, data_size_t end) { std::vector cur_data(most_freq_bins.size(), 0); for (size_t j = 0; j < most_freq_bins.size(); ++j) { - iters[tid][j]->Reset(start); + (*iters)[tid][j]->Reset(start); } for (data_size_t i = start; i < end; ++i) { for (size_t j = 0; j < most_freq_bins.size(); ++j) { - auto cur_bin = iters[tid][j]->Get(i); + auto cur_bin = (*iters)[tid][j]->Get(i); if (cur_bin == most_freq_bins[j]) { cur_bin = 0; } else { @@ -528,7 +528,7 @@ MultiValBin* Dataset::GetMultiBinFromSparseFeatures() const { std::unique_ptr ret; ret.reset(MultiValBin::CreateMultiValBin(num_data_, offsets.back(), num_feature, sum_sparse_rate)); - PushDataToMultiValBin(num_data_, most_freq_bins, offsets, iters, ret.get()); + PushDataToMultiValBin(num_data_, most_freq_bins, offsets, &iters, ret.get()); ret->FinishLoad(); return ret.release(); } @@ -581,7 +581,7 @@ MultiValBin* Dataset::GetMultiBinFromAllFeatures() const { ret.reset(MultiValBin::CreateMultiValBin( num_data_, num_total_bin, static_cast(most_freq_bins.size()), 1.0 - sum_dense_ratio)); - PushDataToMultiValBin(num_data_, most_freq_bins, offsets, iters, ret.get()); + PushDataToMultiValBin(num_data_, most_freq_bins, offsets, &iters, ret.get()); ret->FinishLoad(); return ret.release(); } diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp index 0f321e1089bd..c0b2edf1a8c3 100644 --- a/src/io/dataset_loader.cpp +++ b/src/io/dataset_loader.cpp @@ -1208,7 +1208,7 @@ std::vector> DatasetLoader::GetForcedBins(std::string forced std::stringstream buffer; buffer << forced_bins_stream.rdbuf(); std::string err; - Json forced_bins_json = Json::parse(buffer.str(), err); + Json forced_bins_json = Json::parse(buffer.str(), &err); CHECK(forced_bins_json.is_array()); std::vector forced_bins_arr = forced_bins_json.array_items(); for (size_t i = 0; i < forced_bins_arr.size(); ++i) { diff --git a/src/io/json11.cpp b/src/io/json11.cpp index 0be9bbdf3a6d..db21c6aab544 100644 --- a/src/io/json11.cpp +++ b/src/io/json11.cpp @@ -31,12 +31,12 @@ namespace json11 { static const int max_depth = 200; -using std::string; -using std::vector; -using std::map; -using std::make_shared; using std::initializer_list; +using std::make_shared; +using std::map; using std::move; +using std::string; +using std::vector; using LightGBM::Log; @@ -45,104 +45,98 @@ using LightGBM::Log; * it may not be orderable. */ struct NullStruct { - bool operator==(NullStruct) const { return true; } - bool operator<(NullStruct) const { return false; } + bool operator==(NullStruct) const { return true; } + bool operator<(NullStruct) const { return false; } }; /* * * * * * * * * * * * * * * * * * * * * Serialization */ -static void dump(NullStruct, string &out) { - out += "null"; -} - -static void dump(double value, string &out) { - if (std::isfinite(value)) { - char buf[32]; - snprintf(buf, sizeof buf, "%.17g", value); - out += buf; - } else { - out += "null"; - } -} +static void dump(NullStruct, string *out) { *out += "null"; } -static void dump(int value, string &out) { +static void dump(double value, string *out) { + if (std::isfinite(value)) { char buf[32]; - snprintf(buf, sizeof buf, "%d", value); - out += buf; + snprintf(buf, sizeof buf, "%.17g", value); + *out += buf; + } else { + *out += "null"; + } } -static void dump(bool value, string &out) { - out += value ? "true" : "false"; +static void dump(int value, string *out) { + char buf[32]; + snprintf(buf, sizeof buf, "%d", value); + *out += buf; } -static void dump(const string &value, string &out) { - out += '"'; - for (size_t i = 0; i < value.length(); i++) { - const char ch = value[i]; - if (ch == '\\') { - out += "\\\\"; - } else if (ch == '"') { - out += "\\\""; - } else if (ch == '\b') { - out += "\\b"; - } else if (ch == '\f') { - out += "\\f"; - } else if (ch == '\n') { - out += "\\n"; - } else if (ch == '\r') { - out += "\\r"; - } else if (ch == '\t') { - out += "\\t"; - } else if (static_cast(ch) <= 0x1f) { - char buf[8]; - snprintf(buf, sizeof buf, "\\u%04x", ch); - out += buf; - } else if (static_cast(ch) == 0xe2 && static_cast(value[i+1]) == 0x80 - && static_cast(value[i+2]) == 0xa8) { - out += "\\u2028"; - i += 2; - } else if (static_cast(ch) == 0xe2 && static_cast(value[i+1]) == 0x80 - && static_cast(value[i+2]) == 0xa9) { - out += "\\u2029"; - i += 2; - } else { - out += ch; - } +static void dump(bool value, string *out) { *out += value ? "true" : "false"; } + +static void dump(const string &value, string *out) { + *out += '"'; + for (size_t i = 0; i < value.length(); i++) { + const char ch = value[i]; + if (ch == '\\') { + *out += "\\\\"; + } else if (ch == '"') { + *out += "\\\""; + } else if (ch == '\b') { + *out += "\\b"; + } else if (ch == '\f') { + *out += "\\f"; + } else if (ch == '\n') { + *out += "\\n"; + } else if (ch == '\r') { + *out += "\\r"; + } else if (ch == '\t') { + *out += "\\t"; + } else if (static_cast(ch) <= 0x1f) { + char buf[8]; + snprintf(buf, sizeof buf, "\\u%04x", ch); + *out += buf; + } else if (static_cast(ch) == 0xe2 && + static_cast(value[i + 1]) == 0x80 && + static_cast(value[i + 2]) == 0xa8) { + *out += "\\u2028"; + i += 2; + } else if (static_cast(ch) == 0xe2 && + static_cast(value[i + 1]) == 0x80 && + static_cast(value[i + 2]) == 0xa9) { + *out += "\\u2029"; + i += 2; + } else { + *out += ch; } - out += '"'; + } + *out += '"'; } -static void dump(const Json::array &values, string &out) { - bool first = true; - out += "["; - for (const auto &value : values) { - if (!first) - out += ", "; - value.dump(out); - first = false; - } - out += "]"; +static void dump(const Json::array &values, string *out) { + bool first = true; + *out += "["; + for (const auto &value : values) { + if (!first) *out += ", "; + value.dump(out); + first = false; + } + *out += "]"; } -static void dump(const Json::object &values, string &out) { - bool first = true; - out += "{"; - for (const auto &kv : values) { - if (!first) - out += ", "; - dump(kv.first, out); - out += ": "; - kv.second.dump(out); - first = false; - } - out += "}"; +static void dump(const Json::object &values, string *out) { + bool first = true; + *out += "{"; + for (const auto &kv : values) { + if (!first) *out += ", "; + dump(kv.first, out); + *out += ": "; + kv.second.dump(out); + first = false; + } + *out += "}"; } -void Json::dump(string &out) const { - m_ptr->dump(out); -} +void Json::dump(string *out) const { m_ptr->dump(out); } /* * * * * * * * * * * * * * * * * * * * * Value wrappers @@ -151,174 +145,195 @@ void Json::dump(string &out) const { template class Value : public JsonValue { protected: - // Constructors - explicit Value(const T &value) : m_value(value) {} - explicit Value(T &&value) : m_value(move(value)) {} - - // Get type tag - Json::Type type() const override { - return tag; - } - - // Comparisons - bool equals(const JsonValue * other) const override { - return m_value == static_cast *>(other)->m_value; - } - bool less(const JsonValue * other) const override { - return m_value < static_cast *>(other)->m_value; - } - - const T m_value; - void dump(string &out) const override { json11::dump(m_value, out); } + // Constructors + explicit Value(const T &value) : m_value(value) {} + explicit Value(T &&value) : m_value(move(value)) {} + + // Get type tag + Json::Type type() const override { return tag; } + + // Comparisons + bool equals(const JsonValue *other) const override { + return m_value == static_cast *>(other)->m_value; + } + bool less(const JsonValue *other) const override { + return m_value < (static_cast *>(other)->m_value); + } + + const T m_value; + void dump(string *out) const override { json11::dump(m_value, out); } }; class JsonDouble final : public Value { - double number_value() const override { return m_value; } - int int_value() const override { return static_cast(m_value); } - bool equals(const JsonValue * other) const override { return m_value == other->number_value(); } - bool less(const JsonValue * other) const override { return m_value < other->number_value(); } + double number_value() const override { return m_value; } + int int_value() const override { return static_cast(m_value); } + bool equals(const JsonValue *other) const override { + return m_value == other->number_value(); + } + bool less(const JsonValue *other) const override { + return m_value < other->number_value(); + } + public: - explicit JsonDouble(double value) : Value(value) {} + explicit JsonDouble(double value) : Value(value) {} }; class JsonInt final : public Value { - double number_value() const override { return m_value; } - int int_value() const override { return m_value; } - bool equals(const JsonValue * other) const override { return m_value == other->number_value(); } - bool less(const JsonValue * other) const override { return m_value < other->number_value(); } + double number_value() const override { return m_value; } + int int_value() const override { return m_value; } + bool equals(const JsonValue *other) const override { + return m_value == other->number_value(); + } + bool less(const JsonValue *other) const override { + return m_value < other->number_value(); + } + public: - explicit JsonInt(int value) : Value(value) {} + explicit JsonInt(int value) : Value(value) {} }; class JsonBoolean final : public Value { - bool bool_value() const override { return m_value; } + bool bool_value() const override { return m_value; } + public: - explicit JsonBoolean(bool value) : Value(value) {} + explicit JsonBoolean(bool value) : Value(value) {} }; class JsonString final : public Value { - const string &string_value() const override { return m_value; } + const string &string_value() const override { return m_value; } + public: - explicit JsonString(const string &value) : Value(value) {} - explicit JsonString(string &&value) : Value(move(value)) {} + explicit JsonString(const string &value) : Value(value) {} + explicit JsonString(string &&value) : Value(move(value)) {} }; class JsonArray final : public Value { - const Json::array &array_items() const override { return m_value; } - const Json & operator[](size_t i) const override; + const Json::array &array_items() const override { return m_value; } + const Json &operator[](size_t i) const override; + public: - explicit JsonArray(const Json::array &value) : Value(value) {} - explicit JsonArray(Json::array &&value) : Value(move(value)) {} + explicit JsonArray(const Json::array &value) : Value(value) {} + explicit JsonArray(Json::array &&value) : Value(move(value)) {} }; class JsonObject final : public Value { - const Json::object &object_items() const override { return m_value; } - const Json & operator[](const string &key) const override; + const Json::object &object_items() const override { return m_value; } + const Json &operator[](const string &key) const override; + public: - explicit JsonObject(const Json::object &value) : Value(value) {} - explicit JsonObject(Json::object &&value) : Value(move(value)) {} + explicit JsonObject(const Json::object &value) : Value(value) {} + explicit JsonObject(Json::object &&value) : Value(move(value)) {} }; class JsonNull final : public Value { public: - JsonNull() : Value({}) {} + JsonNull() : Value({}) {} }; /* * * * * * * * * * * * * * * * * * * * * Static globals - static-init-safe */ struct Statics { - const std::shared_ptr null = make_shared(); - const std::shared_ptr t = make_shared(true); - const std::shared_ptr f = make_shared(false); - const string empty_string; - const vector empty_vector; - const map empty_map; - Statics() {} + const std::shared_ptr null = make_shared(); + const std::shared_ptr t = make_shared(true); + const std::shared_ptr f = make_shared(false); + const string empty_string; + const vector empty_vector; + const map empty_map; + Statics() {} }; -static const Statics & statics() { - static const Statics s {}; - return s; +static const Statics &statics() { + static const Statics s{}; + return s; } -static const Json & static_null() { - // This has to be separate, not in Statics, because Json() accesses statics().null. - static const Json json_null; - return json_null; +static const Json &static_null() { + // This has to be separate, not in Statics, because Json() accesses + // statics().null. + static const Json json_null; + return json_null; } /* * * * * * * * * * * * * * * * * * * * * Constructors */ -Json::Json() noexcept : m_ptr(statics().null) {} -Json::Json(std::nullptr_t) noexcept : m_ptr(statics().null) {} -Json::Json(double value) : m_ptr(make_shared(value)) {} -Json::Json(int value) : m_ptr(make_shared(value)) {} -Json::Json(bool value) : m_ptr(value ? statics().t : statics().f) {} -Json::Json(const string &value) : m_ptr(make_shared(value)) {} -Json::Json(string &&value) : m_ptr(make_shared(move(value))) {} -Json::Json(const char * value) : m_ptr(make_shared(value)) {} -Json::Json(const Json::array &values) : m_ptr(make_shared(values)) {} -Json::Json(Json::array &&values) : m_ptr(make_shared(move(values))) {} -Json::Json(const Json::object &values) : m_ptr(make_shared(values)) {} -Json::Json(Json::object &&values) : m_ptr(make_shared(move(values))) {} +Json::Json() noexcept : m_ptr(statics().null) {} +Json::Json(std::nullptr_t) noexcept : m_ptr(statics().null) {} +Json::Json(double value) : m_ptr(make_shared(value)) {} +Json::Json(int value) : m_ptr(make_shared(value)) {} +Json::Json(bool value) : m_ptr(value ? statics().t : statics().f) {} +Json::Json(const string &value) : m_ptr(make_shared(value)) {} +Json::Json(string &&value) : m_ptr(make_shared(move(value))) {} +Json::Json(const char *value) : m_ptr(make_shared(value)) {} +Json::Json(const Json::array &values) : m_ptr(make_shared(values)) {} +Json::Json(Json::array &&values) + : m_ptr(make_shared(move(values))) {} +Json::Json(const Json::object &values) + : m_ptr(make_shared(values)) {} +Json::Json(Json::object &&values) + : m_ptr(make_shared(move(values))) {} /* * * * * * * * * * * * * * * * * * * * * Accessors */ -Json::Type Json::type() const { return m_ptr->type(); } -double Json::number_value() const { return m_ptr->number_value(); } -int Json::int_value() const { return m_ptr->int_value(); } -bool Json::bool_value() const { return m_ptr->bool_value(); } -const string & Json::string_value() const { return m_ptr->string_value(); } -const vector & Json::array_items() const { return m_ptr->array_items(); } -const map & Json::object_items() const { return m_ptr->object_items(); } -const Json & Json::operator[] (size_t i) const { return (*m_ptr)[i]; } -const Json & Json::operator[] (const string &key) const { return (*m_ptr)[key]; } - -double JsonValue::number_value() const { return 0; } -int JsonValue::int_value() const { return 0; } -bool JsonValue::bool_value() const { return false; } -const string & JsonValue::string_value() const { return statics().empty_string; } -const vector & JsonValue::array_items() const { return statics().empty_vector; } -const map & JsonValue::object_items() const { return statics().empty_map; } -const Json & JsonValue::operator[] (size_t) const { return static_null(); } -const Json & JsonValue::operator[] (const string &) const { return static_null(); } - -const Json & JsonObject::operator[] (const string &key) const { - auto iter = m_value.find(key); - return (iter == m_value.end()) ? static_null() : iter->second; +Json::Type Json::type() const { return m_ptr->type(); } +double Json::number_value() const { return m_ptr->number_value(); } +int Json::int_value() const { return m_ptr->int_value(); } +bool Json::bool_value() const { return m_ptr->bool_value(); } +const string &Json::string_value() const { return m_ptr->string_value(); } +const vector &Json::array_items() const { return m_ptr->array_items(); } +const map &Json::object_items() const { + return m_ptr->object_items(); +} +const Json &Json::operator[](size_t i) const { return (*m_ptr)[i]; } +const Json &Json::operator[](const string &key) const { return (*m_ptr)[key]; } + +double JsonValue::number_value() const { return 0; } +int JsonValue::int_value() const { return 0; } +bool JsonValue::bool_value() const { return false; } +const string &JsonValue::string_value() const { return statics().empty_string; } +const vector &JsonValue::array_items() const { + return statics().empty_vector; +} +const map &JsonValue::object_items() const { + return statics().empty_map; +} +const Json &JsonValue::operator[](size_t) const { return static_null(); } +const Json &JsonValue::operator[](const string &) const { + return static_null(); +} + +const Json &JsonObject::operator[](const string &key) const { + auto iter = m_value.find(key); + return (iter == m_value.end()) ? static_null() : iter->second; } -const Json & JsonArray::operator[] (size_t i) const { - if (i >= m_value.size()) - return static_null(); - else - return m_value[i]; +const Json &JsonArray::operator[](size_t i) const { + if (i >= m_value.size()) + return static_null(); + else + return m_value[i]; } /* * * * * * * * * * * * * * * * * * * * * Comparison */ -bool Json::operator== (const Json &other) const { - if (m_ptr == other.m_ptr) - return true; - if (m_ptr->type() != other.m_ptr->type()) - return false; +bool Json::operator==(const Json &other) const { + if (m_ptr == other.m_ptr) return true; + if (m_ptr->type() != other.m_ptr->type()) return false; - return m_ptr->equals(other.m_ptr.get()); + return m_ptr->equals(other.m_ptr.get()); } -bool Json::operator< (const Json &other) const { - if (m_ptr == other.m_ptr) - return false; - if (m_ptr->type() != other.m_ptr->type()) - return m_ptr->type() < other.m_ptr->type(); +bool Json::operator<(const Json &other) const { + if (m_ptr == other.m_ptr) return false; + if (m_ptr->type() != other.m_ptr->type()) + return m_ptr->type() < other.m_ptr->type(); - return m_ptr->less(other.m_ptr.get()); + return m_ptr->less(other.m_ptr.get()); } /* * * * * * * * * * * * * * * * * * * * @@ -330,17 +345,18 @@ bool Json::operator< (const Json &other) const { * Format char c suitable for printing in an error message. */ static inline string esc(char c) { - char buf[12]; - if (static_cast(c) >= 0x20 && static_cast(c) <= 0x7f) { - snprintf(buf, sizeof buf, "'%c' (%d)", c, c); - } else { - snprintf(buf, sizeof buf, "(%d)", c); - } - return string(buf); + char buf[12]; + if (static_cast(c) >= 0x20 && static_cast(c) <= 0x7f) { + snprintf(buf, sizeof buf, "'%c' (%d)", c, c); + } else { + snprintf(buf, sizeof buf, "(%d)", c); + } + return string(buf); } -static inline bool in_range(long x, long lower, long upper) { - return (x >= lower && x <= upper); +template +static inline bool in_range(T x, T lower, T upper) { + return (x >= lower && x <= upper); } namespace { @@ -349,440 +365,417 @@ namespace { * Object that tracks all state of an in-progress parse. */ struct JsonParser final { - /* State - */ - const string &str; - size_t i; - string &err; - bool failed; - const JsonParse strategy; - - /* fail(msg, err_ret = Json()) - * - * Mark this parse as failed. - */ - Json fail(string &&msg) { - return fail(move(msg), Json()); - } - - template - T fail(string &&msg, const T err_ret) { - if (!failed) - err = std::move(msg); - failed = true; - return err_ret; - } - - /* consume_whitespace() - * - * Advance until the current character is non-whitespace. - */ - void consume_whitespace() { - while (str[i] == ' ' || str[i] == '\r' || str[i] == '\n' || str[i] == '\t') - i++; - } - - /* consume_comment() - * - * Advance comments (c-style inline and multiline). - */ - bool consume_comment() { - bool comment_found = false; - if (str[i] == '/') { + /* State + */ + const char *str; + const size_t str_len; + size_t i; + string *err; + bool failed; + const JsonParse strategy; + + /* fail(msg, err_ret = Json()) + * + * Mark this parse as failed. + */ + Json fail(string &&msg) { return fail(move(msg), Json()); } + + template + T fail(string &&msg, const T err_ret) { + if (!failed) *err = std::move(msg); + failed = true; + return err_ret; + } + + /* consume_whitespace() + * + * Advance until the current character is non-whitespace. + */ + void consume_whitespace() { + while (str[i] == ' ' || str[i] == '\r' || str[i] == '\n' || str[i] == '\t') + i++; + } + + /* consume_comment() + * + * Advance comments (c-style inline and multiline). + */ + bool consume_comment() { + bool comment_found = false; + if (str[i] == '/') { + i++; + if (i == str_len) + return fail("Unexpected end of input after start of comment", false); + if (str[i] == '/') { // inline comment i++; - if (i == str.size()) - return fail("Unexpected end of input after start of comment", false); - if (str[i] == '/') { // inline comment + // advance until next line, or end of input + while (i < str_len && str[i] != '\n') { i++; - // advance until next line, or end of input - while (i < str.size() && str[i] != '\n') { - i++; - } - comment_found = true; - } else if (str[i] == '*') { // multiline comment + } + comment_found = true; + } else if (str[i] == '*') { // multiline comment + i++; + if (i > str_len - 2) + return fail("Unexpected end of input inside multi-line comment", + false); + // advance until closing tokens + while (!(str[i] == '*' && str[i + 1] == '/')) { i++; - if (i > str.size()-2) - return fail("Unexpected end of input inside multi-line comment", false); - // advance until closing tokens - while (!(str[i] == '*' && str[i+1] == '/')) { - i++; - if (i > str.size()-2) - return fail("Unexpected end of input inside multi-line comment", false); - } - i += 2; - comment_found = true; - } else { - return fail("Malformed comment", false); + if (i > str_len - 2) + return fail("Unexpected end of input inside multi-line comment", + false); } + i += 2; + comment_found = true; + } else { + return fail("Malformed comment", false); } - return comment_found; } + return comment_found; + } + + /* consume_garbage() + * + * Advance until the current character is non-whitespace and non-comment. + */ + void consume_garbage() { + consume_whitespace(); + if (strategy == JsonParse::COMMENTS) { + bool comment_found = false; + do { + comment_found = consume_comment(); + if (failed) return; + consume_whitespace(); + } while (comment_found); + } + } + + /* get_next_token() + * + * Return the next non-whitespace character. If the end of the input is + * reached, flag an error and return 0. + */ + char get_next_token() { + consume_garbage(); + if (failed) return char{0}; + if (i == str_len) return fail("Unexpected end of input", char{0}); + + return str[i++]; + } + + /* encode_utf8(pt, out) + * + * Encode pt as UTF-8 and add it to out. + */ + void encode_utf8(int64_t pt, string* out) { + if (pt < 0) return; + + if (pt < 0x80) { + *out += static_cast(pt); + } else if (pt < 0x800) { + *out += static_cast((pt >> 6) | 0xC0); + *out += static_cast((pt & 0x3F) | 0x80); + } else if (pt < 0x10000) { + *out += static_cast((pt >> 12) | 0xE0); + *out += static_cast(((pt >> 6) & 0x3F) | 0x80); + *out += static_cast((pt & 0x3F) | 0x80); + } else { + *out += static_cast((pt >> 18) | 0xF0); + *out += static_cast(((pt >> 12) & 0x3F) | 0x80); + *out += static_cast(((pt >> 6) & 0x3F) | 0x80); + *out += static_cast((pt & 0x3F) | 0x80); + } + } + + /* parse_string() + * + * Parse a string, starting at the current position. + */ + string parse_string() { + string out; + int64_t last_escaped_codepoint = -1; + while (true) { + if (i == str_len) return fail("Unexpected end of input in string", ""); + + char ch = str[i++]; + + if (ch == '"') { + encode_utf8(last_escaped_codepoint, &out); + return out; + } - /* consume_garbage() - * - * Advance until the current character is non-whitespace and non-comment. - */ - void consume_garbage() { - consume_whitespace(); - if (strategy == JsonParse::COMMENTS) { - bool comment_found = false; - do { - comment_found = consume_comment(); - if (failed) return; - consume_whitespace(); - } while (comment_found); + if (in_range(ch, 0, 0x1f)) + return fail("Unescaped " + esc(ch) + " in string", ""); + + // The usual case: non-escaped characters + if (ch != '\\') { + encode_utf8(last_escaped_codepoint, &out); + last_escaped_codepoint = -1; + out += ch; + continue; } - } - /* get_next_token() - * - * Return the next non-whitespace character. If the end of the input is reached, - * flag an error and return 0. - */ - char get_next_token() { - consume_garbage(); - if (failed) return char{0}; - if (i == str.size()) - return fail("Unexpected end of input", char{0}); - - return str[i++]; - } + // Handle escapes + if (i == str_len) return fail("Unexpected end of input in string", ""); - /* encode_utf8(pt, out) - * - * Encode pt as UTF-8 and add it to out. - */ - void encode_utf8(long pt, string & out) { - if (pt < 0) - return; - - if (pt < 0x80) { - out += static_cast(pt); - } else if (pt < 0x800) { - out += static_cast((pt >> 6) | 0xC0); - out += static_cast((pt & 0x3F) | 0x80); - } else if (pt < 0x10000) { - out += static_cast((pt >> 12) | 0xE0); - out += static_cast(((pt >> 6) & 0x3F) | 0x80); - out += static_cast((pt & 0x3F) | 0x80); - } else { - out += static_cast((pt >> 18) | 0xF0); - out += static_cast(((pt >> 12) & 0x3F) | 0x80); - out += static_cast(((pt >> 6) & 0x3F) | 0x80); - out += static_cast((pt & 0x3F) | 0x80); - } - } + ch = str[i++]; - /* parse_string() - * - * Parse a string, starting at the current position. - */ - string parse_string() { - string out; - long last_escaped_codepoint = -1; - while (true) { - if (i == str.size()) - return fail("Unexpected end of input in string", ""); - - char ch = str[i++]; - - if (ch == '"') { - encode_utf8(last_escaped_codepoint, out); - return out; - } - - if (in_range(ch, 0, 0x1f)) - return fail("Unescaped " + esc(ch) + " in string", ""); - - // The usual case: non-escaped characters - if (ch != '\\') { - encode_utf8(last_escaped_codepoint, out); - last_escaped_codepoint = -1; - out += ch; - continue; - } - - // Handle escapes - if (i == str.size()) - return fail("Unexpected end of input in string", ""); - - ch = str[i++]; - - if (ch == 'u') { - // Extract 4-byte escape sequence - string esc = str.substr(i, 4); - // Explicitly check length of the substring. The following loop - // relies on std::string returning the terminating NUL when - // accessing str[length]. Checking here reduces brittleness. - if (esc.length() < 4) { - return fail("Bad \\u escape: " + esc, ""); - } - for (size_t j = 0; j < 4; j++) { - if (!in_range(esc[j], 'a', 'f') && !in_range(esc[j], 'A', 'F') - && !in_range(esc[j], '0', '9')) - return fail("Bad \\u escape: " + esc, ""); - } - - long codepoint = strtol(esc.data(), nullptr, 16); - - // JSON specifies that characters outside the BMP shall be encoded as a pair - // of 4-hex-digit \u escapes encoding their surrogate pair components. Check - // whether we're in the middle of such a beast: the previous codepoint was an - // escaped lead (high) surrogate, and this is a trail (low) surrogate. - if (in_range(last_escaped_codepoint, 0xD800, 0xDBFF) - && in_range(codepoint, 0xDC00, 0xDFFF)) { - // Reassemble the two surrogate pairs into one astral-plane character, per - // the UTF-16 algorithm. - encode_utf8((((last_escaped_codepoint - 0xD800) << 10) - | (codepoint - 0xDC00)) + 0x10000, out); - last_escaped_codepoint = -1; - } else { - encode_utf8(last_escaped_codepoint, out); - last_escaped_codepoint = codepoint; - } - - i += 4; - continue; - } - - encode_utf8(last_escaped_codepoint, out); - last_escaped_codepoint = -1; - - if (ch == 'b') { - out += '\b'; - } else if (ch == 'f') { - out += '\f'; - } else if (ch == 'n') { - out += '\n'; - } else if (ch == 'r') { - out += '\r'; - } else if (ch == 't') { - out += '\t'; - } else if (ch == '"' || ch == '\\' || ch == '/') { - out += ch; - } else { - return fail("Invalid escape character " + esc(ch), ""); - } + if (ch == 'u') { + // Extract 4-byte escape sequence + string esc = string(str + i, 4); + // Explicitly check length of the substring. The following loop + // relies on std::string returning the terminating NUL when + // accessing str[length]. Checking here reduces brittleness. + if (esc.length() < 4) { + return fail("Bad \\u escape: " + esc, ""); + } + for (size_t j = 0; j < 4; j++) { + if (!in_range(esc[j], 'a', 'f') && !in_range(esc[j], 'A', 'F') && + !in_range(esc[j], '0', '9')) + return fail("Bad \\u escape: " + esc, ""); } - } - /* parse_number() - * - * Parse a double. - */ - Json parse_number() { - size_t start_pos = i; - - if (str[i] == '-') - i++; - - // Integer part - if (str[i] == '0') { - i++; - if (in_range(str[i], '0', '9')) - return fail("Leading 0s not permitted in numbers"); - } else if (in_range(str[i], '1', '9')) { - i++; - while (in_range(str[i], '0', '9')) - i++; + int64_t codepoint = + static_cast(strtol(esc.data(), nullptr, 16)); + + // JSON specifies that characters outside the BMP shall be encoded as a + // pair of 4-hex-digit \u escapes encoding their surrogate pair + // components. Check whether we're in the middle of such a beast: the + // previous codepoint was an escaped lead (high) surrogate, and this is + // a trail (low) surrogate. + if (in_range(last_escaped_codepoint, 0xD800, 0xDBFF) && + in_range(codepoint, 0xDC00, 0xDFFF)) { + // Reassemble the two surrogate pairs into one astral-plane character, + // per the UTF-16 algorithm. + encode_utf8((((last_escaped_codepoint - 0xD800) << 10) | + (codepoint - 0xDC00)) + + 0x10000, + &out); + last_escaped_codepoint = -1; } else { - return fail("Invalid " + esc(str[i]) + " in number"); + encode_utf8(last_escaped_codepoint, &out); + last_escaped_codepoint = codepoint; } - if (str[i] != '.' && str[i] != 'e' && str[i] != 'E' - && (i - start_pos) <= static_cast(std::numeric_limits::digits10)) { - return std::atoi(str.c_str() + start_pos); - } + i += 4; + continue; + } - // Decimal part - if (str[i] == '.') { - i++; - if (!in_range(str[i], '0', '9')) - return fail("At least one digit required in fractional part"); + encode_utf8(last_escaped_codepoint, &out); + last_escaped_codepoint = -1; + + if (ch == 'b') { + out += '\b'; + } else if (ch == 'f') { + out += '\f'; + } else if (ch == 'n') { + out += '\n'; + } else if (ch == 'r') { + out += '\r'; + } else if (ch == 't') { + out += '\t'; + } else if (ch == '"' || ch == '\\' || ch == '/') { + out += ch; + } else { + return fail("Invalid escape character " + esc(ch), ""); + } + } + } + + /* parse_number() + * + * Parse a double. + */ + Json parse_number() { + size_t start_pos = i; + + if (str[i] == '-') i++; + + // Integer part + if (str[i] == '0') { + i++; + if (in_range(str[i], '0', '9')) + return fail("Leading 0s not permitted in numbers"); + } else if (in_range(str[i], '1', '9')) { + i++; + while (in_range(str[i], '0', '9')) i++; + } else { + return fail("Invalid " + esc(str[i]) + " in number"); + } - while (in_range(str[i], '0', '9')) - i++; - } + if (str[i] != '.' && str[i] != 'e' && str[i] != 'E' && + (i - start_pos) <= + static_cast(std::numeric_limits::digits10)) { + return Json(std::atoi(str + start_pos)); + } + + // Decimal part + if (str[i] == '.') { + i++; + if (!in_range(str[i], '0', '9')) + return fail("At least one digit required in fractional part"); - // Exponent part - if (str[i] == 'e' || str[i] == 'E') { - i++; + while (in_range(str[i], '0', '9')) i++; + } - if (str[i] == '+' || str[i] == '-') - i++; + // Exponent part + if (str[i] == 'e' || str[i] == 'E') { + i++; - if (!in_range(str[i], '0', '9')) - return fail("At least one digit required in exponent"); + if (str[i] == '+' || str[i] == '-') i++; - while (in_range(str[i], '0', '9')) - i++; - } + if (!in_range(str[i], '0', '9')) + return fail("At least one digit required in exponent"); - return std::strtod(str.c_str() + start_pos, nullptr); + while (in_range(str[i], '0', '9')) i++; } - /* expect(str, res) - * - * Expect that 'str' starts at the character that was just read. If it does, advance - * the input and return res. If not, flag an error. - */ - Json expect(const string &expected, Json res) { - CHECK_NE(i, 0) - i--; - if (str.compare(i, expected.length(), expected) == 0) { - i += expected.length(); - return res; - } else { - return fail("Parse error: expected " + expected + ", got " + str.substr(i, expected.length())); - } + return Json(std::strtod(str + start_pos, nullptr)); + } + + /* expect(str, res) + * + * Expect that 'str' starts at the character that was just read. If it does, + * advance the input and return res. If not, flag an error. + */ + Json expect(const string &expected, Json res) { + CHECK_NE(i, 0) + i--; + auto substr = string(str + i, expected.length()); + if (substr == expected) { + i += expected.length(); + return res; + } else { + return fail("Parse error: expected " + expected + ", got " + substr); + } + } + + /* parse_json() + * + * Parse a JSON object. + */ + Json parse_json(int depth) { + if (depth > max_depth) { + return fail("Exceeded maximum nesting depth"); } - /* parse_json() - * - * Parse a JSON object. - */ - Json parse_json(int depth) { - if (depth > max_depth) { - return fail("Exceeded maximum nesting depth"); - } + char ch = get_next_token(); + if (failed) return Json(); - char ch = get_next_token(); - if (failed) - return Json(); + if (ch == '-' || (ch >= '0' && ch <= '9')) { + i--; + return parse_number(); + } - if (ch == '-' || (ch >= '0' && ch <= '9')) { - i--; - return parse_number(); - } + if (ch == 't') return expect("true", Json(true)); - if (ch == 't') - return expect("true", true); + if (ch == 'f') return expect("false", Json(false)); - if (ch == 'f') - return expect("false", false); + if (ch == 'n') return expect("null", Json()); - if (ch == 'n') - return expect("null", Json()); + if (ch == '"') return Json(parse_string()); - if (ch == '"') - return parse_string(); + if (ch == '{') { + map data; + ch = get_next_token(); + if (ch == '}') return Json(data); - if (ch == '{') { - map data; - ch = get_next_token(); - if (ch == '}') - return data; + while (1) { + if (ch != '"') return fail("Expected '\"' in object, got " + esc(ch)); - while (1) { - if (ch != '"') - return fail("Expected '\"' in object, got " + esc(ch)); + string key = parse_string(); + if (failed) return Json(); - string key = parse_string(); - if (failed) - return Json(); + ch = get_next_token(); + if (ch != ':') return fail("Expected ':' in object, got " + esc(ch)); - ch = get_next_token(); - if (ch != ':') - return fail("Expected ':' in object, got " + esc(ch)); + data[std::move(key)] = parse_json(depth + 1); + if (failed) return Json(); - data[std::move(key)] = parse_json(depth + 1); - if (failed) - return Json(); + ch = get_next_token(); + if (ch == '}') break; + if (ch != ',') return fail("Expected ',' in object, got " + esc(ch)); - ch = get_next_token(); - if (ch == '}') - break; - if (ch != ',') - return fail("Expected ',' in object, got " + esc(ch)); + ch = get_next_token(); + } + return Json(data); + } - ch = get_next_token(); - } - return data; - } + if (ch == '[') { + vector data; + ch = get_next_token(); + if (ch == ']') return Json(data); - if (ch == '[') { - vector data; - ch = get_next_token(); - if (ch == ']') - return data; - - while (1) { - i--; - data.push_back(parse_json(depth + 1)); - if (failed) - return Json(); - - ch = get_next_token(); - if (ch == ']') - break; - if (ch != ',') - return fail("Expected ',' in list, got " + esc(ch)); - - ch = get_next_token(); - (void)ch; - } - return data; - } + while (1) { + i--; + data.push_back(parse_json(depth + 1)); + if (failed) return Json(); - return fail("Expected value, got " + esc(ch)); + ch = get_next_token(); + if (ch == ']') break; + if (ch != ',') return fail("Expected ',' in list, got " + esc(ch)); + + ch = get_next_token(); + (void)ch; + } + return Json(data); } + + return fail("Expected value, got " + esc(ch)); + } }; } // namespace -Json Json::parse(const string &in, string &err, JsonParse strategy) { - JsonParser parser { in, 0, err, false, strategy }; - Json result = parser.parse_json(0); +Json Json::parse(const string &in, string *err, JsonParse strategy) { + JsonParser parser{in.c_str(), in.size(), 0, err, false, strategy}; + Json result = parser.parse_json(0); - // Check for any trailing garbage - parser.consume_garbage(); - if (parser.failed) - return Json(); - if (parser.i != in.size()) - return parser.fail("Unexpected trailing " + esc(in[parser.i])); + // Check for any trailing garbage + parser.consume_garbage(); + if (parser.failed) return Json(); + if (parser.i != in.size()) + return parser.fail("Unexpected trailing " + esc(in[parser.i])); - return result; + return result; } // Documented in json11.hpp vector Json::parse_multi(const string &in, - std::string::size_type &parser_stop_pos, - string &err, - JsonParse strategy) { - JsonParser parser { in, 0, err, false, strategy }; - parser_stop_pos = 0; - vector json_vec; - while (parser.i != in.size() && !parser.failed) { - json_vec.push_back(parser.parse_json(0)); - if (parser.failed) - break; - - // Check for another object - parser.consume_garbage(); - if (parser.failed) - break; - parser_stop_pos = parser.i; - } - return json_vec; + std::string::size_type *parser_stop_pos, + string *err, JsonParse strategy) { + JsonParser parser{in.c_str(), in.size(), 0, err, false, strategy}; + *parser_stop_pos = 0; + vector json_vec; + while (parser.i != in.size() && !parser.failed) { + json_vec.push_back(parser.parse_json(0)); + if (parser.failed) break; + + // Check for another object + parser.consume_garbage(); + if (parser.failed) break; + *parser_stop_pos = parser.i; + } + return json_vec; } /* * * * * * * * * * * * * * * * * * * * * Shape-checking */ -bool Json::has_shape(const shape & types, string & err) const { - if (!is_object()) { - err = "Expected JSON object, got " + dump(); - return false; - } +bool Json::has_shape(const shape &types, string *err) const { + if (!is_object()) { + *err = "Expected JSON object, got " + dump(); + return false; + } - for (auto & item : types) { - if ((*this)[item.first].type() != item.second) { - err = "Bad type for " + item.first + " in " + dump(); - return false; - } + for (auto &item : types) { + if ((*this)[item.first].type() != item.second) { + *err = "Bad type for " + item.first + " in " + dump(); + return false; } + } - return true; + return true; } } // namespace json11 diff --git a/src/io/tree.cpp b/src/io/tree.cpp index 4c8fb4eb0e20..be928b7e3124 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -57,7 +57,7 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin, decision_type_[new_node_idx] = 0; SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask); SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask); - SetMissingType(&decision_type_[new_node_idx], missing_type); + SetMissingType(&decision_type_[new_node_idx], static_cast(missing_type)); threshold_in_bin_[new_node_idx] = threshold_bin; threshold_[new_node_idx] = threshold_double; ++num_leaves_; @@ -71,7 +71,7 @@ int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32 int new_node_idx = num_leaves_ - 1; decision_type_[new_node_idx] = 0; SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask); - SetMissingType(&decision_type_[new_node_idx], missing_type); + SetMissingType(&decision_type_[new_node_idx], static_cast(missing_type)); threshold_in_bin_[new_node_idx] = num_cat_; threshold_[new_node_idx] = num_cat_; ++num_cat_; @@ -335,7 +335,7 @@ std::string Tree::NumericalDecisionIfElse(int node) const { std::stringstream str_buf; uint8_t missing_type = GetMissingType(decision_type_[node]); bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask); - if (missing_type == MissingType::None + if (missing_type == MissingType::None || (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) { str_buf << "if (fval <= " << threshold_[node] << ") {"; } else if (missing_type == MissingType::Zero) { diff --git a/src/lightgbm_R.cpp b/src/lightgbm_R.cpp index 7d3a8dcafe80..653d3cc5d0ed 100644 --- a/src/lightgbm_R.cpp +++ b/src/lightgbm_R.cpp @@ -9,6 +9,8 @@ #include #include +#include + #include #include #include @@ -16,8 +18,6 @@ #include #include -#include - #define COL_MAJOR (0) #define R_API_BEGIN() \ diff --git a/src/network/socket_wrapper.hpp b/src/network/socket_wrapper.hpp index 244bc0db548f..70f9586b99c5 100644 --- a/src/network/socket_wrapper.hpp +++ b/src/network/socket_wrapper.hpp @@ -218,6 +218,7 @@ class TcpSocket { continue; } if (ifa->ifa_addr->sa_family == AF_INET) { + // NOLINTNEXTLINE tmpAddrPtr = &((struct sockaddr_in *)ifa->ifa_addr)->sin_addr; char addressBuffer[INET_ADDRSTRLEN]; inet_ntop(AF_INET, tmpAddrPtr, addressBuffer, INET_ADDRSTRLEN); diff --git a/src/treelearner/monotone_constraints.hpp b/src/treelearner/monotone_constraints.hpp index 4d804d7fbfa0..dcad0d6d3288 100644 --- a/src/treelearner/monotone_constraints.hpp +++ b/src/treelearner/monotone_constraints.hpp @@ -62,6 +62,24 @@ class LeafConstraintsBase { const std::vector& best_split_per_leaf) = 0; inline static LeafConstraintsBase* Create(const Config* config, int num_leaves); + + double ComputeMonotoneSplitGainPenalty(int leaf_index, double penalization) { + int depth = tree_->leaf_depth(leaf_index); + if (penalization >= depth + 1.) { + return kEpsilon; + } + if (penalization <= 1.) { + return 1. - penalization / pow(2., depth) + kEpsilon; + } + return 1. - pow(2, penalization - 1. - depth) + kEpsilon; + } + + void ShareTreePointer(const Tree* tree) { + tree_ = tree; + } + + private: + const Tree* tree_; }; class BasicLeafConstraints : public LeafConstraintsBase { diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index b7569d22c8e2..6c4390553efd 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -165,6 +165,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians auto tree = std::unique_ptr(new Tree(config_->num_leaves)); auto tree_prt = tree.get(); + constraints_->ShareTreePointer(tree_prt); + // root leaf int left_leaf = 0; int cur_depth = 1; @@ -692,6 +694,11 @@ void SerialTreeLearner::ComputeBestSplitForFeature( cegb_->DetlaGain(feature_index, real_fidx, leaf_splits->leaf_index(), num_data, new_split); } + if (new_split.monotone_type != 0) { + double penalty = constraints_->ComputeMonotoneSplitGainPenalty( + leaf_splits->leaf_index(), config_->monotone_penalty); + new_split.gain *= penalty; + } if (new_split > *best_split) { *best_split = new_split; } diff --git a/tests/c_api_test/test_.py b/tests/c_api_test/test_.py index d5f837f515bf..b138de4a0ef4 100644 --- a/tests/c_api_test/test_.py +++ b/tests/c_api_test/test_.py @@ -58,7 +58,7 @@ def c_array(ctype, values): def c_str(string): - return ctypes.c_char_p(string.encode('ascii')) + return ctypes.c_char_p(string.encode('utf-8')) def load_from_file(filename, reference): diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 51be083a9f01..c5348e9858c4 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -747,6 +747,21 @@ def test_feature_name(self): gbm = lgb.train(params, lgb_train, num_boost_round=5, feature_name=feature_names_with_space) self.assertListEqual(feature_names, gbm.feature_name()) + def test_feature_name_with_non_ascii(self): + X_train = np.random.normal(size=(100, 4)) + y_train = np.random.random(100) + # This has non-ascii strings. + feature_names = [u'F_零', u'F_一', u'F_二', u'F_三'] + params = {'verbose': -1} + lgb_train = lgb.Dataset(X_train, y_train) + + gbm = lgb.train(params, lgb_train, num_boost_round=5, feature_name=feature_names) + self.assertListEqual(feature_names, gbm.feature_name()) + gbm.save_model('lgb.model') + + gbm2 = lgb.Booster(model_file='lgb.model') + self.assertListEqual(feature_names, gbm2.feature_name()) + def test_save_load_copy_pickle(self): def train_and_predict(init_model=None, return_model=False): X, y = load_boston(True) @@ -1036,7 +1051,7 @@ def generate_trainset_for_monotone_constraints_tests(self, x3_to_category=True): categorical_features = [] if x3_to_category: categorical_features = [2] - trainset = lgb.Dataset(x, label=y, categorical_feature=categorical_features) + trainset = lgb.Dataset(x, label=y, categorical_feature=categorical_features, free_raw_data=False) return trainset def test_monotone_constraints(self): @@ -1071,8 +1086,8 @@ def is_correctly_constrained(learner, x3_to_category=True): return True for test_with_categorical_variable in [True, False]: + trainset = self.generate_trainset_for_monotone_constraints_tests(test_with_categorical_variable) for monotone_constraints_method in ["basic", "intermediate"]: - trainset = self.generate_trainset_for_monotone_constraints_tests(test_with_categorical_variable) params = { 'min_data': 20, 'num_leaves': 20, @@ -1083,6 +1098,76 @@ def is_correctly_constrained(learner, x3_to_category=True): constrained_model = lgb.train(params, trainset) self.assertTrue(is_correctly_constrained(constrained_model, test_with_categorical_variable)) + def test_monotone_penalty(self): + def are_first_splits_non_monotone(tree, n, monotone_constraints): + if n <= 0: + return True + if "leaf_value" in tree: + return True + if monotone_constraints[tree["split_feature"]] != 0: + return False + return (are_first_splits_non_monotone(tree["left_child"], n - 1, monotone_constraints) + and are_first_splits_non_monotone(tree["right_child"], n - 1, monotone_constraints)) + + def are_there_monotone_splits(tree, monotone_constraints): + if "leaf_value" in tree: + return False + if monotone_constraints[tree["split_feature"]] != 0: + return True + return (are_there_monotone_splits(tree["left_child"], monotone_constraints) + or are_there_monotone_splits(tree["right_child"], monotone_constraints)) + + max_depth = 5 + monotone_constraints = [1, -1, 0] + penalization_parameter = 2.0 + trainset = self.generate_trainset_for_monotone_constraints_tests(x3_to_category=False) + for monotone_constraints_method in ["basic", "intermediate"]: + params = { + 'max_depth': max_depth, + 'monotone_constraints': monotone_constraints, + 'monotone_penalty': penalization_parameter, + "monotone_constraints_method": monotone_constraints_method, + } + constrained_model = lgb.train(params, trainset, 10) + dumped_model = constrained_model.dump_model()["tree_info"] + for tree in dumped_model: + self.assertTrue(are_first_splits_non_monotone(tree["tree_structure"], int(penalization_parameter), + monotone_constraints)) + self.assertTrue(are_there_monotone_splits(tree["tree_structure"], monotone_constraints)) + + # test if a penalty as high as the depth indeed prohibits all monotone splits + def test_monotone_penalty_max(self): + max_depth = 5 + monotone_constraints = [1, -1, 0] + penalization_parameter = max_depth + trainset_constrained_model = self.generate_trainset_for_monotone_constraints_tests(x3_to_category=False) + x = trainset_constrained_model.data + y = trainset_constrained_model.label + x3_negatively_correlated_with_y = x[:, 2] + trainset_unconstrained_model = lgb.Dataset(x3_negatively_correlated_with_y.reshape(-1, 1), label=y) + params_constrained_model = { + 'monotone_constraints': monotone_constraints, + 'monotone_penalty': penalization_parameter, + "max_depth": max_depth, + "gpu_use_dp": True, + } + params_unconstrained_model = { + "max_depth": max_depth, + "gpu_use_dp": True, + } + + unconstrained_model = lgb.train(params_unconstrained_model, trainset_unconstrained_model, 10) + unconstrained_model_predictions = unconstrained_model.\ + predict(x3_negatively_correlated_with_y.reshape(-1, 1)) + + for monotone_constraints_method in ["basic", "intermediate"]: + params_constrained_model["monotone_constraints_method"] = monotone_constraints_method + # The penalization is so high that the first 2 features should not be used here + constrained_model = lgb.train(params_constrained_model, trainset_constrained_model, 10) + + # Check that a very high penalization is the same as not using the features at all + np.testing.assert_array_equal(constrained_model.predict(x), unconstrained_model_predictions) + def test_max_bin_by_feature(self): col1 = np.arange(0, 100)[:, np.newaxis] col2 = np.zeros((100, 1)) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 747508f0d5ce..1fb44d0dc12c 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -293,6 +293,11 @@ def test_sklearn_integration(self): check_name = check.func.__name__ if hasattr(check, 'func') else check.__name__ if check_name == 'check_estimators_nan_inf': continue # skip test because LightGBM deals with nan + elif check_name == "check_no_attributes_set_in_init": + # skip test because scikit-learn incorrectly asserts that + # private attributes cannot be set in __init__ + # (see https://github.com/microsoft/LightGBM/issues/2628) + continue try: check(name, estimator) except SkipTest as message: