From 2e4ea5ecc060a7742edea03b35b61ad20303e623 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 21 Feb 2024 02:18:42 +0800 Subject: [PATCH 1/8] Support f64 for ubjson. (#10055) --- include/xgboost/json.h | 19 +++++++++++------ include/xgboost/json_io.h | 4 +++- src/common/json.cc | 38 +++++++++++++++++++++------------ tests/cpp/common/test_json.cc | 35 ++++++++++++++++++++++++++++++ tests/cpp/test_serialization.cc | 11 +++++++++- 5 files changed, 84 insertions(+), 23 deletions(-) diff --git a/include/xgboost/json.h b/include/xgboost/json.h index a5872ec3a9d6..77ca6a510c96 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #ifndef XGBOOST_JSON_H_ #define XGBOOST_JSON_H_ @@ -42,7 +42,8 @@ class Value { kBoolean, kNull, // typed array for ubjson - kNumberArray, + kF32Array, + kF64Array, kU8Array, kI32Array, kI64Array @@ -173,7 +174,11 @@ class JsonTypedArray : public Value { /** * @brief Typed UBJSON array for 32-bit floating point. */ -using F32Array = JsonTypedArray; +using F32Array = JsonTypedArray; +/** + * @brief Typed UBJSON array for 64-bit floating point. + */ +using F64Array = JsonTypedArray; /** * @brief Typed UBJSON array for uint8_t. */ @@ -457,9 +462,9 @@ class Json { Json& operator[](int ind) const { return (*ptr_)[ind]; } /*! \brief Return the reference to stored Json value. */ - Value const& GetValue() const & { return *ptr_; } - Value const& GetValue() && { return *ptr_; } - Value& GetValue() & { return *ptr_; } + [[nodiscard]] Value const& GetValue() const& { return *ptr_; } + Value const& GetValue() && { return *ptr_; } + Value& GetValue() & { return *ptr_; } bool operator==(Json const& rhs) const { return *ptr_ == *(rhs.ptr_); @@ -472,7 +477,7 @@ class Json { return os; } - IntrusivePtr const& Ptr() const { return ptr_; } + [[nodiscard]] IntrusivePtr const& Ptr() const { return ptr_; } private: IntrusivePtr ptr_{new JsonNull}; diff --git a/include/xgboost/json_io.h b/include/xgboost/json_io.h index 3a73d170a4c7..ce3d25c37e19 100644 --- a/include/xgboost/json_io.h +++ b/include/xgboost/json_io.h @@ -142,6 +142,7 @@ class JsonWriter { virtual void Visit(JsonArray const* arr); virtual void Visit(F32Array const* arr); + virtual void Visit(F64Array const*) { LOG(FATAL) << "Only UBJSON format can handle f64 array."; } virtual void Visit(U8Array const* arr); virtual void Visit(I32Array const* arr); virtual void Visit(I64Array const* arr); @@ -244,7 +245,8 @@ class UBJReader : public JsonReader { */ class UBJWriter : public JsonWriter { void Visit(JsonArray const* arr) override; - void Visit(F32Array const* arr) override; + void Visit(F32Array const* arr) override; + void Visit(F64Array const* arr) override; void Visit(U8Array const* arr) override; void Visit(I32Array const* arr) override; void Visit(I64Array const* arr) override; diff --git a/src/common/json.cc b/src/common/json.cc index 21be2a5bc674..2887eeccf855 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -1,11 +1,12 @@ /** - * Copyright 2019-2023, XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include "xgboost/json.h" #include // for array #include // for isdigit #include // for isinf, isnan +#include // for uint8_t, uint16_t, uint32_t #include // for EOF #include // for size_t, strtof #include // for memcpy @@ -72,15 +73,16 @@ void JsonWriter::Visit(JsonNumber const* num) { } void JsonWriter::Visit(JsonInteger const* num) { - char i2s_buffer_[NumericLimits::kToCharsSize]; + std::array::kToCharsSize> i2s_buffer_; auto i = num->GetInteger(); - auto ret = to_chars(i2s_buffer_, i2s_buffer_ + NumericLimits::kToCharsSize, i); + auto ret = + to_chars(i2s_buffer_.data(), i2s_buffer_.data() + NumericLimits::kToCharsSize, i); auto end = ret.ptr; CHECK(ret.ec == std::errc()); - auto digits = std::distance(i2s_buffer_, end); + auto digits = std::distance(i2s_buffer_.data(), end); auto ori_size = stream_->size(); stream_->resize(ori_size + digits); - std::memcpy(stream_->data() + ori_size, i2s_buffer_, digits); + std::memcpy(stream_->data() + ori_size, i2s_buffer_.data(), digits); } void JsonWriter::Visit(JsonNull const* ) { @@ -143,8 +145,10 @@ std::string Value::TypeStr() const { return "Null"; case ValueKind::kInteger: return "Integer"; - case ValueKind::kNumberArray: + case ValueKind::kF32Array: return "F32Array"; + case ValueKind::kF64Array: + return "F64Array"; case ValueKind::kU8Array: return "U8Array"; case ValueKind::kI32Array: @@ -262,10 +266,11 @@ bool JsonTypedArray::operator==(Value const& rhs) const { return std::equal(arr.cbegin(), arr.cend(), vec_.cbegin()); } -template class JsonTypedArray; -template class JsonTypedArray; -template class JsonTypedArray; -template class JsonTypedArray; +template class JsonTypedArray; +template class JsonTypedArray; +template class JsonTypedArray; +template class JsonTypedArray; +template class JsonTypedArray; // Json Number bool JsonNumber::operator==(Value const& rhs) const { @@ -708,6 +713,8 @@ Json UBJReader::ParseArray() { switch (type) { case 'd': return ParseTypedArray(n); + case 'D': + return ParseTypedArray(n); case 'U': return ParseTypedArray(n); case 'l': @@ -797,6 +804,10 @@ Json UBJReader::Parse() { auto v = this->ReadPrimitive(); return Json{v}; } + case 'D': { + auto v = this->ReadPrimitive(); + return Json{v}; + } case 'S': { auto str = this->DecodeStr(); return Json{str}; @@ -825,10 +836,6 @@ Json UBJReader::Parse() { Integer::Int i = this->ReadPrimitive(); return Json{i}; } - case 'D': { - LOG(FATAL) << "f64 is not supported."; - break; - } case 'H': { LOG(FATAL) << "High precision number is not supported."; break; @@ -882,6 +889,8 @@ void WriteTypedArray(JsonTypedArray const* arr, std::vector* stre stream->push_back('$'); if (std::is_same::value) { stream->push_back('d'); + } else if (std::is_same_v) { + stream->push_back('D'); } else if (std::is_same::value) { stream->push_back('i'); } else if (std::is_same::value) { @@ -910,6 +919,7 @@ void WriteTypedArray(JsonTypedArray const* arr, std::vector* stre } void UBJWriter::Visit(F32Array const* arr) { WriteTypedArray(arr, stream_); } +void UBJWriter::Visit(F64Array const* arr) { WriteTypedArray(arr, stream_); } void UBJWriter::Visit(U8Array const* arr) { WriteTypedArray(arr, stream_); } void UBJWriter::Visit(I32Array const* arr) { WriteTypedArray(arr, stream_); } void UBJWriter::Visit(I64Array const* arr) { WriteTypedArray(arr, stream_); } diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 155cf04adf9a..72163efd78cc 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -639,6 +639,40 @@ TEST(Json, TypedArray) { ASSERT_EQ(arr[i + 8], i); } } + + { + Json f64{Object{}}; + auto array = F64Array(); + auto& vec = array.GetArray(); + // Construct test data + vec.resize(18); + std::iota(vec.begin(), vec.end(), 0.0); + // special values + vec.push_back(std::numeric_limits::epsilon()); + vec.push_back(std::numeric_limits::max()); + vec.push_back(std::numeric_limits::min()); + vec.push_back(std::numeric_limits::denorm_min()); + vec.push_back(std::numeric_limits::quiet_NaN()); + + static_assert( + std::is_same_v::value_type>); + + f64["f64"] = std::move(array); + ASSERT_TRUE(IsA(f64["f64"])); + std::vector out; + Json::Dump(f64, &out, std::ios::binary); + + auto loaded = Json::Load(StringView{out.data(), out.size()}, std::ios::binary); + ASSERT_TRUE(IsA(loaded["f64"])); + auto const& result = get(loaded["f64"]); + + auto& vec1 = get(f64["f64"]); + ASSERT_EQ(result.size(), vec1.size()); + for (std::size_t i = 0; i < vec1.size() - 1; ++i) { + ASSERT_EQ(result[i], vec1[i]); + } + ASSERT_TRUE(std::isnan(result.back())); + } } TEST(UBJson, Basic) { @@ -694,6 +728,7 @@ TEST(UBJson, Basic) { } } + TEST(Json, TypeCheck) { Json config{Object{}}; config["foo"] = String{"bar"}; diff --git a/tests/cpp/test_serialization.cc b/tests/cpp/test_serialization.cc index 0b65220ab9e6..283a56fc5ddf 100644 --- a/tests/cpp/test_serialization.cc +++ b/tests/cpp/test_serialization.cc @@ -60,7 +60,7 @@ void CompareJSON(Json l, Json r) { } break; } - case Value::ValueKind::kNumberArray: { + case Value::ValueKind::kF32Array: { auto const& l_arr = get(l); auto const& r_arr = get(r); ASSERT_EQ(l_arr.size(), r_arr.size()); @@ -69,6 +69,15 @@ void CompareJSON(Json l, Json r) { } break; } + case Value::ValueKind::kF64Array: { + auto const& l_arr = get(l); + auto const& r_arr = get(r); + ASSERT_EQ(l_arr.size(), r_arr.size()); + for (size_t i = 0; i < l_arr.size(); ++i) { + ASSERT_NEAR(l_arr[i], r_arr[i], kRtEps); + } + break; + } case Value::ValueKind::kU8Array: { CompareIntArray(l, r); break; From b9171d8f0baca7a8608221ad89da98496f52dd85 Mon Sep 17 00:00:00 2001 From: UncleLLD Date: Thu, 22 Feb 2024 17:34:12 +0800 Subject: [PATCH 2/8] [doc] Fix python docs (#10058) --- doc/python/callbacks.rst | 2 +- doc/python/sklearn_estimator.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/python/callbacks.rst b/doc/python/callbacks.rst index 7cb257a819ed..6d8b43a11557 100644 --- a/doc/python/callbacks.rst +++ b/doc/python/callbacks.rst @@ -36,7 +36,7 @@ inside iteration loop. You can also pass this callback function directly into X # Specify which dataset and which metric should be used for early stopping. early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, metric_name='CustomErr', - data_name='Train') + data_name='Valid') booster = xgb.train( {'objective': 'binary:logistic', diff --git a/doc/python/sklearn_estimator.rst b/doc/python/sklearn_estimator.rst index 207b9fa30920..1aaa340b1abb 100644 --- a/doc/python/sklearn_estimator.rst +++ b/doc/python/sklearn_estimator.rst @@ -62,7 +62,7 @@ stack of trees: .. code-block:: python early_stop = xgb.callback.EarlyStopping( - rounds=2, metric_name='logloss', data_name='Validation_0', save_best=True + rounds=2, metric_name='logloss', data_name='validation_0', save_best=True ) clf = xgb.XGBClassifier(tree_method="hist", callbacks=[early_stop]) clf.fit(X_train, y_train, eval_set=[(X_test, y_test)]) From eb281ff9b48bf88f01ec5ee0b1f41648ca27014f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 23 Feb 2024 06:21:32 +0800 Subject: [PATCH 3/8] [CI] Fix JVM tests on GH Action (#10064) --------- Co-authored-by: Hyunsu Cho --- .github/workflows/jvm_tests.yml | 32 ++++++++------ jvm-packages/create_jni.py | 65 +++++++++++++++++----------- python-package/packager/nativelib.py | 9 ++-- 3 files changed, 64 insertions(+), 42 deletions(-) diff --git a/.github/workflows/jvm_tests.yml b/.github/workflows/jvm_tests.yml index 330c037d7024..13c343d32479 100644 --- a/.github/workflows/jvm_tests.yml +++ b/.github/workflows/jvm_tests.yml @@ -15,31 +15,35 @@ jobs: os: [windows-latest, ubuntu-latest, macos-11] steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: actions/setup-python@7f80679172b057fc5e90d70d197929d454754a5a # v4.3.0 + - uses: mamba-org/setup-micromamba@422500192359a097648154e8db4e39bdb6c6eed7 # v1.8.1 with: - python-version: '3.8' - architecture: 'x64' - - - uses: actions/setup-java@d202f5dbf7256730fb690ec59f6381650114feb2 # v3.6.0 - with: - java-version: 1.8 - - - name: Install Python packages - run: | - python -m pip install wheel setuptools - python -m pip install awscli + micromamba-version: '1.5.6-0' + environment-name: jvm_tests + extra-specs: >- + python=3.10 + awscli + cache-downloads: true + cache-env: true - name: Cache Maven packages - uses: actions/cache@6998d139ddd3e68c71e9e398d8e40b71a2f39812 # v3.2.5 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: ~/.m2 key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }} restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }} + - name: Build xgboost4j.dll + run: | + mkdir build + cd build + cmake .. -G"Visual Studio 17 2022" -A x64 -DJVM_BINDINGS=ON + cmake --build . --config Release + if: matrix.os == 'windows-latest' + - name: Test XGBoost4J (Core) run: | cd jvm-packages diff --git a/jvm-packages/create_jni.py b/jvm-packages/create_jni.py index c39d354cf8cb..865d07fe8b0f 100755 --- a/jvm-packages/create_jni.py +++ b/jvm-packages/create_jni.py @@ -81,40 +81,55 @@ def native_build(args): with cd(".."): build_dir = "build-gpu" if cli_args.use_cuda == "ON" else "build" maybe_makedirs(build_dir) - with cd(build_dir): - if sys.platform == "win32": - # Force x64 build on Windows. - maybe_generator = " -A x64" - else: - maybe_generator = "" - if sys.platform == "linux": - maybe_parallel_build = " -- -j $(nproc)" - else: - maybe_parallel_build = "" - if cli_args.log_capi_invocation == "ON": - CONFIG["LOG_CAPI_INVOCATION"] = "ON" + if sys.platform == "linux": + maybe_parallel_build = " -- -j $(nproc)" + else: + maybe_parallel_build = "" + + if cli_args.log_capi_invocation == "ON": + CONFIG["LOG_CAPI_INVOCATION"] = "ON" - if cli_args.use_cuda == "ON": - CONFIG["USE_CUDA"] = "ON" - CONFIG["USE_NCCL"] = "ON" - CONFIG["USE_DLOPEN_NCCL"] = "OFF" + if cli_args.use_cuda == "ON": + CONFIG["USE_CUDA"] = "ON" + CONFIG["USE_NCCL"] = "ON" + CONFIG["USE_DLOPEN_NCCL"] = "OFF" - args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()] + args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()] - # if enviorment set rabit_mock - if os.getenv("RABIT_MOCK", None) is not None: - args.append("-DRABIT_MOCK:BOOL=ON") + # if enviorment set rabit_mock + if os.getenv("RABIT_MOCK", None) is not None: + args.append("-DRABIT_MOCK:BOOL=ON") - # if enviorment set GPU_ARCH_FLAG - gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None) - if gpu_arch_flag is not None: - args.append("%s" % gpu_arch_flag) + # if enviorment set GPU_ARCH_FLAG + gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None) + if gpu_arch_flag is not None: + args.append("%s" % gpu_arch_flag) + with cd(build_dir): lib_dir = os.path.join(os.pardir, "lib") if os.path.exists(lib_dir): shutil.rmtree(lib_dir) - run("cmake .. " + " ".join(args) + maybe_generator) + + # Same trick as Python build, just test all possible generators. + if sys.platform == "win32": + supported_generators = ( + "", # empty, decided by cmake + '-G"Visual Studio 17 2022" -A x64', + '-G"Visual Studio 16 2019" -A x64', + '-G"Visual Studio 15 2017" -A x64', + ) + for generator in supported_generators: + try: + run("cmake .. " + " ".join(args + [generator])) + break + except subprocess.CalledProcessError as e: + print(f"Failed to build with generator: {generator}", e) + with cd(os.path.pardir): + shutil.rmtree(build_dir) + maybe_makedirs(build_dir) + else: + run("cmake .. " + " ".join(args)) run("cmake --build . --config Release" + maybe_parallel_build) with cd("demo/CLI/regression"): diff --git a/python-package/packager/nativelib.py b/python-package/packager/nativelib.py index 0227cff37f25..42b510eef540 100644 --- a/python-package/packager/nativelib.py +++ b/python-package/packager/nativelib.py @@ -32,7 +32,10 @@ def build_libxgboost( build_dir: pathlib.Path, build_config: BuildConfiguration, ) -> pathlib.Path: - """Build libxgboost in a temporary directory and obtain the path to built libxgboost""" + """Build libxgboost in a temporary directory and obtain the path to built + libxgboost. + + """ logger = logging.getLogger("xgboost.packager.build_libxgboost") if not cpp_src_dir.is_dir(): @@ -51,8 +54,8 @@ def _build(*, generator: str) -> None: cmake_cmd.extend(build_config.get_cmake_args()) # Flag for cross-compiling for Apple Silicon - # We use environment variable because it's the only way to pass down custom flags - # through the cibuildwheel package, which calls `pip wheel` command. + # We use environment variable because it's the only way to pass down custom + # flags through the cibuildwheel package, which calls `pip wheel` command. if "CIBW_TARGET_OSX_ARM64" in os.environ: cmake_cmd.append("-DCMAKE_OSX_ARCHITECTURES=arm64") From 5b1d7a760bbadd5c04b16e5c859d9918085813c0 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 22 Feb 2024 14:40:11 -0800 Subject: [PATCH 4/8] [CI] Hotfix for GH Action --- .github/workflows/jvm_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jvm_tests.yml b/.github/workflows/jvm_tests.yml index 13c343d32479..994fc96ffa8f 100644 --- a/.github/workflows/jvm_tests.yml +++ b/.github/workflows/jvm_tests.yml @@ -23,7 +23,7 @@ jobs: with: micromamba-version: '1.5.6-0' environment-name: jvm_tests - extra-specs: >- + create-args: >- python=3.10 awscli cache-downloads: true From aaa950951b75179e938aa4fd00a05a2a479c9438 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 22 Feb 2024 14:53:55 -0800 Subject: [PATCH 5/8] [CI] Hotfix for GH Action --- .github/workflows/jvm_tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/jvm_tests.yml b/.github/workflows/jvm_tests.yml index 994fc96ffa8f..d508ac6634c2 100644 --- a/.github/workflows/jvm_tests.yml +++ b/.github/workflows/jvm_tests.yml @@ -28,6 +28,7 @@ jobs: awscli cache-downloads: true cache-env: true + init-shell: bash powershell - name: Cache Maven packages uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 From 3ab8ccaa0cb8462516064b715be95f84dde5238c Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 22 Feb 2024 14:55:41 -0800 Subject: [PATCH 6/8] [CI] Hotfix for GH Action --- .github/workflows/jvm_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jvm_tests.yml b/.github/workflows/jvm_tests.yml index d508ac6634c2..6c04fee8a6aa 100644 --- a/.github/workflows/jvm_tests.yml +++ b/.github/workflows/jvm_tests.yml @@ -27,7 +27,7 @@ jobs: python=3.10 awscli cache-downloads: true - cache-env: true + cache-environment: true init-shell: bash powershell - name: Cache Maven packages From 9f7b94cf7028136b555e44c661bb678c48c9c1cf Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Thu, 22 Feb 2024 17:16:48 -0800 Subject: [PATCH 7/8] [CI] Patch GitHub Action pipeline (#10067) --- .github/workflows/jvm_tests.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/jvm_tests.yml b/.github/workflows/jvm_tests.yml index 6c04fee8a6aa..bbded088387f 100644 --- a/.github/workflows/jvm_tests.yml +++ b/.github/workflows/jvm_tests.yml @@ -52,7 +52,8 @@ jobs: - name: Extract branch name shell: bash - run: echo "##[set-output name=branch;]$(echo ${GITHUB_REF#refs/heads/})" + run: | + echo "branch=${GITHUB_REF#refs/heads/}" >> "$GITHUB_OUTPUT" id: extract_branch if: | (github.ref == 'refs/heads/master' || contains(github.ref, 'refs/heads/release_')) && @@ -63,7 +64,7 @@ jobs: cd lib/ Rename-Item -Path xgboost4j.dll -NewName xgboost4j_${{ github.sha }}.dll dir - python -m awscli s3 cp xgboost4j_${{ github.sha }}.dll s3://xgboost-nightly-builds/${{ steps.extract_branch.outputs.branch }}/libxgboost4j/ --acl public-read + python -m awscli s3 cp xgboost4j_${{ github.sha }}.dll s3://xgboost-nightly-builds/${{ steps.extract_branch.outputs.branch }}/libxgboost4j/ --acl public-read --region us-west-2 if: | (github.ref == 'refs/heads/master' || contains(github.ref, 'refs/heads/release_')) && matrix.os == 'windows-latest' @@ -72,11 +73,12 @@ jobs: AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }} - name: Publish artifact libxgboost4j.dylib to S3 + shell: bash -l {0} run: | cd lib/ mv -v libxgboost4j.dylib libxgboost4j_${{ github.sha }}.dylib ls - python -m awscli s3 cp libxgboost4j_${{ github.sha }}.dylib s3://xgboost-nightly-builds/${{ steps.extract_branch.outputs.branch }}/libxgboost4j/ --acl public-read + python -m awscli s3 cp libxgboost4j_${{ github.sha }}.dylib s3://xgboost-nightly-builds/${{ steps.extract_branch.outputs.branch }}/libxgboost4j/ --acl public-read --region us-west-2 if: | (github.ref == 'refs/heads/master' || contains(github.ref, 'refs/heads/release_')) && matrix.os == 'macos-11' From 729fd9719646cd1e0636d94a7a7e5f9158e27429 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 22 Feb 2024 22:01:24 -0600 Subject: [PATCH 8/8] [doc] Fix spark_estimator doc (#10066) --- doc/tutorials/spark_estimator.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/tutorials/spark_estimator.rst b/doc/tutorials/spark_estimator.rst index 8bd1dcd97a76..4e608440a536 100644 --- a/doc/tutorials/spark_estimator.rst +++ b/doc/tutorials/spark_estimator.rst @@ -28,7 +28,7 @@ We can create a ``SparkXGBRegressor`` estimator like: .. code-block:: python from xgboost.spark import SparkXGBRegressor - spark_reg_estimator = SparkXGBRegressor( + xgb_regressor = SparkXGBRegressor( features_col="features", label_col="label", num_workers=2, @@ -61,7 +61,7 @@ type or spark array type. .. code-block:: python - transformed_test_spark_dataframe = xgb_regressor.predict(test_spark_dataframe) + transformed_test_spark_dataframe = xgb_regressor_model.transform(test_spark_dataframe) The above snippet code returns a ``transformed_test_spark_dataframe`` that contains the input