From 6e8d50f0d2d6295677cebb17e3a0dc34a2d6e479 Mon Sep 17 00:00:00 2001 From: jinfeng Date: Tue, 26 Sep 2023 22:50:37 -0700 Subject: [PATCH] add docstring and revise tests to include n_part --- cpp/include/cuml/linear_model/qn_mg.hpp | 8 +++++++- python/cuml/tests/dask/test_dask_logistic_regression.py | 9 ++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/cpp/include/cuml/linear_model/qn_mg.hpp b/cpp/include/cuml/linear_model/qn_mg.hpp index 50643aecaf..f70fd833e9 100644 --- a/cpp/include/cuml/linear_model/qn_mg.hpp +++ b/cpp/include/cuml/linear_model/qn_mg.hpp @@ -28,7 +28,13 @@ namespace ML { namespace GLM { namespace opg { -/*TODO add docstring*/ +/** + * @brief Calculate unique class labels across multiple GPUs in a multi-node environment. + * @param[in] handle: the internal cuml handle object + * @param[in] input_desc: PartDescriptor object for the input + * @param[in] labels: labels data + * @returns host vector that stores the distinct labels + */ std::vector getUniquelabelsMG(const raft::handle_t& handle, Matrix::PartDescriptor& input_desc, std::vector*>& labels); diff --git a/python/cuml/tests/dask/test_dask_logistic_regression.py b/python/cuml/tests/dask/test_dask_logistic_regression.py index 6fc8b0bad9..8e02fb8566 100644 --- a/python/cuml/tests/dask/test_dask_logistic_regression.py +++ b/python/cuml/tests/dask/test_dask_logistic_regression.py @@ -20,6 +20,8 @@ from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression as skLR from cuml.internals.safe_imports import cpu_only_import +from hypothesis import given +from hypothesis import strategies as st pd = cpu_only_import("pandas") np = cpu_only_import("numpy") @@ -385,13 +387,14 @@ def assert_small(X, y, n_classes): ) +@pytest.mark.parametrize("n_parts", [2, 23]) @pytest.mark.parametrize("fit_intercept", [False, True]) -@pytest.mark.parametrize("n_classes", [2, 8]) -def test_n_classes(fit_intercept, n_classes, client): +@pytest.mark.parametrize("n_classes", [8]) +def test_n_classes(n_parts, fit_intercept, n_classes, client): lr = test_lbfgs( nrows=1e5, ncols=20, - n_parts=23, + n_parts=n_parts, fit_intercept=False, datatype=np.float32, delayed=True,