From c4e1480741da504c17aed470d3b261f62d1ce4ce Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 14 Mar 2023 17:54:53 +0800 Subject: [PATCH] Check inf in data for all types of DMatrix. --- python-package/xgboost/testing/data.py | 15 +++++++++++ src/common/error_msg.h | 4 +++ src/data/data.cc | 11 +++++--- src/data/device_adapter.cuh | 19 +++++++++++++- src/data/ellpack_page.cu | 26 ++++++++++--------- src/data/gradient_index.h | 20 +++++++++----- src/data/simple_dmatrix.cuh | 22 +++++++++------- .../test_device_quantile_dmatrix.py | 7 +++++ tests/python/test_quantile_dmatrix.py | 6 ++++- 9 files changed, 97 insertions(+), 33 deletions(-) diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index 4f79d7358e93..295c3294d715 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -2,7 +2,10 @@ from typing import Any, Generator, Tuple, Union import numpy as np +import pytest +from numpy.random import Generator as RNG +import xgboost from xgboost.data import pandas_pyarrow_mapper @@ -179,3 +182,15 @@ def pd_arrow_dtypes() -> Generator: dtype=pd.ArrowDtype(pa.bool_()), ) yield orig, df + + +def check_inf(rng: RNG) -> None: + X = rng.random(size=32).reshape(8, 4) + y = rng.random(size=8) + X[5, 2] = np.inf + + with pytest.raises(ValueError, match="Input data contains `inf`"): + xgboost.QuantileDMatrix(X, y) + + with pytest.raises(ValueError, match="Input data contains `inf`"): + xgboost.DMatrix(X, y) diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 48a2c92a4ca5..484595316e26 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -20,5 +20,9 @@ constexpr StringView GroupSize() { constexpr StringView LabelScoreSize() { return "The size of label doesn't match the size of prediction."; } + +constexpr StringView InfInData() { + return "Input data contains `inf` while `missing` is not set to `inf`"; +} } // namespace xgboost::error #endif // XGBOOST_COMMON_ERROR_MSG_H_ diff --git a/src/data/data.cc b/src/data/data.cc index d24048a2ab23..aa96a1bc801a 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -10,13 +10,16 @@ #include #include "../collective/communicator-inl.h" -#include "../common/algorithm.h" // StableSort -#include "../common/api_entry.h" // XGBAPIThreadLocalEntry +#include "../collective/communicator.h" +#include "../common/common.h" +#include "../common/algorithm.h" // for StableSort +#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry +#include "../common/error_msg.h" // for InfInData #include "../common/group_data.h" #include "../common/io.h" #include "../common/linalg_op.h" #include "../common/math.h" -#include "../common/numeric.h" // Iota +#include "../common/numeric.h" // for Iota #include "../common/threading_utils.h" #include "../common/version.h" #include "../data/adapter.h" @@ -1144,7 +1147,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread }); } exec.Rethrow(); - CHECK(valid) << "Input data contains `inf` or `nan`"; + CHECK(valid) << error::InfInData(); for (const auto & max : max_columns_vector) { max_columns = std::max(max_columns, max[0]); } diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 56c494dd1b12..494fb7d1c438 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -4,7 +4,10 @@ */ #ifndef XGBOOST_DATA_DEVICE_ADAPTER_H_ #define XGBOOST_DATA_DEVICE_ADAPTER_H_ -#include // for size_t +#include // for make_counting_iterator +#include // for none_of + +#include // for size_t #include #include #include @@ -213,6 +216,20 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span offset, static_cast(0), thrust::maximum()); return row_stride; } + +/** + * \brief Check there's no inf in data. + */ +template +bool HasInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) { + auto counting = thrust::make_counting_iterator(0llu); + auto value_iter = dh::MakeTransformIterator( + counting, [=] XGBOOST_DEVICE(std::size_t idx) { return batch.GetElement(idx).value; }); + auto valid = + thrust::none_of(value_iter, value_iter + batch.Size(), + [is_valid] XGBOOST_DEVICE(float v) { return is_valid(v) && std::isinf(v); }); + return valid; +} }; // namespace data } // namespace xgboost #endif // XGBOOST_DATA_DEVICE_ADAPTER_H_ diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 99e17d886df9..d631407a1eb0 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2019-2022 XGBoost contributors +/** + * Copyright 2019-2023 by XGBoost contributors */ #include #include @@ -9,7 +9,7 @@ #include "../common/random.h" #include "../common/transform_iterator.h" // MakeIndexTransformIter #include "./ellpack_page.cuh" -#include "device_adapter.cuh" +#include "device_adapter.cuh" // for HasInfInData #include "gradient_index.h" #include "xgboost/data.h" @@ -189,9 +189,8 @@ struct TupleScanOp { // Here the data is already correctly ordered and simply needs to be compacted // to remove missing data template -void CopyDataToEllpack(const AdapterBatchT &batch, - common::Span feature_types, - EllpackPageImpl *dst, int device_idx, float missing) { +void CopyDataToEllpack(const AdapterBatchT& batch, common::Span feature_types, + EllpackPageImpl* dst, int device_idx, float missing) { // Some witchcraft happens here // The goal is to copy valid elements out of the input to an ELLPACK matrix // with a given row stride, using no extra working memory Standard stream @@ -201,6 +200,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch, // correct output position auto counting = thrust::make_counting_iterator(0llu); data::IsValidFunctor is_valid(missing); + bool valid = data::HasInfInData(batch, is_valid); + CHECK(valid) << error::InfInData(); + auto key_iter = dh::MakeTransformIterator( counting, [=] __device__(size_t idx) { @@ -239,9 +241,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch, cub::DispatchScan, cub::NullType, int64_t>; #if THRUST_MAJOR_VERSION >= 2 - DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out, - TupleScanOp(), cub::NullType(), batch.Size(), - nullptr); + dh::safe_cuda(DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out, + TupleScanOp(), cub::NullType(), batch.Size(), + nullptr)); #else DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out, TupleScanOp(), cub::NullType(), batch.Size(), @@ -249,9 +251,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch, #endif dh::TemporaryArray temp_storage(temp_storage_bytes); #if THRUST_MAJOR_VERSION >= 2 - DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes, - key_value_index_iter, out, TupleScanOp(), - cub::NullType(), batch.Size(), nullptr); + dh::safe_cuda(DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes, + key_value_index_iter, out, TupleScanOp(), + cub::NullType(), batch.Size(), nullptr)); #else DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes, key_value_index_iter, out, TupleScanOp(), diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 9eba9637fbea..3cb0709bd95d 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -1,21 +1,23 @@ -/*! - * Copyright 2017-2022 by XGBoost Contributors +/** + * Copyright 2017-2023 by XGBoost Contributors * \brief Data type for fast histogram aggregation. */ #ifndef XGBOOST_DATA_GRADIENT_INDEX_H_ #define XGBOOST_DATA_GRADIENT_INDEX_H_ -#include // std::min -#include // std::uint32_t -#include // std::size_t +#include // for min +#include // for atomic +#include // for uint32_t +#include // for size_t #include #include #include "../common/categorical.h" +#include "../common/error_msg.h" // for InfInData #include "../common/hist_util.h" #include "../common/numeric.h" #include "../common/threading_utils.h" -#include "../common/transform_iterator.h" // common::MakeIndexTransformIter +#include "../common/transform_iterator.h" // for MakeIndexTransformIter #include "adapter.h" #include "proxy_dmatrix.h" #include "xgboost/base.h" @@ -62,6 +64,7 @@ class GHistIndexMatrix { BinIdxType* index_data = index_data_span.data(); auto const& ptrs = cut.Ptrs(); auto const& values = cut.Values(); + std::atomic valid{true}; common::ParallelFor(batch_size, batch_threads, [&](size_t i) { auto line = batch.GetLine(i); size_t ibegin = row_ptr[rbegin + i]; // index of first entry for current block @@ -70,6 +73,9 @@ class GHistIndexMatrix { for (size_t j = 0; j < line.Size(); ++j) { data::COOTuple elem = line.GetElement(j); if (is_valid(elem)) { + if (XGBOOST_EXPECT((std::isinf(elem.value)), false)) { + valid = false; + } bst_bin_t bin_idx{-1}; if (common::IsCat(ft, elem.column_idx)) { bin_idx = cut.SearchCatBin(elem.value, elem.column_idx, ptrs, values); @@ -82,6 +88,8 @@ class GHistIndexMatrix { } } }); + + CHECK(valid) << error::InfInData(); } // Gather hit_count from all threads diff --git a/src/data/simple_dmatrix.cuh b/src/data/simple_dmatrix.cuh index c71a52b6746e..63310a92984f 100644 --- a/src/data/simple_dmatrix.cuh +++ b/src/data/simple_dmatrix.cuh @@ -1,18 +1,19 @@ -/*! - * Copyright 2019-2021 by XGBoost Contributors +/** + * Copyright 2019-2023 by XGBoost Contributors * \file simple_dmatrix.cuh */ #ifndef XGBOOST_DATA_SIMPLE_DMATRIX_CUH_ #define XGBOOST_DATA_SIMPLE_DMATRIX_CUH_ #include -#include #include -#include "device_adapter.cuh" +#include + #include "../common/device_helpers.cuh" +#include "../common/error_msg.h" // for InfInData +#include "device_adapter.cuh" // for HasInfInData -namespace xgboost { -namespace data { +namespace xgboost::data { template struct COOToEntryOp { @@ -61,7 +62,11 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, } template -size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, SparsePage* page) { +size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, + SparsePage* page) { + bool valid = HasInfInData(batch, IsValidFunctor{missing}); + CHECK(valid) << error::InfInData(); + page->offset.SetDevice(device); page->data.SetDevice(device); page->offset.Resize(batch.NumRows() + 1); @@ -73,6 +78,5 @@ size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missin return num_nonzero_; } -} // namespace data -} // namespace xgboost +} // namespace xgboost::data #endif // XGBOOST_DATA_SIMPLE_DMATRIX_CUH_ diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index 0250cea3f03d..3cd65e30fe8f 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -6,6 +6,7 @@ import xgboost as xgb from xgboost import testing as tm +from xgboost.testing.data import check_inf sys.path.append("tests/python") import test_quantile_dmatrix as tqd @@ -153,3 +154,9 @@ def test_ltr(self) -> None: from_qdm = xgb.QuantileDMatrix(X, weight=w, ref=Xy_qdm) assert tm.predictor_equal(from_qdm, from_dm) + + @pytest.mark.skipif(**tm.no_cupy()) + def test_check_inf(self) -> None: + import cupy as cp + rng = cp.random.default_rng(1994) + check_inf(rng) diff --git a/tests/python/test_quantile_dmatrix.py b/tests/python/test_quantile_dmatrix.py index 316d0e5f6515..537910725b7f 100644 --- a/tests/python/test_quantile_dmatrix.py +++ b/tests/python/test_quantile_dmatrix.py @@ -15,7 +15,7 @@ make_sparse_regression, predictor_equal, ) -from xgboost.testing.data import np_dtypes +from xgboost.testing.data import check_inf, np_dtypes class TestQuantileDMatrix: @@ -244,6 +244,10 @@ def test_ltr(self) -> None: from_dm = xgb.QuantileDMatrix(X, weight=w, ref=Xy) assert predictor_equal(from_qdm, from_dm) + def test_check_inf(self) -> None: + rng = np.random.default_rng(1994) + check_inf(rng) + # we don't test empty Quantile DMatrix in single node construction. @given( strategies.integers(1, 1000),