diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 4b912dc966..9fcd6401ad 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -91,7 +91,7 @@ void fit(handle_t const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - detail::kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); + kmeans::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); } /** @@ -156,7 +156,7 @@ void predict(handle_t const& handle, bool normalize_weight, raft::host_scalar_view inertia) { - detail::kmeans_predict( + kmeans::predict( handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } @@ -219,7 +219,7 @@ void fit_predict(handle_t const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - detail::kmeans_fit_predict( + kmeans::fit_predict( handle, params, X, sample_weight, centroids, labels, inertia, n_iter); } @@ -245,7 +245,7 @@ void transform(const raft::handle_t& handle, raft::device_matrix_view centroids, raft::device_matrix_view X_new) { - detail::kmeans_transform(handle, params, X, centroids, X_new); + kmeans::transform(handle, params, X, centroids, X_new); } template @@ -257,8 +257,7 @@ void transform(const raft::handle_t& handle, IndexT n_features, DataT* X_new) { - detail::kmeans_transform( - handle, params, X, centroids, n_samples, n_features, X_new); + kmeans::transform(handle, params, X, centroids, n_samples, n_features, X_new); } /** @@ -571,7 +570,7 @@ void fit_main(const raft::handle_t& handle, handle, params, X, sample_weights, centroids, inertia, n_iter, workspace); } -}; // end namespace raft::cluster::kmeans +}; // namespace raft::cluster::kmeans namespace raft::cluster { diff --git a/cpp/include/raft/solver/coordinate_descent.cuh b/cpp/include/raft/solver/coordinate_descent.cuh new file mode 100644 index 0000000000..aa2086d3b3 --- /dev/null +++ b/cpp/include/raft/solver/coordinate_descent.cuh @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::solver::coordinate_descent { + +/** + * @brief Minimizes an objective function using the Coordinate Descent solver. + * + * Note: Currently only least squares loss is supported w/ optional lasso and elastic-net penalties: + * f(coef) = 1/2 * || b - Ax ||^2 + * + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2 + * + alpha * l1_ratio * ||coef||_1 + * + * @param[in] handle: Reference of raft::handle_t + * @param[in] A: Input matrix in column-major format (size of n_rows, n_cols) + * @param[in] b: Input vector of labels (size of n_rows) + * @param[in] sample_weights: Optional input vector for sample weights (size n_rows) + * @param[out] x: Output vector of learned coefficients (size of n_cols) + * @param[out] intercept: Optional scalar to hold intercept if desired + */ +template +void minimize(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + std::optional < raft::device_vector_view sample_weights, + raft::device_vector_view x, + std::optional> intercept, + cd_params& params) +{ + RAFT_EXPECTS(A.extent(0) == b.extent(0), + "Number of labels must match the number of rows in input matrix"); + + if (sample_weights.has_value()) { + RAFT_EXPECTS(A.extent(0) == sample_weights.value().extent(0), + "Number of sample weights must match number of rows in input matrix"); + } + + RAFT_EXPECTS(x.extent(0) == A.extent(1), + "Objective is linear. The number of coefficients must match the number features in " + "the input matrix"); + RAFT_EXPECTS(lossFunct == loss_funct::SQRD_LOSS, + "Only squared loss is supported in the current implementation."); + + math_t* intercept_ptr = intercept.has_value() ? intercept.value().data_handle() : nullptr; + math_t* sample_weight_ptr = + sample_weights.has_value() ? sample_weights.value().data_handle() : nullptr; + + detail::cdFit(handle, + A.data_handle(), + A.extent(0), + A.extent(1), + b.data_handle(), + x.data_handle(), + intercept_ptr, + intercept.has_value(), + params.normalize, + params.epochs, + params.loss, + params.alpha, + params.l1_ratio, + params.shuffle, + params.tol, + sample_weight_ptr); +} +} // namespace raft::solver::coordinate_descent \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/cd.cuh b/cpp/include/raft/solver/detail/cd.cuh new file mode 100644 index 0000000000..ad6092c929 --- /dev/null +++ b/cpp/include/raft/solver/detail/cd.cuh @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail { + +namespace { + +/** Epoch and iteration -related state. */ +template +struct ConvState { + math_t coef; + math_t coefMax; + math_t diffMax; +}; + +/** + * Update a single CD coefficient and the corresponding convergence criteria. + * + * @param[inout] coefLoc pointer to the coefficient (arr ptr + column index offset) + * @param[in] squaredLoc pointer to the precomputed data - L2 norm of input for across rows + * @param[inout] convStateLoc pointer to the structure holding the convergence state + * @param[in] l1_alpha L1 regularization coef + */ +template +__global__ void __launch_bounds__(1, 1) cdUpdateCoefKernel(math_t* coefLoc, + const math_t* squaredLoc, + ConvState* convStateLoc, + const math_t l1_alpha) +{ + auto coef = *coefLoc; + auto r = coef > l1_alpha ? coef - l1_alpha : (coef < -l1_alpha ? coef + l1_alpha : 0); + auto squared = *squaredLoc; + r = squared > math_t(1e-5) ? r / squared : math_t(0); + auto diff = raft::myAbs(convStateLoc->coef - r); + if (convStateLoc->diffMax < diff) convStateLoc->diffMax = diff; + auto absv = raft::myAbs(r); + if (convStateLoc->coefMax < absv) convStateLoc->coefMax = absv; + convStateLoc->coef = -r; + *coefLoc = r; +} + +} // namespace + +/** + * Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver. + * + * i.e. finds coefficients that minimize the following loss function: + * + * f(coef) = 1/2 * || b - A * x ||^2 + * + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2 + * + alpha * l1_ratio * ||coef||_1 + * + * + * @param handle + * Reference of raft::handle_t + * @param input + * pointer to an array in column-major format (size of n_rows, n_cols) + * @param n_rows + * n_samples or rows in input + * @param n_cols + * n_features or columns in X + * @param labels + * pointer to an array for labels (size of n_rows) + * @param coef + * pointer to an array for coefficients (size of n_cols). This will be filled with + * coefficients once the function is executed. + * @param intercept + * pointer to a scalar for intercept. This will be filled + * once the function is executed + * @param fit_intercept + * boolean parameter to control if the intercept will be fitted or not + * @param normalize + * boolean parameter to control if the data will be normalized or not; + * NB: the input is scaled by the column-wise biased sample standard deviation estimator. + * @param epochs + * Maximum number of iterations that solver will run + * @param loss + * enum to use different loss functions. Only linear regression loss functions is supported + * right now + * @param alpha + * L1 parameter + * @param l1_ratio + * ratio of alpha will be used for L1. (1 - l1_ratio) * alpha will be used for L2 + * @param shuffle + * boolean parameter to control whether coordinates will be picked randomly or not + * @param tol + * tolerance to stop the solver + * @param sample_weight + * device pointer to sample weight vector of length n_rows (nullptr or uniform weights) + * This vector is modified during the computation + */ +template +void cdFit(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + math_t* coef, + math_t* intercept, + bool fit_intercept, + bool normalize, + int epochs, + ML::loss_funct loss, + math_t alpha, + math_t l1_ratio, + bool shuffle, + math_t tol, + math_t* sample_weight = nullptr) +{ + raft::common::nvtx::range fun_scope("ML::Solver::cdFit-%d-%d", n_rows, n_cols); + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + ASSERT(loss == ML::loss_funct::SQRD_LOSS, + "Parameter loss: Only SQRT_LOSS function is supported for now"); + + cudaStream_t stream = handle.get_stream(); + rmm::device_uvector residual(n_rows, stream); + rmm::device_uvector squared(n_cols, stream); + rmm::device_uvector mu_input(0, stream); + rmm::device_uvector mu_labels(0, stream); + rmm::device_uvector norm2_input(0, stream); + math_t h_sum_sw = 0; + + if (sample_weight != nullptr) { + rmm::device_scalar sum_sw(stream); + raft::stats::sum(sum_sw.data(), sample_weight, 1, n_rows, true, stream); + raft::update_host(&h_sum_sw, sum_sw.data(), 1, stream); + + raft::linalg::multiplyScalar( + sample_weight, sample_weight, (math_t)n_rows / h_sum_sw, n_rows, stream); + } + + if (fit_intercept) { + mu_input.resize(n_cols, stream); + mu_labels.resize(1, stream); + if (normalize) { norm2_input.resize(n_cols, stream); } + + preProcessData(handle, + input, + n_rows, + n_cols, + labels, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + normalize, + sample_weight); + } + if (sample_weight != nullptr) { + raft::linalg::sqrt(sample_weight, sample_weight, n_rows, stream); + raft::matrix::matrixVectorBinaryMult( + input, sample_weight, n_rows, n_cols, false, false, stream); + raft::linalg::map_k( + labels, + n_rows, + [] __device__(math_t a, math_t b) { return a * b; }, + stream, + labels, + sample_weight); + } + + std::vector ri(n_cols); + std::mt19937 g(rand()); + initShuffle(ri, g); + + math_t l2_alpha = (1 - l1_ratio) * alpha * n_rows; + math_t l1_alpha = l1_ratio * alpha * n_rows; + + // Precompute the residual + if (normalize) { + // if we normalized the data, we know sample variance for each column is 1, + // thus no need to compute the norm again. + math_t scalar = math_t(n_rows) + l2_alpha; + raft::matrix::setValue(squared.data(), squared.data(), scalar, n_cols, stream); + } else { + raft::linalg::colNorm( + squared.data(), input, n_cols, n_rows, raft::linalg::L2Norm, false, stream); + raft::linalg::addScalar(squared.data(), squared.data(), l2_alpha, n_cols, stream); + } + + raft::copy(residual.data(), labels, n_rows, stream); + + ConvState h_convState; + rmm::device_uvector> convStateBuf(1, stream); + auto convStateLoc = convStateBuf.data(); + + rmm::device_scalar cublas_alpha(1.0, stream); + rmm::device_scalar cublas_beta(0.0, stream); + + for (int i = 0; i < epochs; i++) { + raft::common::nvtx::range epoch_scope("ML::Solver::cdFit::epoch-%d", i); + if (i > 0 && shuffle) { Solver::shuffle(ri, g); } + + RAFT_CUDA_TRY(cudaMemsetAsync(convStateLoc, 0, sizeof(ConvState), stream)); + + for (int j = 0; j < n_cols; j++) { + raft::common::nvtx::range iter_scope("ML::Solver::cdFit::col-%d", j); + int ci = ri[j]; + math_t* coef_loc = coef + ci; + math_t* squared_loc = squared.data() + ci; + math_t* input_col_loc = input + (ci * n_rows); + + // remember current coef + raft::copy(&(convStateLoc->coef), coef_loc, 1, stream); + // calculate the residual without the contribution from column ci + // residual[:] += coef[ci] * X[:, ci] + raft::linalg::axpy( + handle, n_rows, coef_loc, input_col_loc, 1, residual.data(), 1, stream); + + // coef[ci] = dot(X[:, ci], residual[:]) + raft::linalg::gemv(handle, + false, + 1, + n_rows, + cublas_alpha.data(), + input_col_loc, + 1, + residual.data(), + 1, + cublas_beta.data(), + coef_loc, + 1, + stream); + + // Calculate the new coefficient that minimizes f along coordinate line ci + // coef[ci] = SoftTreshold(dot(X[:, ci], residual[:]), l1_alpha) / dot(X[:, ci], X[:, ci])) + // Also, update the convergence criteria. + cdUpdateCoefKernel<<>>( + coef_loc, squared_loc, convStateLoc, l1_alpha); + RAFT_CUDA_TRY(cudaGetLastError()); + + // Restore the residual using the updated coeffecient + raft::linalg::axpy( + handle, n_rows, &(convStateLoc->coef), input_col_loc, 1, residual.data(), 1, stream); + } + raft::update_host(&h_convState, convStateLoc, 1, stream); + handle.sync_stream(stream); + + if (h_convState.coefMax < tol || (h_convState.diffMax / h_convState.coefMax) < tol) break; + } + + if (sample_weight != nullptr) { + raft::matrix::matrixVectorBinaryDivSkipZero( + input, sample_weight, n_rows, n_cols, false, false, stream); + raft::linalg::map_k( + labels, + n_rows, + [] __device__(math_t a, math_t b) { return a / b; }, + stream, + labels, + sample_weight); + raft::linalg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream); + raft::linalg::multiplyScalar(sample_weight, sample_weight, h_sum_sw / n_rows, n_rows, stream); + } + + if (fit_intercept) { + postProcessData(handle, + input, + n_rows, + n_cols, + labels, + coef, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + normalize); + + } else { + *intercept = math_t(0); + } +} + +/** + * Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver + * @param handle + * cuml handle + * @param input + * pointer to an array in column-major format (size of n_rows, n_cols) + * @param n_rows + * n_samples or rows in input + * @param n_cols + * n_features or columns in X + * @param coef + * pointer to an array for coefficients (size of n_cols). Calculated in cdFit function. + * @param intercept + * intercept value calculated in cdFit function + * @param preds + * pointer to an array for predictions (size of n_rows). This will be fitted once functions + * is executed. + * @param loss + * enum to use different loss functions. Only linear regression loss functions is supported + * right now. + */ +template +void cdPredict(const raft::handle_t& handle, + const math_t* input, + int n_rows, + int n_cols, + const math_t* coef, + math_t intercept, + math_t* preds, + ML::loss_funct loss) +{ + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + ASSERT(loss == ML::loss_funct::SQRD_LOSS, + "Parameter loss: Only SQRT_LOSS function is supported for now"); + + Functions::linearRegH(handle, input, n_rows, n_cols, coef, preds, intercept, handle.get_stream()); +} + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/lars.cuh b/cpp/include/raft/solver/detail/lars.cuh new file mode 100644 index 0000000000..6ee77ec6f8 --- /dev/null +++ b/cpp/include/raft/solver/detail/lars.cuh @@ -0,0 +1,1141 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail { + +/** + * @brief Select the largest element from the inactive working set. + * + * The inactive set consist of cor[n_active..n-1]. This function returns the + * index of the most correlated element. The value of the largest element is + * returned in cj. + * + * The correlation value is checked for numeric error and convergence, and the + * return status indicates whether training should continue. + * + * @param n_active number of active elements (n_active <= n ) + * @param n number of elements in vector cor + * @param correlation device array of correlations, size [n] + * @param cj host pointer to return the value of the largest element + * @param wokspace buffer, size >= n_cols + * @param max_idx host pointer the index of the max correlation is returned here + * @param indices host pointer of feature column indices, size [n_cols] + * @param n_iter iteration counter + * @param stream CUDA stream + * + * @return fit status + */ +template +LarsFitStatus selectMostCorrelated(idx_t n_active, + idx_t n, + math_t* correlation, + math_t* cj, + rmm::device_uvector& workspace, + idx_t* max_idx, + idx_t n_rows, + idx_t* indices, + idx_t n_iter, + cudaStream_t stream) +{ + const idx_t align_bytes = 16 * sizeof(math_t); + // We might need to start a few elements earlier to ensure that the unary + // op has aligned access for vectorized load. + int start = raft::alignDown(n_active, align_bytes) / sizeof(math_t); + raft::linalg::unaryOp( + workspace.data(), correlation + start, n, [] __device__(math_t a) { return abs(a); }, stream); + thrust::device_ptr ptr(workspace.data() + n_active - start); + auto max_ptr = thrust::max_element(thrust::cuda::par.on(stream), ptr, ptr + n - n_active); + raft::update_host(cj, max_ptr.get(), 1, stream); + raft::interruptible::synchronize(stream); + + *max_idx = n_active + (max_ptr - ptr); // the index of the maximum element + + RAFT_LOG_DEBUG( + "Iteration %d, selected feature %d with correlation %f", n_iter, indices[*max_idx], *cj); + + if (!std::isfinite(*cj)) { + RAFT_LOG_ERROR("Correlation is not finite, aborting."); + return LarsFitStatus::kError; + } + + // Tolerance for early stopping. Note we intentionally use here fp32 epsilon, + // otherwise the tolerance is too small (which could result in numeric error + // in Cholesky rank one update if eps < 0, or exploding regression parameters + // if eps > 0). + const math_t tolerance = std::numeric_limits::epsilon(); + if (abs(*cj) / n_rows < tolerance) { + RAFT_LOG_WARN("Reached tolarence limit with %e", abs(*cj)); + return LarsFitStatus::kStop; + } + return LarsFitStatus::kOk; +} + +/** + * @brief Swap two feature vectors. + * + * The function swaps feature column j and k or the corresponding rows and + * and columns of the Gram matrix. The elements of the cor and indices arrays + * are also swapped. + * + * @param handle cuBLAS handle + * @param j column index + * @param k column index + * @param X device array of feature vectors in column major format, size + * [n_cols * ld_X] + * @param n_rows number of training vectors + * @param n_cols number of features + * @param ld_X leading dimension of X + * @param cor device array of correlations, size [n_cols] + * @param indices host array of indices, size [n_cols] + * @param G device pointer of Gram matrix (or nullptr), size [n_cols * ld_G] + * @param ld_G leading dimension of G + * @param stream CUDA stream + */ +template +void swapFeatures(cublasHandle_t handle, + idx_t j, + idx_t k, + math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + math_t* cor, + idx_t* indices, + math_t* G, + idx_t ld_G, + cudaStream_t stream) +{ + std::swap(indices[j], indices[k]); + if (G) { + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY( + raft::linalg::detail::cublasSwap(handle, n_cols, G + ld_G * j, 1, G + ld_G * k, 1, stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY( + raft::linalg::detail::cublasSwap(handle, n_cols, G + j, ld_G, G + k, ld_G, stream)); + } else { + // Only swap X if G is nullptr. Only in that case will we use the feature + // columns, otherwise all the necessary information is already there in G. + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY( + raft::linalg::detail::cublasSwap(handle, n_rows, X + ld_X * j, 1, X + ld_X * k, 1, stream)); + } + // swap (c[j], c[k]) + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasSwap(handle, 1, cor + j, 1, cor + k, 1, stream)); +} + +/** + * @brief Move feature at idx=j into the active set. + * + * We have an active set with n_active elements, and an inactive set with + * n_valid_cols - n_active elements. The matrix X [n_samples, n_features] is + * partitioned in a way that the first n_active columns store the active set. + * Similarily the vectors correlation and indices are partitioned in a way + * that the first n_active elements belong to the active set: + * - active set: X[:,:n_active], correlation[:n_active], indices[:n_active] + * - inactive set: X[:,n_active:], correlation[n_active:], indices[n_active:]. + * + * This function moves the feature column X[:,idx] into the active set by + * replacing the first inactive element with idx. The indices and correlation + * vectors are modified accordinly. The sign array is updated with the sign + * of correlation[n_active]. + * + * @param handle cuBLAS handle + * @param n_active number of active elements, will be increased by one after + * we move the new element j into the active set + * @param j index of the new element (n_active <= j < n_cols) + * @param X device array of feature vectors in column major format, size + * [n_cols * ld_X] + * @param n_rows number of training vectors + * @param n_cols number of valid features colums (ignoring those features which + * are detected to be collinear with the active set) + * @param ld_X leading dimension of X + * @param cor device array of correlations, size [n_cols] + * @param indices host array of indices, size [n_cols] + * @param G device pointer of Gram matrix (or nullptr), size [n_cols * ld_G] + * @param ld_G leading dimension of G + * @param sign device pointer to sign array, size[n] + * @param stream CUDA stream + */ +template +void moveToActive(cublasHandle_t handle, + idx_t* n_active, + idx_t j, + math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + math_t* cor, + idx_t* indices, + math_t* G, + idx_t ld_G, + math_t* sign, + cudaStream_t stream) +{ + idx_t idx_free = *n_active; + swapFeatures(handle, idx_free, j, X, n_rows, n_cols, ld_X, cor, indices, G, ld_G, stream); + + // sign[n_active] = sign(c[n_active]) + raft::linalg::unaryOp( + sign + idx_free, + cor + idx_free, + 1, + [] __device__(math_t c) -> math_t { + // return the sign of c + return (math_t(0) < c) - (c < math_t(0)); + }, + stream); + + (*n_active)++; +} + +/** + * @brief Update the Cholesky decomposition of the Gram matrix of the active set + * + * G0 = X.T * X, Gram matrix without signs. We use the part that corresponds to + * the active set, [n_A x n_A] + * + * At each step on the LARS path we add one column to the active set, therefore + * the Gram matrix grows incrementally. We update the Cholesky decomposition + * G0 = U.T * U. + * + * The Cholesky decomposition can use the same storage as G0, if the input + * pointers are same. + * + * @param handle RAFT handle + * @param n_active number of active elements + * @param X device array of feature vectors in column major format, size + * [n_rows * n_cols] + * @param n_rows number of training vectors + * @param n_cols number of features + * @param ld_X leading dimension of X (stride of columns) + * @param U device pointer to the Cholesky decomposition of G0, + * size [n_cols * ld_U] + * @param ld_U leading dimension of U + * @param G0 device pointer to Gram matrix G0 = X.T*X (can be nullptr), + * size [n_cols * ld_G]. + * @param ld_G leading dimension of G + * @param workspace workspace for the Cholesky update + * @param eps parameter for cheleskyRankOneUpdate + * @param stream CUDA stream + */ +template +void updateCholesky(const raft::handle_t& handle, + idx_t n_active, + const math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + math_t* U, + idx_t ld_U, + const math_t* G0, + idx_t ld_G, + rmm::device_uvector& workspace, + math_t eps, + cudaStream_t stream) +{ + const cublasFillMode_t fillmode = CUBLAS_FILL_MODE_UPPER; + if (G0 == nullptr) { + // Calculate the new column of G0. It is stored in U. + math_t* G_row = U + (n_active - 1) * ld_U; + const math_t* X_row = X + (n_active - 1) * ld_X; + math_t one = 1; + math_t zero = 0; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_T, + n_rows, + n_cols, + &one, + X, + n_rows, + X_row, + 1, + &zero, + G_row, + 1, + stream)); + } else if (G0 != U) { + // Copy the new column of G0 into U, because the factorization works in + // place. + raft::copy(U + (n_active - 1) * ld_U, G0 + (n_active - 1) * ld_G, n_active, stream); + } // Otherwise the new data is already in place in U. + + // Update the Cholesky decomposition + int n_work = workspace.size(); + if (n_work == 0) { + // Query workspace size and allocate it + raft::linalg::choleskyRank1Update( + handle, U, n_active, ld_U, nullptr, &n_work, fillmode, stream); + workspace.resize(n_work, stream); + } + raft::linalg::choleskyRank1Update( + handle, U, n_active, ld_U, workspace.data(), &n_work, fillmode, stream, eps); +} + +/** + * @brief Solve for ws = S * GA^(-1) * 1_A using a Cholesky decomposition. + * + * See calcEquiangularVec for more details on the formulas. In this function we + * calculate ws = S * (S * G0 * S)^{-1} 1_A = G0^{-1} (S 1_A) = G0^{-1} sign_A. + * + * @param handle RAFT handle + * @param n_active number of active elements + * @param n_cols number of features + * @param sign array with sign of the active set, size [n_cols] + * @param U device pointer to the Cholesky decomposition of G0, + * size [n_cols * n_cols] + * @param ld_U leading dimension of U (column stride) + * @param ws device pointer, size [n_active] + * @param stream CUDA stream + */ +template +void calcW0(const raft::handle_t& handle, + idx_t n_active, + idx_t n_cols, + const math_t* sign, + const math_t* U, + idx_t ld_U, + math_t* ws, + cudaStream_t stream) +{ + const cublasFillMode_t fillmode = CUBLAS_FILL_MODE_UPPER; + + // First we calculate x by solving equation U.T x = sign_A. + raft::copy(ws, sign, n_active, stream); + math_t alpha = 1; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(handle.get_cublas_handle(), + CUBLAS_SIDE_LEFT, + fillmode, + CUBLAS_OP_T, + CUBLAS_DIAG_NON_UNIT, + n_active, + 1, + &alpha, + U, + ld_U, + ws, + ld_U, + stream)); + + // ws stores x, the solution of U.T x = sign_A. Now we solve U * ws = x + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(handle.get_cublas_handle(), + CUBLAS_SIDE_LEFT, + fillmode, + CUBLAS_OP_N, + CUBLAS_DIAG_NON_UNIT, + n_active, + 1, + &alpha, + U, + ld_U, + ws, + ld_U, + stream)); + // Now ws = G0^(-1) sign_A = S GA^{-1} 1_A. +} + +/** + * @brief Calculate A = (1_A * GA^{-1} * 1_A)^{-1/2}. + * + * See calcEquiangularVec for more details on the formulas. + * + * @param handle RAFT handle + * @param A device pointer to store the result + * @param n_active number of active elements + * @param sign array with sign of the active set, size [n_cols] + * @param ws device pointer, size [n_active] + * @param stream CUDA stream + */ +template +void calcA(const raft::handle_t& handle, + math_t* A, + idx_t n_active, + const math_t* sign, + const math_t* ws, + cudaStream_t stream) +{ + // Calculate sum (w) = sum(ws * sign) + auto multiply = [] __device__(math_t w, math_t s) { return w * s; }; + raft::linalg::mapThenSumReduce(A, n_active, multiply, stream, ws, sign); + // Calc Aa = 1 / sqrt(sum(w)) + raft::linalg::unaryOp( + A, A, 1, [] __device__(math_t a) { return 1 / sqrt(a); }, stream); +} + +/** + * @brief Calculate the equiangular vector u, w and A according to [1]. + * + * We introduce the following variables (Python like indexing): + * - n_A number of elements in the active set + * - S = diag(sign_A): diagonal matrix with the signs, size [n_A x n_A] + * - X_A = X[:,:n_A] * S, column vectors of the active set size [n_A x n_A] + * - G0 = X.T * X, Gram matrix without signs. We just use the part that + * corresponds to the active set, [n_A x n_A] + * - GA = X_A.T * X_A is the Gram matrix of the active set, size [n_A x n_A] + * GA = S * G0[:n_A, :n_A] * S + * - 1_A = np.ones(n_A) + * - A = (1_A * GA^{-1} * 1_A)^{-1/2}, scalar, see eq (2.5) in [1] + * - w = A GA^{-1} * 1_A, vector of size [n_A] see eq (2.6) in [1] + * - ws = S * w, vector of size [n_A] + * + * The equiangular vector can be expressed the following way (equation 2.6): + * u = X_A * w = X[:,:n_A] S * w = X[:,:n_A] * ws. + * + * The equiangular vector later appears only in an expression like X.T u, which + * can be reformulated as X.T u = X.T X[:,:n_A] S * w = G[:n_A,:n_A] * ws. + * If the gram matrix is given, then we do not need to calculate u, it will be + * sufficient to calculate ws and A. + * + * We use Cholesky decomposition G0 = U.T * U to solve to calculate A and w + * which depend on GA^{-1}. + * + * References: + * [1] B. Efron, T. Hastie, I. Johnstone, R Tibshirani, Least Angle Regression + * The Annals of Statistics (2004) Vol 32, No 2, 407-499 + * http://statweb.stanford.edu/~tibs/ftp/lars.pdf + * + * @param handle RAFT handle + * @param n_active number of active elements + * @param X device array of feature vectors in column major format, size + * [ld_X * n_cols] + * @param n_rows number of training vectors + * @param n_cols number of features + * @param ld_X leading dimension of array X (column stride, ld_X >= n_rows) + * @param sign array with sign of the active set, size [n_cols] + * @param U device pointer to the Cholesky decomposition of G0, + * size [ld_U * n_cols] + * @param ld_U leading dimension of array U (ld_U >= n_cols) + * @param G0 device pointer to Gram matrix G0 = X.T*X (can be nullptr), + * size [ld_G * n_cols]. Note the difference between G0 and + * GA = X_A.T * X_A + * @param ld_G leading dimension of array G0 (ld_G >= n_cols) + * @param workspace workspace for the Cholesky update + * @param ws device pointer, size [n_active] + * @param A device pointer to a scalar + * @param u_eq device pointer to the equiangular vector, only used if + * Gram==nullptr, size [n_rows]. + * @param eps numerical regularizaton parameter for the Cholesky decomposition + * @param stream CUDA stream + * + * @return fit status + */ +template +LarsFitStatus calcEquiangularVec(const raft::handle_t& handle, + idx_t n_active, + math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + math_t* sign, + math_t* U, + idx_t ld_U, + math_t* G0, + idx_t ld_G, + rmm::device_uvector& workspace, + math_t* ws, + math_t* A, + math_t* u_eq, + math_t eps, + cudaStream_t stream) +{ + // Since we added a new vector to the active set, we update the Cholesky + // decomposition (U) + updateCholesky( + handle, n_active, X, n_rows, n_cols, ld_X, U, ld_U, G0, ld_G, workspace, eps, stream); + + // Calculate ws = S GA^{-1} 1_A using U + calcW0(handle, n_active, n_cols, sign, U, ld_U, ws, stream); + + calcA(handle, A, n_active, sign, ws, stream); + + // ws *= Aa + raft::linalg::unaryOp( + ws, ws, n_active, [A] __device__(math_t w) { return (*A) * w; }, stream); + + // Check for numeric error + math_t ws_host; + raft::update_host(&ws_host, ws, 1, stream); + math_t diag_host; // U[n_active-1, n_active-1] + raft::update_host(&diag_host, U + ld_U * (n_active - 1) + n_active - 1, 1, stream); + handle.sync_stream(stream); + if (diag_host < 1e-7) { + RAFT_LOG_WARN( + "Vanising diagonal in Cholesky factorization (%e). This indicates " + "collinear features. Dropping current regressor.", + diag_host); + return LarsFitStatus::kCollinear; + } + if (!std::isfinite(ws_host)) { + RAFT_LOG_WARN("ws=%f is not finite at iteration %d", ws_host, n_active); + return LarsFitStatus::kError; + } + + if (G0 == nullptr) { + // Calculate u_eq only in the case if the Gram matrix is not stored. + math_t one = 1; + math_t zero = 0; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_rows, + n_active, + &one, + X, + ld_X, + ws, + 1, + &zero, + u_eq, + 1, + stream)); + } + return LarsFitStatus::kOk; +} + +/** + * @brief Calculate the maximum step size (gamma) in the equiangular direction. + * + * Let mu = X beta.T be the current prediction vector. The modified solution + * after taking step gamma is defined as mu' = mu + gamma u. With this + * solution the correlation of the covariates in the active set will decrease + * equally, to a new value |c_j(gamma)| = Cmax - gamma A. At the same time + * the correlation of the values in the inactive set changes according to the + * following formula: c_j(gamma) = c_j - gamma a_j. We increase gamma until + * one of correlations from the inactive set becomes equal with the + * correlation from the active set. + * + * References: + * [1] B. Efron, T. Hastie, I. Johnstone, R Tibshirani, Least Angle Regression + * The Annals of Statistics (2004) Vol 32, No 2, 407-499 + * http://statweb.stanford.edu/~tibs/ftp/lars.pdf + * + * @param handle RAFT handle + * @param max_iter maximum number of iterations + * @param n_rows number of samples + * @param n_cols number of valid feature columns + * @param n_active size of the active set (n_active <= max_iter <= n_cols) + * @param cj value of the maximum correlation + * @param A device pointer to a scalar, as defined by eq 2.5 in [1] + * @param cor device pointer to correlation vector, size [n_active] + * @param G device pointer to Gram matrix of the active set (without signs) + * size [n_active * ld_G] + * @param ld_G leading dimension of G (ld_G >= n_cols) + * @param X device array of training vectors in column major format, + * size [n_rows * n_cols]. Only used if the gram matrix is not avaiable. + * @param ld_X leading dimension of X (ld_X >= n_rows) + * @param u device pointer to equiangular vector size [n_rows]. Only used if the + * Gram matrix G is not available. + * @param ws device pointer to the ws vector defined in calcEquiangularVec, + * size [n_active] + * @param gamma device pointer to a scalar. The max step size is returned here. + * @param a_vec device pointer, size [n_cols] + * @param stream CUDA stream + */ +template +void calcMaxStep(const raft::handle_t& handle, + idx_t max_iter, + idx_t n_rows, + idx_t n_cols, + idx_t n_active, + math_t cj, + const math_t* A, + math_t* cor, + const math_t* G, + idx_t ld_G, + const math_t* X, + idx_t ld_X, + const math_t* u, + const math_t* ws, + math_t* gamma, + math_t* a_vec, + cudaStream_t stream) +{ + // In the active set each element has the same correlation, whose absolute + // value is given by Cmax. + math_t Cmax = std::abs(cj); + if (n_active == n_cols) { + // Last iteration, the inactive set is empty we use equation (2.21) + raft::linalg::unaryOp( + gamma, A, 1, [Cmax] __device__(math_t A) { return Cmax / A; }, stream); + } else { + const int n_inactive = n_cols - n_active; + if (G == nullptr) { + // Calculate a = X.T[:,n_active:] * u (2.11) + math_t one = 1; + math_t zero = 0; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_T, + n_rows, + n_inactive, + &one, + X + n_active * ld_X, + ld_X, + u, + 1, + &zero, + a_vec, + 1, + stream)); + } else { + // Calculate a = X.T[:,n_A:] * u = X.T[:, n_A:] * X[:,:n_A] * ws + // = G[n_A:,:n_A] * ws (2.11) + math_t one = 1; + math_t zero = 0; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_inactive, + n_active, + &one, + G + n_active, + ld_G, + ws, + 1, + &zero, + a_vec, + 1, + stream)); + } + const math_t tiny = std::numeric_limits::min(); + const math_t huge = std::numeric_limits::max(); + // + // gamma = min^+_{j \in inactive} {(Cmax - cor_j) / (A-a_j), + // (Cmax + cor_j) / (A+a_j)} (2.13) + auto map = [Cmax, A, tiny, huge] __device__(math_t c, math_t a) -> math_t { + math_t tmp1 = (Cmax - c) / (*A - a + tiny); + math_t tmp2 = (Cmax + c) / (*A + a + tiny); + // We consider only positive elements while we search for the minimum + math_t val = (tmp1 > 0) ? tmp1 : huge; + if (tmp2 > 0 && tmp2 < val) val = tmp2; + return val; + }; + raft::linalg::mapThenReduce( + gamma, n_inactive, huge, map, cub::Min(), stream, cor + n_active, a_vec); + } +} + +/** + * @brief Initialize for Lars training. + * + * We calculate the initial correlation, initialize the indices array, and set + * up pointers to store the Cholesky factorization. + * + * @param handle RAFT handle + * @param X device array of training vectors in column major format, + * size [ld_X * n_cols]. + * @param n_rows number of samples + * @param n_cols number of valid feature columns + * @param ld_X leading dimension of X (ld_X >= n_rows) + * @param y device pointer to regression targets, size [n_rows] + * @param Gram device pointer to Gram matrix (X.T * X), size [n_cols * ld_G], + * can be nullptr + * @param ld_G leading dimension of G (ld_G >= n_cols) + * @param U_buffer device buffer that will be initialized to store the Cholesky + * factorization. Only used if Gram is nullptr. + * @param U device pointer to U + * @param ld_U leading dimension of U + * @param indices host buffer to store feature column indices + * @param cor device pointer to correlation vector, size [n_cols] + * @param max_iter host pointer to the maximum number of iterations + * @param coef_path device pointer to store coefficients along the + * regularization path size [(max_iter + 1) * max_iter], can be nullptr + * @param stream CUDA stream + */ +template +void larsInit(const raft::handle_t& handle, + const math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + const math_t* y, + math_t* Gram, + idx_t ld_G, + rmm::device_uvector& U_buffer, + math_t** U, + idx_t* ld_U, + std::vector& indices, + rmm::device_uvector& cor, + int* max_iter, + math_t* coef_path, + cudaStream_t stream) +{ + if (n_cols < *max_iter) { *max_iter = n_cols; } + if (Gram == nullptr) { + const idx_t align_bytes = 256; + *ld_U = raft::alignTo(*max_iter, align_bytes); + try { + U_buffer.resize((*ld_U) * (*max_iter), stream); + } catch (std::bad_alloc const&) { + THROW( + "Not enough GPU memory! The memory usage depends quadraticaly on the " + "n_nonzero_coefs parameter, try to decrease it."); + } + *U = U_buffer.data(); + } else { + // Set U as G. During the solution in larsFit, the Cholesky factorization + // U will overwrite G. + *U = Gram; + *ld_U = ld_G; + } + std::iota(indices.data(), indices.data() + n_cols, 0); + + math_t one = 1; + math_t zero = 0; + // Set initial correlation to X.T * y + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_T, + n_rows, + n_cols, + &one, + X, + ld_X, + y, + 1, + &zero, + cor.data(), + 1, + stream)); + if (coef_path) { + RAFT_CUDA_TRY( + cudaMemsetAsync(coef_path, 0, sizeof(math_t) * (*max_iter + 1) * (*max_iter), stream)); + } +} + +/** + * @brief Update regression coefficient and correlations + * + * After we calculated the equiangular vector and the step size (gamma) we + * adjust the regression coefficients here. + * + * See calcEquiangularVec for definition of ws. + * + * @param handle RAFT handle + * @param max_iter maximum number of iterations + * @param n_cols number of valid feature columns + * @param n_active number of elements in the active set (n_active <= n_cols) + * @param gamma device pointer to the maximum step size (scalar) + * @param ws device pointer to the ws vector, size [n_cols] + * @param cor device pointer to the correlations, size [n_cols] + * @param a_vec device pointer to a = X.T[:,n_A:] * u, size [n_cols] + * @param beta pointer to regression coefficents, size [max_iter] + * @param coef_path device pointer to all the coefficients along the + * regularization path, size [(max_iter + 1) * max_iter] + * @param stream CUDA stream + */ +template +void updateCoef(const raft::handle_t& handle, + idx_t max_iter, + idx_t n_cols, + idx_t n_active, + math_t* gamma, + const math_t* ws, + math_t* cor, + math_t* a_vec, + math_t* beta, + math_t* coef_path, + cudaStream_t stream) +{ + // It is sufficient to update correlations only for the inactive set. + // cor[n_active:] -= gamma * a_vec + int n_inactive = n_cols - n_active; + if (n_inactive > 0) { + raft::linalg::binaryOp( + cor + n_active, + cor + n_active, + a_vec, + n_inactive, + [gamma] __device__(math_t c, math_t a) { return c - *gamma * a; }, + stream); + } + // beta[:n_active] += gamma * ws + raft::linalg::binaryOp( + beta, + beta, + ws, + n_active, + [gamma] __device__(math_t b, math_t w) { return b + *gamma * w; }, + stream); + if (coef_path) { raft::copy(coef_path + n_active * max_iter, beta, n_active, stream); } +} + +/** + * @brief Train a regressor using Least Angle Regression. + * + * Least Angle Regression (LAR or LARS) is a model selection algorithm. It + * builds up the model using the following algorithm: + * + * 1. We start with all the coefficients equal to zero. + * 2. At each step we select the predictor that has the largest absolute + * correlation with the residual. + * 3. We take the largest step possible in the direction which is equiangular + * with all the predictors selected so far. The largest step is determined + * such that using this step a new predictor will have as much correlation + * with the residual as any of the currently active predictors. + * 4. Stop if max_iter reached or all the predictors are used, or if the + * correlation between any unused predictor and the residual is lower than + * a tolerance. + * + * The solver is based on [1]. The equations referred in the comments correspond + * to the equations in the paper. + * + * Note: this algorithm assumes that the offset is removed from X and y, and + * each feature is normalized: + * - sum_i y_i = 0, + * - sum_i x_{i,j} = 0, sum_i x_{i,j}^2=1 for j=0..n_col-1 + * + * References: + * [1] B. Efron, T. Hastie, I. Johnstone, R Tibshirani, Least Angle Regression + * The Annals of Statistics (2004) Vol 32, No 2, 407-499 + * http://statweb.stanford.edu/~tibs/ftp/lars.pdf + * + * @param handle RAFT handle + * @param X device array of training vectors in column major format, + * size [n_rows * n_cols]. Note that the columns of X will be permuted if + * the Gram matrix is not specified. It is expected that X is normalized so + * that each column has zero mean and unit variance. + * @param n_rows number of training samples + * @param n_cols number of feature columns + * @param y device array of the regression targets, size [n_rows]. y should + * be normalized to have zero mean. + * @param beta device array of regression coefficients, has to be allocated on + * entry, size [max_iter] + * @param active_idx device array containing the indices of active variables. + * Must be allocated on entry. Size [max_iter] + * @param alphas device array to return the maximum correlation along the + * regularization path. Must be allocated on entry, size [max_iter+1]. + * @param n_active host pointer to return the number of active elements (scalar) + * @param Gram device array containing Gram matrix containing X.T * X. Can be + * nullptr. + * @param max_iter maximum number of iterations, this equals with the maximum + * number of coefficients returned. max_iter <= n_cols. + * @param coef_path coefficients along the regularization path are returned + * here. Must be nullptr, or a device array already allocated on entry. + * Size [max_iter * (max_iter+1)]. + * @param verbosity verbosity level + * @param ld_X leading dimension of X (stride of columns) + * @param ld_G leading dimesion of G + * @param eps numeric parameter for Cholesky rank one update + */ +template +void larsFit(const raft::handle_t& handle, + math_t* X, + idx_t n_rows, + idx_t n_cols, + const math_t* y, + math_t* beta, + idx_t* active_idx, + math_t* alphas, + idx_t* n_active, + math_t* Gram = nullptr, + int max_iter = 500, + math_t* coef_path = nullptr, + int verbosity = 0, + idx_t ld_X = 0, + idx_t ld_G = 0, + math_t eps = -1) +{ + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 0, "Parameter n_rows: number of rows cannot be less than one"); + ML::Logger::get().setLevel(verbosity); + + // Set default ld parameters if needed. + if (ld_X == 0) ld_X = n_rows; + if (Gram && ld_G == 0) ld_G = n_cols; + + cudaStream_t stream = handle.get_stream(); + + // We will use either U_buffer.data() to store the Cholesky factorization, or + // store it in place at Gram. Pointer U will point to the actual storage. + rmm::device_uvector U_buffer(0, stream); + idx_t ld_U = 0; + math_t* U = nullptr; + + // Indices of elements in the active set. + std::vector indices(n_cols); + // Sign of the correlation at the time when the element was added to the + // active set. + rmm::device_uvector sign(n_cols, stream); + + // Correlation between the residual mu = y - X.T*beta and columns of X + rmm::device_uvector cor(n_cols, stream); + + // Temporary arrays used by the solver + rmm::device_scalar A(stream); + rmm::device_uvector a_vec(n_cols, stream); + rmm::device_scalar gamma(stream); + rmm::device_uvector u_eq(n_rows, stream); + rmm::device_uvector ws(max_iter, stream); + rmm::device_uvector workspace(n_cols, stream); + + larsInit(handle, + X, + n_rows, + n_cols, + ld_X, + y, + Gram, + ld_G, + U_buffer, + &U, + &ld_U, + indices, + cor, + &max_iter, + coef_path, + stream); + + // If we detect collinear features, then we will move them to the end of the + // correlation array and mark them as invalid (simply by decreasing + // n_valid_cols). At every iteration the solver is only working with the valid + // columns stored at X[:,:n_valid_cols], and G[:n_valid_cols, :n_valid_cols] + // cor[:n_valid_cols]. + int n_valid_cols = n_cols; + + *n_active = 0; + for (int i = 0; i < max_iter; i++) { + math_t cj; + idx_t j; + LarsFitStatus status = selectMostCorrelated( + *n_active, n_valid_cols, cor.data(), &cj, workspace, &j, n_rows, indices.data(), i, stream); + if (status != LarsFitStatus::kOk) { break; } + + moveToActive(handle.get_cublas_handle(), + n_active, + j, + X, + n_rows, + n_valid_cols, + ld_X, + cor.data(), + indices.data(), + Gram, + ld_G, + sign.data(), + stream); + + status = calcEquiangularVec(handle, + *n_active, + X, + n_rows, + n_valid_cols, + ld_X, + sign.data(), + U, + ld_U, + Gram, + ld_G, + workspace, + ws.data(), + A.data(), + u_eq.data(), + eps, + stream); + + if (status == LarsFitStatus::kError) { + if (*n_active > 1) { RAFT_LOG_WARN("Returning with last valid model."); } + *n_active -= 1; + break; + } else if (status == LarsFitStatus::kCollinear) { + // We move the current feature to the invalid set + swapFeatures(handle.get_cublas_handle(), + n_valid_cols - 1, + *n_active - 1, + X, + n_rows, + n_cols, + ld_X, + cor.data(), + indices.data(), + Gram, + ld_G, + stream); + *n_active -= 1; + n_valid_cols--; + continue; + } + + calcMaxStep(handle, + max_iter, + n_rows, + n_valid_cols, + *n_active, + cj, + A.data(), + cor.data(), + Gram, + ld_G, + X, + ld_X, + u_eq.data(), + ws.data(), + gamma.data(), + a_vec.data(), + stream); + + updateCoef(handle, + max_iter, + n_valid_cols, + *n_active, + gamma.data(), + ws.data(), + cor.data(), + a_vec.data(), + beta, + coef_path, + stream); + } + + if (*n_active > 0) { + // Apply sklearn definition of alphas = cor / n_rows + raft::linalg::unaryOp( + alphas, + cor.data(), + *n_active, + [n_rows] __device__(math_t c) { return abs(c) / n_rows; }, + stream); + + // Calculate the final correlation. We use the correlation from the last + // iteration and apply the changed during the last LARS iteration: + // alpha[n_active] = cor[n_active-1] - gamma * A + math_t* gamma_ptr = gamma.data(); + math_t* A_ptr = A.data(); + raft::linalg::unaryOp( + alphas + *n_active, + cor.data() + *n_active - 1, + 1, + [gamma_ptr, A_ptr, n_rows] __device__(math_t c) { + return abs(c - (*gamma_ptr) * (*A_ptr)) / n_rows; + }, + stream); + + raft::update_device(active_idx, indices.data(), *n_active, stream); + } else { + THROW("Model is not fitted."); + } +} + +/** + * @brief Predict with least angle regressor. + * + * @param handle RAFT handle + * @param X device array of training vectors in column major format, + * size [n_rows * n_cols]. + * @param n_rows number of training samples + * @param n_cols number of feature columns + * @param ld_X leading dimension of X (stride of columns) + * @param beta device array of regression coefficients, size [n_active] + * @param n_active the number of regression coefficients + * @param active_idx device array containing the indices of active variables. + * Only these columns of X will be used for prediction, size [n_active]. + * @param intercept + * @param preds device array to store the predictions, size [n_rows]. Must be + * allocated on entry. + */ +template +void larsPredict(const raft::handle_t& handle, + const math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + const math_t* beta, + idx_t n_active, + idx_t* active_idx, + math_t intercept, + math_t* preds) +{ + cudaStream_t stream = handle.get_stream(); + rmm::device_uvector beta_sorted(0, stream); + rmm::device_uvector X_active_cols(0, stream); + auto execution_policy = handle.get_thrust_policy(); + + if (n_active == 0 || n_rows == 0) return; + + if (n_active == n_cols) { + // We make a copy of the beta coefs and sort them + beta_sorted.resize(n_active, stream); + rmm::device_uvector idx_sorted(n_active, stream); + raft::copy(beta_sorted.data(), beta, n_active, stream); + raft::copy(idx_sorted.data(), active_idx, n_active, stream); + thrust::device_ptr beta_ptr(beta_sorted.data()); + thrust::device_ptr idx_ptr(idx_sorted.data()); + thrust::sort_by_key(execution_policy, idx_ptr, idx_ptr + n_active, beta_ptr); + beta = beta_sorted.data(); + } else { + // We collect active columns of X to contiguous space + X_active_cols.resize(n_active * ld_X, stream); + const int TPB = 64; + raft::cache::get_vecs<<>>( + X, ld_X, active_idx, n_active, X_active_cols.data()); + RAFT_CUDA_TRY(cudaGetLastError()); + X = X_active_cols.data(); + } + // Initialize preds = intercept + thrust::device_ptr pred_ptr(preds); + thrust::fill(execution_policy, pred_ptr, pred_ptr + n_rows, intercept); + math_t one = 1; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_rows, + n_active, + &one, + X, + ld_X, + beta, + 1, + &one, + preds, + 1, + stream)); +} +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/learning_rate.h b/cpp/include/raft/solver/detail/learning_rate.h new file mode 100644 index 0000000000..c83a65d472 --- /dev/null +++ b/cpp/include/raft/solver/detail/learning_rate.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::solver::detail { + +template +math_t max(math_t a, math_t b) +{ + return (a < b) ? b : a; + ; +} + +template +math_t invScaling(math_t eta, math_t power_t, int t) +{ + return (eta / pow(t, power_t)); +} + +template +math_t regDLoss(math_t a, math_t b) +{ + return a - b; +} + +template +math_t calOptimalInit(math_t alpha) +{ + math_t typw = sqrt(math_t(1.0) / sqrt(alpha)); + math_t initial_eta0 = typw / max(math_t(1.0), regDLoss(-typw, math_t(1.0))); + return (math_t(1.0) / (initial_eta0 * alpha)); +} + +template +math_t optimal(math_t alpha, math_t optimal_init, int t) +{ + return math_t(1.0) / (alpha * (optimal_init + t - 1)); +} + +template +math_t calLearningRate(ML::lr_type lr_type, math_t eta, math_t power_t, math_t alpha, math_t t) +{ + if (lr_type == ML::lr_type::CONSTANT) { + return eta; + } else if (lr_type == ML::lr_type::INVSCALING) { + return invScaling(eta, power_t, t); + } else if (lr_type == ML::lr_type::OPTIMAL) { + return optimal(alpha, eta, t); + } else { + return math_t(0); + } +} + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/objectives/hinge.cuh b/cpp/include/raft/solver/detail/objectives/hinge.cuh new file mode 100644 index 0000000000..c6152a8fbe --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/hinge.cuh @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "penalty.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail::objectives { + +template +void hingeLossGradMult(math_t* data, + const math_t* vec1, + const math_t* vec2, + idx_type n_row, + idx_type n_col, + cudaStream_t stream) +{ + raft::linalg::matrixVectorOp( + data, + data, + vec1, + vec2, + n_col, + n_row, + false, + false, + [] __device__(math_t a, math_t b, math_t c) { + if (c < math_t(1)) + return -a * b; + else + return math_t(0); + }, + stream); +} + +template +void hingeLossSubtract( + math_t* out, const math_t* in, math_t scalar, idx_type len, cudaStream_t stream) +{ + raft::linalg::unaryOp( + out, + in, + len, + [scalar] __device__(math_t in) { + if (in < scalar) + return math_t(1) - in; + else + return math_t(0); + }, + stream); +} + +template +void hingeH(const raft::handle_t& handle, + const math_t* input, + idx_type n_rows, + idx_type n_cols, + const math_t* coef, + math_t* pred, + math_t intercept, + cudaStream_t stream) +{ + raft::linalg::gemm( + handle, input, n_rows, n_cols, coef, pred, n_rows, 1, CUBLAS_OP_N, CUBLAS_OP_N, stream); + + if (intercept != math_t(0)) raft::linalg::addScalar(pred, pred, intercept, n_rows, stream); + + sign(pred, pred, math_t(1.0), n_rows, stream); +} + +template +void hingeLossGrads(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + const math_t* labels, + const math_t* coef, + math_t* grads, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + + raft::linalg::gemm(handle, + input, + n_rows, + n_cols, + coef, + labels_pred.data(), + n_rows, + 1, + CUBLAS_OP_N, + CUBLAS_OP_N, + stream); + + raft::linalg::eltwiseMultiply(labels_pred.data(), labels_pred.data(), labels, n_rows, stream); + hingeLossGradMult(input, labels, labels_pred.data(), n_rows, n_cols, stream); + raft::stats::mean(grads, input, n_cols, n_rows, false, false, stream); + + rmm::device_uvector pen_grads(0, stream); + + if (pen != penalty::NONE) pen_grads.resize(n_cols, stream); + + if (pen == penalty::L1) { + lassoGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridgeGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnetGrad(pen_grads.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(grads, grads, pen_grads.data(), n_cols, stream); } +} + +template +void hingeLoss(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + const math_t* labels, + const math_t* coef, + math_t* loss, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + + raft::linalg::gemm(handle, + input, + n_rows, + n_cols, + coef, + labels_pred.data(), + n_rows, + 1, + CUBLAS_OP_N, + CUBLAS_OP_N, + stream); + + raft::linalg::eltwiseMultiply(labels_pred.data(), labels_pred.data(), labels, n_rows, stream); + + hingeLossSubtract(labels_pred.data(), labels_pred.data(), math_t(1), n_rows, stream); + + raft::stats::sum(loss, labels_pred.data(), 1, n_rows, false, stream); + + rmm::device_uvector pen_val(0, stream); + + if (pen != penalty::NONE) pen_val.resize(1, stream); + + if (pen == penalty::L1) { + lasso(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridge(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnet(pen_val.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(loss, loss, pen_val.data(), 1, stream); } +} + +}; // namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/linearReg.cuh b/cpp/include/raft/solver/detail/objectives/linearReg.cuh new file mode 100644 index 0000000000..78a22b10e7 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/linearReg.cuh @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "penalty.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail::objectives { + +template +void linearRegH(const raft::handle_t& handle, + const math_t* input, + int n_rows, + int n_cols, + const math_t* coef, + math_t* pred, + math_t intercept, + cudaStream_t stream) +{ + raft::linalg::gemm( + handle, input, n_rows, n_cols, coef, pred, n_rows, 1, CUBLAS_OP_N, CUBLAS_OP_N, stream); + + if (intercept != math_t(0)) raft::linalg::addScalar(pred, pred, intercept, n_rows, stream); +} + +template +void linearRegLossGrads(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + const math_t* labels, + const math_t* coef, + math_t* grads, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + + linearRegH(handle, input, n_rows, n_cols, coef, labels_pred.data(), math_t(0), stream); + raft::linalg::subtract(labels_pred.data(), labels_pred.data(), labels, n_rows, stream); + raft::matrix::matrixVectorBinaryMult( + input, labels_pred.data(), n_rows, n_cols, false, false, stream); + + raft::stats::mean(grads, input, n_cols, n_rows, false, false, stream); + raft::linalg::scalarMultiply(grads, grads, math_t(2), n_cols, stream); + + rmm::device_uvector pen_grads(0, stream); + + if (pen != penalty::NONE) pen_grads.resize(n_cols, stream); + + if (pen == penalty::L1) { + lassoGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridgeGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnetGrad(pen_grads.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(grads, grads, pen_grads.data(), n_cols, stream); } +} + +template +void linearRegLoss(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + const math_t* labels, + const math_t* coef, + math_t* loss, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + + linearRegH(handle, input, n_rows, n_cols, coef, labels_pred.data(), math_t(0), stream); + + raft::linalg::subtract(labels_pred.data(), labels, labels_pred.data(), n_rows, stream); + raft::matrix::power(labels_pred.data(), n_rows, stream); + raft::stats::mean(loss, labels_pred.data(), 1, n_rows, false, false, stream); + + rmm::device_uvector pen_val(0, stream); + + if (pen != penalty::NONE) pen_val.resize(1, stream); + + if (pen == penalty::L1) { + lasso(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridge(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnet(pen_val.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(loss, loss, pen_val.data(), 1, stream); } +} + +}; // namespace raft::solver::detail::objectives \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/objectives/log.cuh b/cpp/include/raft/solver/detail/objectives/log.cuh new file mode 100644 index 0000000000..c62e7e580c --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/log.cuh @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::solver::detail::objectives { + +template +void f_log(T* out, T* in, T scalar, IdxType len, cudaStream_t stream) +{ + raft::linalg::unaryOp( + out, in, len, [scalar] __device__(T in) { return raft::myLog(in) * scalar; }, stream); +} + +}; // end namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/logisticReg.cuh b/cpp/include/raft/solver/detail/objectives/logisticReg.cuh new file mode 100644 index 0000000000..40a2d4b2c4 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/logisticReg.cuh @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "penalty.cuh" +#include "sigmoid.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail::objectives { + +template +void logisticRegH(const raft::handle_t& handle, + const math_t* input, + int n_rows, + int n_cols, + const math_t* coef, + math_t* pred, + math_t intercept, + cudaStream_t stream) +{ + raft::linalg::gemm( + handle, input, n_rows, n_cols, coef, pred, n_rows, 1, CUBLAS_OP_N, CUBLAS_OP_N, stream); + + if (intercept != math_t(0)) raft::linalg::addScalar(pred, pred, intercept, n_rows, stream); + + sigmoid(pred, pred, n_rows, stream); +} + +template +void logisticRegLossGrads(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + const math_t* labels, + const math_t* coef, + math_t* grads, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + + logisticRegH(handle, input, n_rows, n_cols, coef, labels_pred.data(), math_t(0), stream); + raft::linalg::subtract(labels_pred.data(), labels_pred.data(), labels, n_rows, stream); + raft::matrix::matrixVectorBinaryMult( + input, labels_pred.data(), n_rows, n_cols, false, false, stream); + + raft::stats::mean(grads, input, n_cols, n_rows, false, false, stream); + + rmm::device_uvector pen_grads(0, stream); + + if (pen != penalty::NONE) pen_grads.resize(n_cols, stream); + + if (pen == penalty::L1) { + lassoGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridgeGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnetGrad(pen_grads.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(grads, grads, pen_grads.data(), n_cols, stream); } +} + +template +void logLoss(T* out, T* label, T* label_pred, int len, cudaStream_t stream); + +template <> +inline void logLoss(float* out, float* label, float* label_pred, int len, cudaStream_t stream) +{ + raft::linalg::binaryOp( + out, + label, + label_pred, + len, + [] __device__(float y, float y_pred) { return -y * logf(y_pred) - (1 - y) * logf(1 - y_pred); }, + stream); +} + +template <> +inline void logLoss(double* out, double* label, double* label_pred, int len, cudaStream_t stream) +{ + raft::linalg::binaryOp( + out, + label, + label_pred, + len, + [] __device__(double y, double y_pred) { + return -y * log(y_pred) - (1 - y) * logf(1 - y_pred); + }, + stream); +} + +template +void logisticRegLoss(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + const math_t* coef, + math_t* loss, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + logisticRegH(handle, input, n_rows, n_cols, coef, labels_pred.data(), math_t(0), stream); + logLoss(labels_pred.data(), labels, labels_pred.data(), n_rows, stream); + + raft::stats::mean(loss, labels_pred.data(), 1, n_rows, false, false, stream); + + rmm::device_uvector pen_val(0, stream); + + if (pen != penalty::NONE) pen_val.resize(1, stream); + + if (pen == penalty::L1) { + lasso(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridge(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnet(pen_val.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(loss, loss, pen_val.data(), 1, stream); } +} + +}; // namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/penalty.cuh b/cpp/include/raft/solver/detail/objectives/penalty.cuh new file mode 100644 index 0000000000..db60f029a9 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/penalty.cuh @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "sign.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail::objectives { + +enum penalty { + NONE, + L1, + L2, + ELASTICNET, +}; + +template +void lasso(math_t* out, const math_t* coef, const int len, const math_t alpha, cudaStream_t stream) +{ + raft::linalg::rowNorm(out, coef, len, 1, raft::linalg::NormType::L1Norm, true, stream); + raft::linalg::scalarMultiply(out, out, alpha, 1, stream); +} + +template +void lassoGrad( + math_t* grad, const math_t* coef, const int len, const math_t alpha, cudaStream_t stream) +{ + sign(grad, coef, alpha, len, stream); +} + +template +void ridge(math_t* out, const math_t* coef, const int len, const math_t alpha, cudaStream_t stream) +{ + raft::linalg::rowNorm(out, coef, len, 1, raft::linalg::NormType::L2Norm, true, stream); + raft::linalg::scalarMultiply(out, out, alpha, 1, stream); +} + +template +void ridgeGrad( + math_t* grad, const math_t* coef, const int len, const math_t alpha, cudaStream_t stream) +{ + raft::linalg::scalarMultiply(grad, coef, math_t(2) * alpha, len, stream); +} + +template +void elasticnet(math_t* out, + const math_t* coef, + const int len, + const math_t alpha, + const math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_scalar out_lasso(stream); + + ridge(out, coef, len, alpha * (math_t(1) - l1_ratio), stream); + lasso(out_lasso.data(), coef, len, alpha * l1_ratio, stream); + + raft::linalg::add(out, out, out_lasso.data(), 1, stream); +} + +template +void elasticnetGrad(math_t* grad, + const math_t* coef, + const int len, + const math_t alpha, + const math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector grad_lasso(len, stream); + + ridgeGrad(grad, coef, len, alpha * (math_t(1) - l1_ratio), stream); + lassoGrad(grad_lasso.data(), coef, len, alpha * l1_ratio, stream); + + raft::linalg::add(grad, grad, grad_lasso.data(), len, stream); +} + +}; // namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/sigmoid.cuh b/cpp/include/raft/solver/detail/objectives/sigmoid.cuh new file mode 100644 index 0000000000..a06a305e44 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/sigmoid.cuh @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::solver::detail::objectives { + +template +void sigmoid(T* out, T* in, IdxType len, cudaStream_t stream) +{ + T one = T(1); + raft::linalg::unaryOp( + out, in, len, [one] __device__(T in) { return one / (one + raft::myExp(-in)); }, stream); +} + +}; // end namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/sign.cuh b/cpp/include/raft/solver/detail/objectives/sign.cuh new file mode 100644 index 0000000000..ca37727355 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/sign.cuh @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::solver::detail::objectives { + +template +void sign( + math_t* out, const math_t* in, const math_t scalar, const idx_type len, cudaStream_t stream) +{ + raft::linalg::unaryOp( + out, + in, + len, + [scalar] __device__(math_t in) { + if (in < math_t(0)) + return (math_t(-1) * scalar); + else if (in > math_t(0)) + return (math_t(1) * scalar); + else + return math_t(0); + }, + stream); +} + +template +void sign(math_t* out, const math_t* in, const idx_type n_len, cudaStream_t stream) +{ + math_t scalar = math_t(1); + sign(out, in, scalar, n_len, stream); +} + +}; // namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/softThres.cuh b/cpp/include/raft/solver/detail/objectives/softThres.cuh new file mode 100644 index 0000000000..485fc4f688 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/softThres.cuh @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::solver::detail::objectives { + +template +void softThres( + math_t* out, const math_t* in, const math_t thres, const int len, cudaStream_t stream) +{ + raft::linalg::unaryOp( + out, + in, + len, + [thres] __device__(math_t in) { + if (in > math_t(0) && thres < raft::myAbs(in)) + return in - thres; + else if (in < math_t(0) && thres < raft::myAbs(in)) + return in + thres; + else + return math_t(0); + }, + stream); +} + +}; // namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/preprocess.cuh b/cpp/include/raft/solver/detail/preprocess.cuh new file mode 100644 index 0000000000..3fd863df42 --- /dev/null +++ b/cpp/include/raft/solver/detail/preprocess.cuh @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail { + +/** + * @brief Center and scale the data, depending on the flags fit_intercept and normalize + * + * @tparam math_t the element type + * @param [inout] input the column-major data of size [n_rows, n_cols] + * @param [in] n_rows + * @param [in] n_cols + * @param [inout] labels vector of size [n_rows] + * @param [out] intercept + * @param [out] mu_input the column-wise means of the input of size [n_cols] + * @param [out] mu_labels the scalar mean of the target (labels vector) + * @param [out] norm2_input the column-wise standard deviations of the input of size [n_cols]; + * note, the biased estimator is used to match sklearn's StandardScaler + * (dividing by n_rows, not by (n_rows - 1)). + * @param [in] fit_intercept whether to center the data / to fit the intercept + * @param [in] normalize whether to normalize the data + * @param [in] stream + */ +template +void preProcessData(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + math_t* intercept, + math_t* mu_input, + math_t* mu_labels, + math_t* norm2_input, + bool fit_intercept, + bool normalize, + math_t* sample_weight = nullptr) +{ + cudaStream_t stream = handle.get_stream(); + raft::common::nvtx::range fun_scope("ML::GLM::preProcessData-%d-%d", n_rows, n_cols); + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + + if (fit_intercept) { + if (normalize && sample_weight == nullptr) { + raft::stats::meanvar(mu_input, norm2_input, input, n_cols, n_rows, false, false, stream); + raft::linalg::unaryOp( + norm2_input, + norm2_input, + n_cols, + [] __device__(math_t v) { return raft::mySqrt(v); }, + stream); + raft::matrix::linewiseOp( + input, + input, + n_rows, + n_cols, + false, + [] __device__(math_t x, math_t m, math_t s) { return s > 1e-10 ? (x - m) / s : 0; }, + stream, + mu_input, + norm2_input); + } else { + if (sample_weight != nullptr) { + raft::stats::weightedMean( + mu_input, input, sample_weight, n_cols, n_rows, false, false, stream); + } else { + raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream); + } + raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream); + if (normalize) { + raft::linalg::colNorm(norm2_input, + input, + n_cols, + n_rows, + raft::linalg::L2Norm, + false, + stream, + [] __device__(math_t v) { return raft::mySqrt(v); }); + raft::matrix::matrixVectorBinaryDivSkipZero( + input, norm2_input, n_rows, n_cols, false, true, stream, true); + } + } + + if (sample_weight != nullptr) { + raft::stats::weightedMean(mu_labels, labels, sample_weight, 1, n_rows, true, false, stream); + } else { + raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); + } + raft::stats::meanCenter(labels, labels, mu_labels, 1, n_rows, false, true, stream); + } +} + +template +void postProcessData(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + math_t* coef, + math_t* intercept, + math_t* mu_input, + math_t* mu_labels, + math_t* norm2_input, + bool fit_intercept, + bool normalize) +{ + cudaStream_t stream = handle.get_stream(); + raft::common::nvtx::range fun_scope("ML::GLM::postProcessData-%d-%d", n_rows, n_cols); + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + + cublasHandle_t cublas_handle = handle.get_cublas_handle(); + rmm::device_scalar d_intercept(stream); + + if (normalize) { + raft::matrix::matrixVectorBinaryDivSkipZero( + coef, norm2_input, 1, n_cols, false, true, stream, true); + } + + raft::linalg::gemm( + handle, mu_input, 1, n_cols, coef, d_intercept.data(), 1, 1, CUBLAS_OP_N, CUBLAS_OP_N, stream); + + raft::linalg::subtract(d_intercept.data(), mu_labels, d_intercept.data(), 1, stream); + *intercept = d_intercept.value(stream); + + if (normalize) { + raft::matrix::linewiseOp( + input, + input, + n_rows, + n_cols, + false, + [] __device__(math_t x, math_t m, math_t s) { return s * x + m; }, + stream, + mu_input, + norm2_input); + } else { + raft::stats::meanAdd(input, input, mu_input, n_cols, n_rows, false, true, stream); + } + raft::stats::meanAdd(labels, labels, mu_labels, 1, n_rows, false, true, stream); +} + +}; // end namespace raft::solver::detail \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/qn/objectives/base.cuh b/cpp/include/raft/solver/detail/qn/objectives/base.cuh new file mode 100644 index 0000000000..288bfac1c8 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/base.cuh @@ -0,0 +1,240 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace raft::solver::quasi_newton::detail::objectives { + +template +inline void linearFwd(const raft::handle_t& handle, + SimpleDenseMat& Z, + const SimpleMat& X, + const SimpleDenseMat& W, + cudaStream_t stream) +{ + // Forward pass: compute Z <- W * X.T + bias + const bool has_bias = X.n != W.n; + const int D = X.n; + if (has_bias) { + SimpleVec bias; + SimpleDenseMat weights; + col_ref(W, bias, D); + col_slice(W, weights, 0, D); + // We implement Z <- W * X^T + b by + // - Z <- b (broadcast): TODO reads Z unnecessarily atm + // - Z <- W * X^T + Z : TODO can be fused in CUTLASS? + auto set_bias = [] __device__(const T z, const T b) { return b; }; + raft::linalg::matrixVectorOp( + Z.data, Z.data, bias.data, Z.n, Z.m, false, false, set_bias, stream); + + Z.assign_gemm(handle, 1, weights, false, X, true, 1, stream); + } else { + Z.assign_gemm(handle, 1, W, false, X, true, 0, stream); + } +} + +template +inline void linearBwd(const raft::handle_t& handle, + SimpleDenseMat& G, + const SimpleMat& X, + const SimpleDenseMat& dZ, + bool setZero, + cudaStream_t stream) +{ + // Backward pass: + // - compute G <- dZ * X.T + // - for bias: Gb = mean(dZ, 1) + + const bool has_bias = X.n != G.n; + const int D = X.n; + const T beta = setZero ? T(0) : T(1); + if (has_bias) { + SimpleVec Gbias; + SimpleDenseMat Gweights; + col_ref(G, Gbias, D); + col_slice(G, Gweights, 0, D); + + // TODO can this be fused somehow? + Gweights.assign_gemm(handle, 1.0 / X.m, dZ, false, X, false, beta, stream); + raft::stats::mean(Gbias.data, dZ.data, dZ.m, dZ.n, false, true, stream); + } else { + G.assign_gemm(handle, 1.0 / X.m, dZ, false, X, false, beta, stream); + } +} + +template +struct QNLinearBase : LinearDims { + typedef SimpleDenseMat Mat; + typedef SimpleVec Vec; + + const raft::handle_t& handle; + T* sample_weights; + T weights_sum; + + QNLinearBase(const raft::handle_t& handle, int D, int C, bool fit_intercept) + : LinearDims(C, D, fit_intercept), handle(handle), sample_weights(nullptr), weights_sum(0) + { + } + + void add_sample_weights(T* sample_weights, int n_samples, cudaStream_t stream) + { + this->sample_weights = sample_weights; + this->weights_sum = thrust::reduce(thrust::cuda::par.on(stream), + sample_weights, + sample_weights + n_samples, + (T)0, + thrust::plus()); + } + + /* + * Computes the following: + * 1. Z <- dL/DZ + * 2. loss_val <- sum loss(Z) + * + * Default: elementwise application of loss and its derivative + * + * NB: for this method to work, loss implementations must have two functor fields `lz` and `dlz`. + * These two compute loss value and its derivative w.r.t. `z`. + */ + inline void getLossAndDZ(T* loss_val, + SimpleDenseMat& Z, + const SimpleVec& y, + cudaStream_t stream) + { + // Base impl assumes simple case C = 1 + // TODO would be nice to have a kernel that fuses these two steps + // This would be easy, if mapThenSumReduce allowed outputing the result of + // map (supporting inplace) + auto lz_copy = static_cast(this)->lz; + auto dlz_copy = static_cast(this)->dlz; + if (this->sample_weights) { // Sample weights are in use + T normalization = 1.0 / this->weights_sum; + raft::linalg::mapThenSumReduce( + loss_val, + y.len, + [lz_copy, normalization] __device__(const T y, const T z, const T weight) { + return lz_copy(y, z) * (weight * normalization); + }, + stream, + y.data, + Z.data, + sample_weights); + raft::linalg::map_k( + Z.data, + y.len, + [dlz_copy] __device__(const T y, const T z, const T weight) { + return weight * dlz_copy(y, z); + }, + stream, + y.data, + Z.data, + sample_weights); + } else { // Sample weights are not used + T normalization = 1.0 / y.len; + raft::linalg::mapThenSumReduce( + loss_val, + y.len, + [lz_copy, normalization] __device__(const T y, const T z) { + return lz_copy(y, z) * normalization; + }, + stream, + y.data, + Z.data); + raft::linalg::binaryOp(Z.data, y.data, Z.data, y.len, dlz_copy, stream); + } + } + + inline void loss_grad(T* loss_val, + Mat& G, + const Mat& W, + const SimpleMat& Xb, + const Vec& yb, + Mat& Zb, + cudaStream_t stream, + bool initGradZero = true) + { + Loss* loss = static_cast(this); // static polymorphism + + linearFwd(handle, Zb, Xb, W, stream); // linear part: forward pass + loss->getLossAndDZ(loss_val, Zb, yb, stream); // loss specific part + linearBwd(handle, G, Xb, Zb, initGradZero, + stream); // linear part: backward pass + } +}; + +template +struct QNWithData : LinearDims { + const SimpleMat* X; + const SimpleVec* y; + SimpleDenseMat* Z; + QuasiNewtonObjective* objective; + + QNWithData(QuasiNewtonObjective* obj, + const SimpleMat& X, + const SimpleVec& y, + SimpleDenseMat& Z) + : objective(obj), X(&X), y(&y), Z(&Z), LinearDims(obj->C, obj->D, obj->fit_intercept) + { + } + + // interface exposed to typical non-linear optimizers + inline T operator()(const SimpleVec& wFlat, + SimpleVec& gradFlat, + T* dev_scalar, + cudaStream_t stream) + { + SimpleDenseMat W(wFlat.data, C, dims); + SimpleDenseMat G(gradFlat.data, C, dims); + objective->loss_grad(dev_scalar, G, W, *X, *y, *Z, stream); + T loss_host; + raft::update_host(&loss_host, dev_scalar, 1, stream); + raft::interruptible::synchronize(stream); + return loss_host; + } + + /** + * @brief Calculate a norm of the gradient computed using the given Loss instance. + * + * This function is intended to be used in `check_convergence`; it's output is supposed + * to be proportional to the loss value w.r.t. the number of features (D). + * + * Different loss functions may scale differently with the number of features (D). + * This has an effect on the convergence criteria. To account for that, we let a + * loss function define its preferred metric. Normally, we differentiate between the + * L2 norm (e.g. for Squared loss) and LInf norm (e.g. for Softmax loss). + */ + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return objective->gradNorm(grad, dev_scalar, stream); + } +}; + +}; // namespace raft::solver::quasi_newton::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh b/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh new file mode 100644 index 0000000000..d90c30dc1c --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base.cuh" +#include +#include +#include + +namespace raft::solver::quasi_newton::detail::objectives { + +template +struct HingeLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + inline __device__ T operator()(const T y, const T z) const + { + T s = 2 * y - 1; + return raft::myMax(0, 1 - s * z); + } + } lz; + + const struct Dlz { + inline __device__ T operator()(const T y, const T z) const + { + T s = 2 * y - 1; + return s * z <= 1 ? -s : 0; + } + } dlz; + + HingeLoss(const raft::handle_t& handle, int D, bool has_bias) + : Super(handle, D, 1, has_bias), lz{}, dlz{} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrm1(grad, dev_scalar, stream); + } +}; + +template +struct SqHingeLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + inline __device__ T operator()(const T y, const T z) const + { + T s = 2 * y - 1; + T t = raft::myMax(0, 1 - s * z); + return t * t; + } + } lz; + + const struct Dlz { + inline __device__ T operator()(const T y, const T z) const + { + T s = 2 * y - 1; + return s * z <= 1 ? z - s : 0; + } + } dlz; + + SqHingeLoss(const raft::handle_t& handle, int D, bool has_bias) + : Super(handle, D, 1, has_bias), lz{}, dlz{} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return squaredNorm(grad, dev_scalar, stream) * 0.5; + } +}; + +template +struct EpsInsHingeLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + T sensitivity; + inline __device__ T operator()(const T y, const T z) const + { + T t = y - z; + return t > sensitivity ? t - sensitivity : t < -sensitivity ? -t - sensitivity : 0; + } + } lz; + + const struct Dlz { + T sensitivity; + inline __device__ T operator()(const T y, const T z) const + { + T t = y - z; + return t > sensitivity ? -1 : (t < -sensitivity ? 1 : 0); + } + } dlz; + + EpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) + : Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrm1(grad, dev_scalar, stream); + } +}; + +template +struct SqEpsInsHingeLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + T sensitivity; + inline __device__ T operator()(const T y, const T z) const + { + T t = y - z; + T s = t > sensitivity ? t - sensitivity : t < -sensitivity ? -t - sensitivity : 0; + return s * s; + } + } lz; + + const struct Dlz { + T sensitivity; + inline __device__ T operator()(const T y, const T z) const + { + T t = y - z; + return -2 * (t > sensitivity ? t - sensitivity : t < -sensitivity ? (t + sensitivity) : 0); + } + } dlz; + + SqEpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) + : Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return squaredNorm(grad, dev_scalar, stream) * 0.5; + } +}; + +}; // namespace raft::solver::quasi_newton::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/objectives/linear.cuh b/cpp/include/raft/solver/detail/qn/objectives/linear.cuh new file mode 100644 index 0000000000..dfaf83abf0 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/linear.cuh @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base.cuh" +#include +#include +#include + +namespace raft::solver::quasi_newton::detail::objectives { + +template +struct SquaredLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + inline __device__ T operator()(const T y, const T z) const + { + T diff = z - y; + return diff * diff * 0.5; + } + } lz; + + const struct Dlz { + inline __device__ T operator()(const T y, const T z) const { return z - y; } + } dlz; + + SquaredLoss(const raft::handle_t& handle, int D, bool has_bias) + : Super(handle, D, 1, has_bias), lz{}, dlz{} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return squaredNorm(grad, dev_scalar, stream) * 0.5; + } +}; + +template +struct AbsLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + inline __device__ T operator()(const T y, const T z) const { return raft::myAbs(z - y); } + } lz; + + const struct Dlz { + inline __device__ T operator()(const T y, const T z) const + { + return z > y ? 1 : (z < y ? -1 : 0); + } + } dlz; + + AbsLoss(const raft::handle_t& handle, int D, bool has_bias) + : Super(handle, D, 1, has_bias), lz{}, dlz{} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrm1(grad, dev_scalar, stream); + } +}; + +}; // namespace raft::solver::quasi_newton::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh b/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh new file mode 100644 index 0000000000..ed52069bc6 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base.cuh" +#include +#include +#include + +namespace raft::solver::quasi_newton::detail::objectives { + +template +struct LogisticLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + inline __device__ T log_sigmoid(const T x) const + { + // To avoid floating point overflow in the exp function + T temp = raft::myLog(1 + raft::myExp(x < 0 ? x : -x)); + return x < 0 ? x - temp : -temp; + } + + inline __device__ T operator()(const T y, const T z) const + { + T ytil = 2 * y - 1; + return -log_sigmoid(ytil * z); + } + } lz; + + const struct Dlz { + inline __device__ T operator()(const T y, const T z) const + { + // To avoid fp overflow with exp(z) when abs(z) is large + T ez = raft::myExp(z < 0 ? z : -z); + T numerator = z < 0 ? ez : T(1.0); + return numerator / (T(1.0) + ez) - y; + } + } dlz; + + LogisticLoss(const raft::handle_t& handle, int D, bool has_bias) + : Super(handle, D, 1, has_bias), lz{}, dlz{} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrmMax(grad, dev_scalar, stream); + } +}; +}; // namespace raft::solver::quasi_newton::detail::objectives \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh b/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh new file mode 100644 index 0000000000..68c79ab15d --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base.cuh" +#include +#include +#include +#include +#include +#include + +namespace raft::solver::quasi_newton::detail::objectives { + +template +struct Tikhonov { + T l2_penalty; + Tikhonov(T l2) : l2_penalty(l2) {} + Tikhonov(const Tikhonov& other) : l2_penalty(other.l2_penalty) {} + + HDI T operator()(const T w) const { return 0.5 * l2_penalty * w * w; } + + inline void reg_grad(T* reg_val, + SimpleDenseMat& G, + const SimpleDenseMat& W, + const bool has_bias, + cudaStream_t stream) const + { + // NOTE: scikit generally does not penalize biases + SimpleDenseMat Gweights; + SimpleDenseMat Wweights; + col_slice(G, Gweights, 0, G.n - has_bias); + col_slice(W, Wweights, 0, G.n - has_bias); + Gweights.ax(l2_penalty, Wweights, stream); + + raft::linalg::mapThenSumReduce(reg_val, Wweights.len, *this, stream, Wweights.data); + } +}; + +template +struct RegularizedQN : LinearDims { + Reg* reg; + Loss* loss; + + RegularizedQN(Loss* loss, Reg* reg) + : reg(reg), loss(loss), LinearDims(loss->C, loss->D, loss->fit_intercept) + { + } + + inline void loss_grad(T* loss_val, + SimpleDenseMat& G, + const SimpleDenseMat& W, + const SimpleMat& Xb, + const SimpleVec& yb, + SimpleDenseMat& Zb, + cudaStream_t stream, + bool initGradZero = true) + { + T reg_host, loss_host; + SimpleVec lossVal(loss_val, 1); + + G.fill(0, stream); + + reg->reg_grad(lossVal.data, G, W, loss->fit_intercept, stream); + raft::update_host(®_host, lossVal.data, 1, stream); + + loss->loss_grad(lossVal.data, G, W, Xb, yb, Zb, stream, false); + raft::update_host(&loss_host, lossVal.data, 1, stream); + + raft::interruptible::synchronize(stream); + + lossVal.fill(loss_host + reg_host, stream); + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return loss->gradNorm(grad, dev_scalar, stream); + } +}; +}; // namespace raft::solver::quasi_newton::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh b/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh new file mode 100644 index 0000000000..74b78e1158 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base.cuh" +#include +#include +#include + +namespace raft::solver::quasi_newton::detail::objectives { +using raft::ceildiv; +using raft::myExp; +using raft::myLog; +using raft::myMax; + +// Input: matrix Z (dims: CxN) +// Computes softmax cross entropy loss across columns, i.e. normalization +// column-wise. +// +// This kernel performs best for small number of classes C. +// It's much faster than implementation based on ml-prims (up to ~2x - ~10x for +// small C <= BX). More importantly, it does not require another CxN scratch +// space. In that case the block covers the whole column and warp reduce is fast +// TODO for very large C, there should be maybe rather something along the lines +// of +// coalesced reduce, i.e. blocks should take care of columns +// TODO split into two kernels for small and large case? +template +__global__ void logSoftmaxKernel( + T* out, T* dZ, const T* in, const T* labels, int C, int N, bool getDerivative = true) +{ + typedef cub::WarpReduce WarpRed; + typedef cub::BlockReduce BlockRed; + + __shared__ union { + typename WarpRed::TempStorage warpStore[BY]; + typename BlockRed::TempStorage blockStore; + T sh_val[BY]; + } shm; + + int y = threadIdx.y + blockIdx.x * BY; + int len = C * N; + + bool delta = false; + // TODO is there a better way to read this? + if (getDerivative && threadIdx.x == 0) { + if (y < N) { + shm.sh_val[threadIdx.y] = labels[y]; + } else { + shm.sh_val[threadIdx.y] = std::numeric_limits::lowest(); + } + } + __syncthreads(); + T label = shm.sh_val[threadIdx.y]; + __syncthreads(); + T eta_y = 0; + T myEta = 0; + T etaMax = -1e9; + T lse = 0; + /* + * Phase 1: Find Maximum m over column + */ + for (int x = threadIdx.x; x < C; x += BX) { + int idx = x + y * C; + if (x < C && idx < len) { + myEta = in[idx]; + if (x == label) { + delta = true; + eta_y = myEta; + } + etaMax = myMax(myEta, etaMax); + } + } + T tmpMax = WarpRed(shm.warpStore[threadIdx.y]).Reduce(etaMax, cub::Max()); + if (threadIdx.x == 0) { shm.sh_val[threadIdx.y] = tmpMax; } + __syncthreads(); + etaMax = shm.sh_val[threadIdx.y]; + __syncthreads(); + + /* + * Phase 2: Compute stabilized log-sum-exp over column + * lse = m + log(sum(exp(eta - m))) + */ + // TODO there must be a better way to do this... + if (C <= BX) { // this means one block covers a column and myEta is valid + int idx = threadIdx.x + y * C; + if (threadIdx.x < C && idx < len) { lse = myExp(myEta - etaMax); } + } else { + for (int x = threadIdx.x; x < C; x += BX) { + int idx = x + y * C; + if (x < C && idx < len) { lse += myExp(in[idx] - etaMax); } + } + } + T tmpLse = WarpRed(shm.warpStore[threadIdx.y]).Sum(lse); + if (threadIdx.x == 0) { shm.sh_val[threadIdx.y] = etaMax + myLog(tmpLse); } + __syncthreads(); + lse = shm.sh_val[threadIdx.y]; + __syncthreads(); + + /* + * Phase 3: Compute derivatives dL/dZ = P - delta_y + * P is the softmax distribution, delta_y the kronecker delta for the class of + * label y If we getDerivative=false, dZ will just contain P, which might be + * useful + */ + + if (C <= BX) { // this means one block covers a column and myEta is valid + int idx = threadIdx.x + y * C; + if (threadIdx.x < C && idx < len) { + dZ[idx] = (myExp(myEta - lse) - (getDerivative ? (threadIdx.x == label) : T(0))); + } + } else { + for (int x = threadIdx.x; x < C; x += BX) { + int idx = x + y * C; + if (x < C && idx < len) { + T logP = in[idx] - lse; + dZ[idx] = (myExp(logP) - (getDerivative ? (x == label) : T(0))); + } + } + } + + if (!getDerivative) // no need to continue, lossval will be undefined + return; + + T lossVal = 0; + if (delta) { lossVal = (lse - eta_y) / N; } + + /* + * Phase 4: accumulate loss value + */ + T blockSum = BlockRed(shm.blockStore).Sum(lossVal); + if (threadIdx.x == 0 && threadIdx.y == 0) { raft::myAtomicAdd(out, blockSum); } +} + +template +void launchLogsoftmax( + T* loss_val, T* dldZ, const T* Z, const T* labels, int C, int N, cudaStream_t stream) +{ + RAFT_CUDA_TRY(cudaMemsetAsync(loss_val, 0, sizeof(T), stream)); + raft::interruptible::synchronize(stream); + if (C <= 4) { + dim3 bs(4, 64); + dim3 gs(ceildiv(N, 64)); + logSoftmaxKernel<<>>(loss_val, dldZ, Z, labels, C, N); + } else if (C <= 8) { + dim3 bs(8, 32); + dim3 gs(ceildiv(N, 32)); + logSoftmaxKernel<<>>(loss_val, dldZ, Z, labels, C, N); + } else if (C <= 16) { + dim3 bs(16, 16); + dim3 gs(ceildiv(N, 16)); + logSoftmaxKernel<<>>(loss_val, dldZ, Z, labels, C, N); + } else { + dim3 bs(32, 8); + dim3 gs(ceildiv(N, 8)); + logSoftmaxKernel<<>>(loss_val, dldZ, Z, labels, C, N); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +struct Softmax : QNLinearBase> { + typedef QNLinearBase> Super; + + Softmax(const raft::handle_t& handle, int D, int C, bool has_bias) : Super(handle, D, C, has_bias) + { + } + + inline void getLossAndDZ(T* loss_val, + SimpleDenseMat& Z, + const SimpleVec& y, + cudaStream_t stream) + { + launchLogsoftmax(loss_val, Z.data, Z.data, y.data, Z.m, Z.n, stream); + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrmMax(grad, dev_scalar, stream); + } +}; + +}; // namespace raft::solver::quasi_newton::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh b/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh new file mode 100644 index 0000000000..28e37da2fb --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "qn_util.cuh" +#include + +/* + * Linesearch functions + */ + +namespace raft::solver::quasi_newton::detail { + +template +struct LSProjectedStep { + typedef SimpleVec Vector; + struct op_pstep { + T step; + op_pstep(const T s) : step(s) {} + + HDI T operator()(const T xp, const T drt, const T pg) const + { + T xi = xp == 0 ? -pg : xp; + return project_orth(xp + step * drt, xi); + } + }; + + void operator()(const T step, + Vector& x, + const Vector& drt, + const Vector& xp, + const Vector& pgrad, + cudaStream_t stream) const + { + op_pstep pstep(step); + x.assign_ternary(xp, drt, pgrad, pstep, stream); + } +}; + +template +inline bool ls_success(const LBFGSParam& param, + const T fx_init, + const T dg_init, + const T fx, + const T dg_test, + const T step, + const SimpleVec& grad, + const SimpleVec& drt, + T* width, + T* dev_scalar, + cudaStream_t stream) +{ + if (fx > fx_init + step * dg_test) { + *width = param.ls_dec; + } else { + // Armijo condition is met + if (param.linesearch == LBFGS_LS_BT_ARMIJO) return true; + + const T dg = dot(grad, drt, dev_scalar, stream); + if (dg < param.wolfe * dg_init) { + *width = param.ls_inc; + } else { + // Regular Wolfe condition is met + if (param.linesearch == LBFGS_LS_BT_WOLFE) return true; + + if (dg > -param.wolfe * dg_init) { + *width = param.ls_dec; + } else { + // Strong Wolfe condition is met + return true; + } + } + } + + return false; +} + +/** + * Backtracking linesearch + * + * \param param LBFGS parameters + * \param f A function object such that `f(x, grad)` returns the + * objective function value at `x`, and overwrites `grad` + * with the gradient. + * \param fx In: The objective function value at the current point. + * Out: The function value at the new point. + * \param x Out: The new point moved to. + * \param grad In: The current gradient vector. + * Out: The gradient at the new point. + * \param step In: The initial step length. + * Out: The calculated step length. + * \param drt The current moving direction. + * \param xp The current point. + * \param dev_scalar Device pointer to workspace of at least 1 + * \param stream Device pointer to workspace of at least 1 + */ +template +LINE_SEARCH_RETCODE ls_backtrack(const LBFGSParam& param, + Function& f, + T& fx, + SimpleVec& x, + SimpleVec& grad, + T& step, + const SimpleVec& drt, + const SimpleVec& xp, + T* dev_scalar, + cudaStream_t stream) +{ + // Check the value of step + if (step <= T(0)) return LS_INVALID_STEP; + + // Save the function value at the current x + const T fx_init = fx; + // Projection of gradient on the search direction + const T dg_init = dot(grad, drt, dev_scalar, stream); + // Make sure d points to a descent direction + if (dg_init > 0) return LS_INVALID_DIR; + + const T dg_test = param.ftol * dg_init; + T width; + + RAFT_LOG_TRACE("Starting line search fx_init=%f, dg_init=%f", fx_init, dg_init); + + int iter; + for (iter = 0; iter < param.max_linesearch; iter++) { + // x_{k+1} = x_k + step * d_k + x.axpy(step, drt, xp, stream); + // Evaluate this candidate + fx = f(x, grad, dev_scalar, stream); + RAFT_LOG_TRACE("Line search iter %d, fx=%f", iter, fx); + // if (is_success(fx_init, dg_init, fx, dg_test, step, grad, drt, &width)) + if (ls_success( + param, fx_init, dg_init, fx, dg_test, step, grad, drt, &width, dev_scalar, stream)) + return LS_SUCCESS; + + if (step < param.min_step) return LS_INVALID_STEP_MIN; + + if (step > param.max_step) return LS_INVALID_STEP_MAX; + + step *= width; + } + return LS_MAX_ITERS_REACHED; +} + +template +LINE_SEARCH_RETCODE ls_backtrack_projected(const LBFGSParam& param, + Function& f, + T& fx, + SimpleVec& x, + SimpleVec& grad, + const SimpleVec& pseudo_grad, + T& step, + const SimpleVec& drt, + const SimpleVec& xp, + T l1_penalty, + T* dev_scalar, + cudaStream_t stream) +{ + LSProjectedStep lsstep; + + // Check the value of step + if (step <= T(0)) return LS_INVALID_STEP; + + // Save the function value at the current x + const T fx_init = fx; + // Projection of gradient on the search direction + const T dg_init = dot(pseudo_grad, drt, dev_scalar, stream); + // Make sure d points to a descent direction + if (dg_init > 0) return LS_INVALID_DIR; + + const T dg_test = param.ftol * dg_init; + T width; + + int iter; + for (iter = 0; iter < param.max_linesearch; iter++) { + // x_{k+1} = proj_orth(x_k + step * d_k) + lsstep(step, x, drt, xp, pseudo_grad, stream); + // evaluates fx with l1 term, but only grad of the loss term + fx = f(x, grad, dev_scalar, stream); + + // if (is_success(fx_init, dg_init, fx, dg_test, step, pseudo_grad, drt, + // &width)) + if (ls_success( + param, fx_init, dg_init, fx, dg_test, step, pseudo_grad, drt, &width, dev_scalar, stream)) + return LS_SUCCESS; + + if (step < param.min_step) return LS_INVALID_STEP_MIN; + + if (step > param.max_step) return LS_INVALID_STEP_MAX; + + step *= width; + } + return LS_MAX_ITERS_REACHED; +} + +}; // namespace raft::solver::quasi_newton::detail diff --git a/cpp/include/raft/solver/detail/qn/qn_solvers.cuh b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh new file mode 100644 index 0000000000..a9f26096cd --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh @@ -0,0 +1,469 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +/* + * This file contains implementations of two popular Quasi-Newton methods: + * - Limited-memory Broyden Fletcher Goldfarb Shanno (L-BFGS) [Nocedal, Wright - + * Numerical Optimization (1999)] + * - Orthant-wise limited-memory quasi-newton (OWL-QN) [Andrew, Gao - ICML 2007] + * https://www.microsoft.com/en-us/research/publication/scalable-training-of-l1-regularized-log-linear-models/ + * + * L-BFGS is a classical method to solve unconstrained optimization problems of + * differentiable multi-variate functions f: R^D \mapsto R, i.e. it solves + * + * \min_{x \in R^D} f(x) + * + * iteratively by building up a m-dimensional (inverse) Hessian approximation. + * + * OWL-QN is an extension of L-BFGS that is specifically designed to optimize + * functions of the form + * + * f(x) + \lambda * \sum_i |x_i|, + * + * i.e. functions with an l1 penalty, by leveraging that |z| is differentiable + * when restricted to an orthant. + * + */ + +#include "qn_linesearch.cuh" +#include "qn_util.cuh" +#include +#include +#include +#include + +namespace raft::solver::quasi_newton::detail { + +// TODO better way to deal with alignment? Smaller aligne possible? +constexpr size_t qn_align = 256; + +template +inline size_t lbfgs_workspace_size(const LBFGSParam& param, const int n) +{ + size_t mat_size = raft::alignTo(sizeof(T) * param.m * n, qn_align); + size_t vec_size = raft::alignTo(sizeof(T) * n, qn_align); + return 2 * mat_size + 4 * vec_size + qn_align; +} + +template +inline size_t owlqn_workspace_size(const LBFGSParam& param, const int n) +{ + size_t vec_size = raft::alignTo(sizeof(T) * n, qn_align); + return lbfgs_workspace_size(param, n) + vec_size; +} + +template +inline bool update_and_check(const char* solver, + const LBFGSParam& param, + int iter, + LINE_SEARCH_RETCODE lsret, + T& fx, + T& fxp, + const T& gnorm, + raft::solver::quasi_newton::SimpleVec& x, + raft::solver::quasi_newton::SimpleVec& xp, + raft::solver::quasi_newton::SimpleVec& grad, + raft::solver::quasi_newton::SimpleVec& gradp, + std::vector& fx_hist, + T* dev_scalar, + OPT_RETCODE& outcode, + cudaStream_t stream) +{ + bool stop = false; + bool converged = false; + bool isLsValid = !isnan(fx) && !isinf(fx); + // Linesearch may fail to converge, but still come closer to the solution; + // if that is not the case, let `check_convergence` ("insufficient change") + // below terminate the loop. + bool isLsNonCritical = lsret == LS_INVALID_STEP_MIN || lsret == LS_MAX_ITERS_REACHED; + // If the error is not critical, check that the target function does not grow. + // This shouldn't really happen, but weird things can happen if the convergence + // thresholds are too small. + bool isLsInDoubt = isLsValid && fx <= fxp + param.ftol && isLsNonCritical; + bool isLsSuccess = lsret == LS_SUCCESS || isLsInDoubt; + + RAFT_LOG_TRACE("%s iteration %d, fx=%f", solver, iter, fx); + + // if the target is at least finite, we can check the convergence + if (isLsValid) converged = check_convergence(param, iter, fx, gnorm, fx_hist); + + if (!isLsSuccess && !converged) { + RAFT_LOG_WARN( + "%s line search failed (code %d); stopping at the last valid step", solver, lsret); + outcode = OPT_LS_FAILED; + stop = true; + } else if (!isLsValid) { + RAFT_LOG_ERROR( + "%s error fx=%f at iteration %d; stopping at the last valid step", solver, fx, iter); + outcode = OPT_NUMERIC_ERROR; + stop = true; + } else if (converged) { + RAFT_LOG_DEBUG("%s converged", solver); + outcode = OPT_SUCCESS; + stop = true; + } else if (isLsInDoubt && fx + param.ftol >= fxp) { + // If a non-critical error has happened during the line search, check if the target + // is improved at least a bit. Otherwise, stop to avoid spinning till the iteration limit. + RAFT_LOG_WARN( + "%s stopped, because the line search failed to advance (step delta = %f)", solver, fx - fxp); + outcode = OPT_LS_FAILED; + stop = true; + } + + // if lineseach wasn't successful, undo the update. + if (!isLsSuccess || !isLsValid) { + fx = fxp; + x.copy_async(xp, stream); + grad.copy_async(gradp, stream); + } + + return stop; +} + +template +inline OPT_RETCODE min_lbfgs(const LBFGSParam& param, + Function& f, // function to minimize + SimpleVec& x, // initial point, holds result + T& fx, // output function value + int* k, // output iterations + SimpleVec& workspace, // scratch space + cudaStream_t stream, + int verbosity = 0) +{ + int n = x.len; + const int workspace_size = lbfgs_workspace_size(param, n); + ASSERT(workspace.len >= workspace_size, "LBFGS: workspace insufficient"); + + // SETUP WORKSPACE + size_t mat_size = raft::alignTo(sizeof(T) * param.m * n, qn_align); + size_t vec_size = raft::alignTo(sizeof(T) * n, qn_align); + T* p_ws = workspace.data; + SimpleDenseMat S(p_ws, n, param.m); + p_ws += mat_size; + SimpleDenseMat Y(p_ws, n, param.m); + p_ws += mat_size; + SimpleVec xp(p_ws, n); + p_ws += vec_size; + SimpleVec grad(p_ws, n); + p_ws += vec_size; + SimpleVec gradp(p_ws, n); + p_ws += vec_size; + SimpleVec drt(p_ws, n); + p_ws += vec_size; + T* dev_scalar = p_ws; + + SimpleVec svec, yvec; // mask vectors + + std::vector ys(param.m); + std::vector alpha(param.m); + std::vector fx_hist(param.past > 0 ? param.past : 0); + + *k = 0; + raft::solver::quasi_newton::Logger::get().setLevel(verbosity); + RAFT_LOG_DEBUG("Running L-BFGS"); + + // Evaluate function and compute gradient + fx = f(x, grad, dev_scalar, stream); + T gnorm = f.gradNorm(grad, dev_scalar, stream); + + if (param.past > 0) fx_hist[0] = fx; + + // Early exit if the initial x is already a minimizer + if (check_convergence(param, *k, fx, gnorm, fx_hist)) { + RAFT_LOG_DEBUG("Initial solution fulfills optimality condition."); + return OPT_SUCCESS; + } + + // Initial direction + drt.ax(-1.0, grad, stream); + + // Initial step + T step = T(1.0) / nrm2(drt, dev_scalar, stream); + T fxp = fx; + + *k = 1; + int end = 0; + int n_vec = 0; // number of vector updates made in lbfgs_search_dir + OPT_RETCODE retcode; + LINE_SEARCH_RETCODE lsret; + for (; *k <= param.max_iterations; (*k)++) { + // Save the curent x and gradient + xp.copy_async(x, stream); + gradp.copy_async(grad, stream); + fxp = fx; + + // Line search to update x, fx and gradient + lsret = ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream); + gnorm = f.gradNorm(grad, dev_scalar, stream); + + if (update_and_check("L-BFGS", + param, + *k, + lsret, + fx, + fxp, + gnorm, + x, + xp, + grad, + gradp, + fx_hist, + dev_scalar, + retcode, + stream)) + return retcode; + + // Update s and y + // s_{k+1} = x_{k+1} - x_k + // y_{k+1} = g_{k+1} - g_k + col_ref(S, svec, end); + col_ref(Y, yvec, end); + svec.axpy(-1.0, xp, x, stream); + yvec.axpy(-1.0, gradp, grad, stream); + // drt <- -H * g + end = lbfgs_search_dir( + param, &n_vec, end, S, Y, grad, svec, yvec, drt, ys, alpha, dev_scalar, stream); + + // step = 1.0 as initial guess + step = T(1.0); + } + RAFT_LOG_WARN("L-BFGS: max iterations reached"); + return OPT_MAX_ITERS_REACHED; +} + +template +inline void update_pseudo(const SimpleVec& x, + const SimpleVec& grad, + const op_pseudo_grad& pseudo_grad, + const int pg_limit, + SimpleVec& pseudo, + cudaStream_t stream) +{ + if (grad.len > pg_limit) { + pseudo.copy_async(grad, stream); + SimpleVec mask(pseudo.data, pg_limit); + mask.assign_binary(x, grad, pseudo_grad, stream); + } else { + pseudo.assign_binary(x, grad, pseudo_grad, stream); + } +} + +template +inline OPT_RETCODE min_owlqn(const LBFGSParam& param, + Function& f, + const T l1_penalty, + const int pg_limit, + SimpleVec& x, + T& fx, + int* k, + SimpleVec& workspace, // scratch space + cudaStream_t stream, + const int verbosity = 0) +{ + int n = x.len; + const int workspace_size = owlqn_workspace_size(param, n); + ASSERT(workspace.len >= workspace_size, "LBFGS: workspace insufficient"); + ASSERT(pg_limit <= n && pg_limit > 0, "OWL-QN: Invalid pseudo grad limit parameter"); + + // SETUP WORKSPACE + size_t mat_size = raft::alignTo(sizeof(T) * param.m * n, qn_align); + size_t vec_size = raft::alignTo(sizeof(T) * n, qn_align); + T* p_ws = workspace.data; + SimpleDenseMat S(p_ws, n, param.m); + p_ws += mat_size; + SimpleDenseMat Y(p_ws, n, param.m); + p_ws += mat_size; + SimpleVec xp(p_ws, n); + p_ws += vec_size; + SimpleVec grad(p_ws, n); + p_ws += vec_size; + SimpleVec gradp(p_ws, n); + p_ws += vec_size; + SimpleVec drt(p_ws, n); + p_ws += vec_size; + SimpleVec pseudo(p_ws, n); + p_ws += vec_size; + T* dev_scalar = p_ws; + + raft::solver::quasi_newton::Logger::get().setLevel(verbosity); + + SimpleVec svec, yvec; // mask vectors + + std::vector ys(param.m); + std::vector alpha(param.m); + std::vector fx_hist(param.past > 0 ? param.past : 0); + + op_project project_neg(T(-1.0)); + + auto f_wrap = [&f, &l1_penalty, &pg_limit]( + SimpleVec& x, SimpleVec& grad, T* dev_scalar, cudaStream_t stream) { + T tmp = f(x, grad, dev_scalar, stream); + SimpleVec mask(x.data, pg_limit); + return tmp + l1_penalty * nrm1(mask, dev_scalar, stream); + }; + + *k = 0; + RAFT_LOG_DEBUG("Running OWL-QN with lambda=%f", l1_penalty); + + // op to compute the pseudo gradients + op_pseudo_grad pseudo_grad(l1_penalty); + + fx = f_wrap(x, grad, dev_scalar, + stream); // fx is loss+regularizer, grad is grad of loss only + T gnorm = f.gradNorm(grad, dev_scalar, stream); + + // compute pseudo grad, but don't overwrite grad: used to build H + // pseudo.assign_binary(x, grad, pseudo_grad); + update_pseudo(x, grad, pseudo_grad, pg_limit, pseudo, stream); + + if (param.past > 0) fx_hist[0] = fx; + + // Early exit if the initial x is already a minimizer + if (check_convergence(param, *k, fx, gnorm, fx_hist)) { + RAFT_LOG_DEBUG("Initial solution fulfills optimality condition."); + return OPT_SUCCESS; + } + + // Initial direction + drt.ax(-1.0, pseudo, stream); // using Pseudo gradient here + // below should be done for consistency but seems unnecessary + // drt.assign_k_ary(project, pseudo, x); + + // Initial step + T step = T(1.0) / std::max(T(1), nrm2(drt, dev_scalar, stream)); + T fxp = fx; + + int end = 0; + int n_vec = 0; // number of vector updates made in lbfgs_search_dir + OPT_RETCODE retcode; + LINE_SEARCH_RETCODE lsret; + for ((*k) = 1; (*k) <= param.max_iterations; (*k)++) { + // Save the curent x and gradient + xp.copy_async(x, stream); + gradp.copy_async(grad, stream); + fxp = fx; + + // Projected line search to update x, fx and gradient + lsret = ls_backtrack_projected( + param, f_wrap, fx, x, grad, pseudo, step, drt, xp, l1_penalty, dev_scalar, stream); + gnorm = f.gradNorm(grad, dev_scalar, stream); + + if (update_and_check("QWL-QN", + param, + *k, + lsret, + fx, + fxp, + gnorm, + x, + xp, + grad, + gradp, + fx_hist, + dev_scalar, + retcode, + stream)) + return retcode; + + // recompute pseudo + // pseudo.assign_binary(x, grad, pseudo_grad); + update_pseudo(x, grad, pseudo_grad, pg_limit, pseudo, stream); + + // Update s and y - We should only do this if there is no skipping condition + + col_ref(S, svec, end); + col_ref(Y, yvec, end); + svec.axpy(-1.0, xp, x, stream); + yvec.axpy(-1.0, gradp, grad, stream); + // drt <- -H * -> pseudo grad <- + end = lbfgs_search_dir( + param, &n_vec, end, S, Y, pseudo, svec, yvec, drt, ys, alpha, dev_scalar, stream); + + // Project drt onto orthant of -pseudog + drt.assign_binary(drt, pseudo, project_neg, stream); + + // step = 1.0 as initial guess + step = T(1.0); + } + RAFT_LOG_WARN("QWL-QN: max iterations reached"); + return OPT_MAX_ITERS_REACHED; +} +/* + * Chooses the right algorithm, depending on presence of l1 term + */ +template +inline int qn_minimize(const raft::handle_t& handle, + SimpleVec& x, + T* fx, + int* num_iters, + LossFunction& loss, + const T l1, + const LBFGSParam& opt_param, + cudaStream_t stream, + const int verbosity = 0) +{ + // TODO should the worksapce allocation happen outside? + OPT_RETCODE ret; + if (l1 == 0.0) { + rmm::device_uvector tmp(lbfgs_workspace_size(opt_param, x.len), stream); + SimpleVec workspace(tmp.data(), tmp.size()); + + ret = min_lbfgs(opt_param, + loss, // function to minimize + x, // initial point, holds result + *fx, // output function value + num_iters, // output iterations + workspace, // scratch space + stream, + verbosity); + + RAFT_LOG_DEBUG("L-BFGS Done"); + } else { + // There might not be a better way to deal with dispatching + // for the l1 case: + // The algorithm explicitely expects a differentiable + // function f(x). It takes care of adding and + // handling the term l1norm(x) * l1_pen explicitely, i.e. + // it needs to evaluate f(x) and its gradient separately + + rmm::device_uvector tmp(owlqn_workspace_size(opt_param, x.len), stream); + SimpleVec workspace(tmp.data(), tmp.size()); + + ret = min_owlqn(opt_param, + loss, // function to minimize + l1, + loss.D * loss.C, + x, // initial point, holds result + *fx, // output function value + num_iters, // output iterations + workspace, // scratch space + stream, + verbosity); + + RAFT_LOG_DEBUG("OWL-QN Done"); + } + if (ret == OPT_MAX_ITERS_REACHED) { + RAFT_LOG_WARN( + "Maximum iterations reached before solver is converged. To increase " + "model accuracy you can increase the number of iterations (max_iter) or " + "improve the scaling of the input data."); + } + return ret; +} + +}; // namespace raft::solver::quasi_newton::detail \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/qn/qn_util.cuh b/cpp/include/raft/solver/detail/qn/qn_util.cuh new file mode 100644 index 0000000000..a8df31df8f --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/qn_util.cuh @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::solver::quasi_newton::detail { + +inline bool qn_is_classification(qn_loss_type t) +{ + switch (t) { + case QN_LOSS_LOGISTIC: + case QN_LOSS_SOFTMAX: + case QN_LOSS_HINGE: + case QN_LOSS_SQ_HINGE: return true; + default: return false; + } +} + +template +HDI T project_orth(T x, T y) +{ + return x * y <= T(0) ? T(0) : x; +} + +template +inline bool check_convergence( + const LBFGSParam& param, const int k, const T fx, const T gnorm, std::vector& fx_hist) +{ + // Positive scale factor for the stop condition + T fmag = std::max(fx, param.epsilon); + + RAFT_LOG_DEBUG( + "%04d: f(x)=%.8f conv.crit=%.8f (gnorm=%.8f, fmag=%.8f)", k, fx, gnorm / fmag, gnorm, fmag); + // Convergence test -- gradient + if (gnorm <= param.epsilon * fmag) { + RAFT_LOG_DEBUG("Converged after %d iterations: f(x)=%.6f", k, fx); + return true; + } + // Convergence test -- objective function value + if (param.past > 0) { + if (k >= param.past && std::abs(fx_hist[k % param.past] - fx) <= param.delta * fmag) { + RAFT_LOG_DEBUG("Insufficient change in objective value"); + return true; + } + + fx_hist[k % param.past] = fx; + } + return false; +} + +/* + * Multiplies a vector g with the inverse hessian approximation, i.e. + * drt = - H * g, + * e.g. to compute the new search direction for g = \nabla f(x) + */ +template +inline int lbfgs_search_dir(const LBFGSParam& param, + int* n_vec, + const int end_prev, + const SimpleDenseMat& S, + const SimpleDenseMat& Y, + const SimpleVec& g, + const SimpleVec& svec, + const SimpleVec& yvec, + SimpleVec& drt, + std::vector& yhist, + std::vector& alpha, + T* dev_scalar, + cudaStream_t stream) +{ + SimpleVec sj, yj; // mask vectors + int end = end_prev; + // note: update_state assigned svec, yvec to m_s[:,end], m_y[:,end] + T ys = dot(svec, yvec, dev_scalar, stream); + T yy = dot(yvec, yvec, dev_scalar, stream); + RAFT_LOG_TRACE("ys=%e, yy=%e", ys, yy); + // Skipping test: + if (ys <= std::numeric_limits::epsilon() * yy) { + // We can land here for example if yvec == 0 (no change in the gradient, + // g_k == g_k+1). That means the Hessian is approximately zero. We cannot + // use the QN model to update the search dir, we just continue along the + // previous direction. + // + // See eq (3.9) and Section 6 in "A limited memory algorithm for bound + // constrained optimization" Richard H. Byrd, Peihuang Lu, Jorge Nocedal and + // Ciyou Zhu Technical Report NAM-08 (1994) NORTHWESTERN UNIVERSITY. + // + // Alternative condition to skip update is: ys / (-gs) <= epsmch, + // (where epsmch = std::numeric_limits::epsilon) given in Section 5 of + // "L-BFGS-B Fortran subroutines for large-scale bound constrained + // optimization" Ciyou Zhu, Richard H. Byrd, Peihuang Lu and Jorge Nocedal + // (1994). + RAFT_LOG_DEBUG("L-BFGS WARNING: skipping update step ys=%f, yy=%f", ys, yy); + return end; + } + (*n_vec)++; + yhist[end] = ys; + + // Recursive formula to compute d = -H * g + drt.ax(-1.0, g, stream); + int bound = std::min(param.m, *n_vec); + end = (end + 1) % param.m; + int j = end; + for (int i = 0; i < bound; i++) { + j = (j + param.m - 1) % param.m; + col_ref(S, sj, j); + col_ref(Y, yj, j); + alpha[j] = dot(sj, drt, dev_scalar, stream) / yhist[j]; + drt.axpy(-alpha[j], yj, drt, stream); + } + + drt.ax(ys / yy, drt, stream); + + for (int i = 0; i < bound; i++) { + col_ref(S, sj, j); + col_ref(Y, yj, j); + T beta = dot(yj, drt, dev_scalar, stream) / yhist[j]; + drt.axpy((alpha[j] - beta), sj, drt, stream); + j = (j + 1) % param.m; + } + + return end; +} + +template +HDI T get_pseudo_grad(T x, T dlossx, T C) +{ + if (x != 0) { return dlossx + raft::sgn(x) * C; } + T dplus = dlossx + C; + T dmins = dlossx - C; + if (dmins > T(0)) return dmins; + if (dplus < T(0)) return dplus; + return T(0); +} + +template +struct op_project { + T scal; + op_project(T s) : scal(s) {} + + HDI T operator()(const T x, const T y) const { return project_orth(x, scal * y); } +}; + +template +struct op_pseudo_grad { + T l1; + op_pseudo_grad(const T lam) : l1(lam) {} + + HDI T operator()(const T x, const T dlossx) const { return get_pseudo_grad(x, dlossx, l1); } +}; + +}; // namespace raft::solver::quasi_newton::detail diff --git a/cpp/include/raft/solver/detail/sgd.cuh b/cpp/include/raft/solver/detail/sgd.cuh new file mode 100644 index 0000000000..8a5372dc33 --- /dev/null +++ b/cpp/include/raft/solver/detail/sgd.cuh @@ -0,0 +1,422 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail { + +/** + * Fits a linear, lasso, and elastic-net regression model using Gradient Descent solver + * @param handle + * Reference of raft::handle_t + * @param input + * pointer to an array in column-major format (size of n_rows, n_cols) + * @param n_rows + * n_samples or rows in input + * @param n_cols + * n_features or columns in X + * @param labels + * pointer to an array for labels (size of n_rows) + * @param coef + * pointer to an array for coefficients (size of n_cols). This will be filled with + * coefficients once the function is executed. + * @param intercept + * pointer to a scalar for intercept. This will be filled + * once the function is executed + * @param fit_intercept + * boolean parameter to control if the intercept will be fitted or not + * @param batch_size + * number of rows in the minibatch + * @param epochs + * number of iterations that the solver will run + * @param lr_type + * type of the learning rate function (i.e. OPTIMAL, CONSTANT, INVSCALING, ADAPTIVE) + * @param eta0 + * learning rate for contant lr_type. It's used to calculate learning rate function for other + * types of lr_type + * @param power_t + * power value in the INVSCALING lr_type + * @param loss + * enum to use different loss functions. + * @param penalty + * None, L1, L2, or Elastic-net penalty + * @param alpha + * alpha value in L1 + * @param l1_ratio + * ratio of alpha will be used for L1. (1 - l1_ratio) * alpha will be used for L2. + * @param shuffle + * boolean parameter to control whether coordinates will be picked randomly or not. + * @param tol + * tolerance to stop the solver + * @param n_iter_no_change + * solver stops if there is no update greater than tol after n_iter_no_change iterations + * @param stream + * cuda stream + */ +template +void sgdFit(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + math_t* coef, + math_t* intercept, + bool fit_intercept, + int batch_size, + int epochs, + ML::lr_type lr_type, + math_t eta0, + math_t power_t, + ML::loss_funct loss, + Functions::penalty penalty, + math_t alpha, + math_t l1_ratio, + bool shuffle, + math_t tol, + int n_iter_no_change, + cudaStream_t stream) +{ + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + + cublasHandle_t cublas_handle = handle.get_cublas_handle(); + + rmm::device_uvector mu_input(0, stream); + rmm::device_uvector mu_labels(0, stream); + rmm::device_uvector norm2_input(0, stream); + + if (fit_intercept) { + mu_input.resize(n_cols, stream); + mu_labels.resize(1, stream); + + preProcessData(handle, + input, + n_rows, + n_cols, + labels, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + false); + } + + rmm::device_uvector grads(n_cols, stream); + rmm::device_uvector indices(batch_size, stream); + rmm::device_uvector input_batch(batch_size * n_cols, stream); + rmm::device_uvector labels_batch(batch_size, stream); + rmm::device_scalar loss_value(stream); + + math_t prev_loss_value = math_t(0); + math_t curr_loss_value = math_t(0); + + std::vector rand_indices(n_rows); + std::mt19937 g(rand()); + initShuffle(rand_indices, g); + + math_t t = math_t(1); + math_t learning_rate = math_t(0); + if (lr_type == ML::lr_type::ADAPTIVE) { + learning_rate = eta0; + } else if (lr_type == ML::lr_type::OPTIMAL) { + eta0 = calOptimalInit(alpha); + } + + int n_iter_no_change_curr = 0; + + for (int i = 0; i < epochs; i++) { + int cbs = 0; + int j = 0; + + if (i > 0 && shuffle) { Solver::shuffle(rand_indices, g); } + + while (j < n_rows) { + if ((j + batch_size) > n_rows) { + cbs = n_rows - j; + } else { + cbs = batch_size; + } + + if (cbs == 0) break; + + raft::update_device(indices.data(), &rand_indices[j], cbs, stream); + raft::matrix::copyRows( + input, n_rows, n_cols, input_batch.data(), indices.data(), cbs, stream); + raft::matrix::copyRows(labels, n_rows, 1, labels_batch.data(), indices.data(), cbs, stream); + + if (loss == ML::loss_funct::SQRD_LOSS) { + Functions::linearRegLossGrads(handle, + input_batch.data(), + cbs, + n_cols, + labels_batch.data(), + coef, + grads.data(), + penalty, + alpha, + l1_ratio, + stream); + } else if (loss == ML::loss_funct::LOG) { + Functions::logisticRegLossGrads(handle, + input_batch.data(), + cbs, + n_cols, + labels_batch.data(), + coef, + grads.data(), + penalty, + alpha, + l1_ratio, + stream); + } else if (loss == ML::loss_funct::HINGE) { + Functions::hingeLossGrads(handle, + input_batch.data(), + cbs, + n_cols, + labels_batch.data(), + coef, + grads.data(), + penalty, + alpha, + l1_ratio, + stream); + } else { + ASSERT(false, "sgd.cuh: Other loss functions have not been implemented yet!"); + } + + if (lr_type != ML::lr_type::ADAPTIVE) + learning_rate = calLearningRate(lr_type, eta0, power_t, alpha, t); + + raft::linalg::scalarMultiply(grads.data(), grads.data(), learning_rate, n_cols, stream); + raft::linalg::subtract(coef, coef, grads.data(), n_cols, stream); + + j = j + cbs; + t = t + 1; + } + + if (tol > math_t(0)) { + if (loss == ML::loss_funct::SQRD_LOSS) { + Functions::linearRegLoss(handle, + input, + n_rows, + n_cols, + labels, + coef, + loss_value.data(), + penalty, + alpha, + l1_ratio, + stream); + } else if (loss == ML::loss_funct::LOG) { + Functions::logisticRegLoss(handle, + input, + n_rows, + n_cols, + labels, + coef, + loss_value.data(), + penalty, + alpha, + l1_ratio, + stream); + } else if (loss == ML::loss_funct::HINGE) { + Functions::hingeLoss(handle, + input, + n_rows, + n_cols, + labels, + coef, + loss_value.data(), + penalty, + alpha, + l1_ratio, + stream); + } + + raft::update_host(&curr_loss_value, loss_value.data(), 1, stream); + handle.sync_stream(stream); + + if (i > 0) { + if (curr_loss_value > (prev_loss_value - tol)) { + n_iter_no_change_curr = n_iter_no_change_curr + 1; + if (n_iter_no_change_curr > n_iter_no_change) { + if (lr_type == ML::lr_type::ADAPTIVE && learning_rate > math_t(1e-6)) { + learning_rate = learning_rate / math_t(5); + n_iter_no_change_curr = 0; + } else { + break; + } + } + } else { + n_iter_no_change_curr = 0; + } + } + + prev_loss_value = curr_loss_value; + } + } + + if (fit_intercept) { + GLM::postProcessData(handle, + input, + n_rows, + n_cols, + labels, + coef, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + false); + } else { + *intercept = math_t(0); + } +} + +/** + * Make predictions + * @param handle + * Reference of raft::handle_t + * @param input + * pointer to an array in column-major format (size of n_rows, n_cols) + * @param n_rows + * n_samples or rows in input + * @param n_cols + * n_features or columns in X + * @param coef + * pointer to an array for coefficients (size of n_cols). Calculated in cdFit function. + * @param intercept + * intercept value calculated in cdFit function + * @param preds + * pointer to an array for predictions (size of n_rows). This will be fitted once functions + * is executed. + * @param loss + * enum to use different loss functions. Only linear regression loss functions is supported + * right now. + * @param stream + * cuda stream + */ +template +void sgdPredict(const raft::handle_t& handle, + const math_t* input, + int n_rows, + int n_cols, + const math_t* coef, + math_t intercept, + math_t* preds, + ML::loss_funct loss, + cudaStream_t stream) +{ + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + + if (loss == ML::loss_funct::SQRD_LOSS) { + Functions::linearRegH(handle, input, n_rows, n_cols, coef, preds, intercept, stream); + } else if (loss == ML::loss_funct::LOG) { + Functions::logisticRegH(handle, input, n_rows, n_cols, coef, preds, intercept, stream); + } else if (loss == ML::loss_funct::HINGE) { + Functions::hingeH(handle, input, n_rows, n_cols, coef, preds, intercept, stream); + } +} + +/** + * Make binary classifications + * @param handle + * Reference of raft::handle_t + * @param input + * pointer to an array in column-major format (size of n_rows, n_cols) + * @param n_rows + * n_samples or rows in input + * @param n_cols + * n_features or columns in X + * @param coef + * pointer to an array for coefficients (size of n_cols). Calculated in cdFit function. + * @param intercept + * intercept value calculated in cdFit function + * @param preds + * pointer to an array for predictions (size of n_rows). This will be fitted once functions + * is executed. + * @param loss + * enum to use different loss functions. Only linear regression loss functions is supported + * right now. + * @param stream + * cuda stream + */ +template +void sgdPredictBinaryClass(const raft::handle_t& handle, + const math_t* input, + int n_rows, + int n_cols, + const math_t* coef, + math_t intercept, + math_t* preds, + ML::loss_funct loss, + cudaStream_t stream) +{ + sgdPredict(handle, input, n_rows, n_cols, coef, intercept, preds, loss, stream); + + math_t scalar = math_t(1); + if (loss == ML::loss_funct::SQRD_LOSS || loss == ML::loss_funct::LOG) { + raft::linalg::unaryOp( + preds, + preds, + n_rows, + [scalar] __device__(math_t in) { + if (in >= math_t(0.5)) + return math_t(1); + else + return math_t(0); + }, + stream); + } else if (loss == ML::loss_funct::HINGE) { + raft::linalg::unaryOp( + preds, + preds, + n_rows, + [scalar] __device__(math_t in) { + if (in >= math_t(0.0)) + return math_t(1); + else + return math_t(0); + }, + stream); + } +} + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/shuffle.h b/cpp/include/raft/solver/detail/shuffle.h new file mode 100644 index 0000000000..1a815822b4 --- /dev/null +++ b/cpp/include/raft/solver/detail/shuffle.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::solver::detail { + +template +void initShuffle(std::vector& rand_indices, std::mt19937& g, math_t random_state = 0) +{ + g.seed((int)random_state); + for (std::size_t i = 0; i < rand_indices.size(); ++i) + rand_indices[i] = i; +} + +template +void shuffle(std::vector& rand_indices, std::mt19937& g) +{ + std::shuffle(rand_indices.begin(), rand_indices.end(), g); +} + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/gradient_descent.cuh b/cpp/include/raft/solver/gradient_descent.cuh new file mode 100644 index 0000000000..5188d81ae2 --- /dev/null +++ b/cpp/include/raft/solver/gradient_descent.cuh @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::solver::gradient_descent { + +/** + * @brief Minimizes an objective function using the Gradient Descent solver and optional + * lasso or elastic-net penalties. + * + * @param[in] handle: Reference of raft::handle_t + * @param[in] A: Input matrix in column-major format (size of n_rows, n_cols) + * @param[in] b: Input vector of labels (size of n_rows) + * @param[out] x: Output vector of coefficients (size of n_cols) + * @param[out] intercept: Optional scalar if fitting the intercept + * @param[in] params: solver hyper-parameters + */ +template +void minimize(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + raft::device_vector_view x, + std::optional < raft::device_scalar_view intercept, + sgd_params& params) +{ + RAFT_EXPECTS(A.extent(0) == b.extent(0), + "Number of labels must match the number of rows in input matrix"); + RAFT_EXPECTS(x.extent(0) == A.extent(1), + "Objective is linear. The number of coefficients must match the number features in " + "the input matrix"); + + auto intercept_ptr = intercept.has_value() ? intercept.data_handle() ? nullptr; + detail::sgdFit(handle, + A.data_handle(), + A.extent(0), + A.extent(1), + b.data_handle(), + x.data_handle(), + intercept_ptr, + intercept.has_value(), + params.batch_size, + params.epochs, + params.lr_type, + params.eta0, + params.power_t, + params.loss, + params.penalty, + params.alpha, + params.l1_ratio, + params.shuffle, + params.tol, + params.n_iter_no_change, + handle.get_stream()); +} + +} // namespace raft::solver::gradient_descent \ No newline at end of file diff --git a/cpp/include/raft/solver/least_angle_regression.cuh b/cpp/include/raft/solver/least_angle_regression.cuh new file mode 100644 index 0000000000..6484e40d7d --- /dev/null +++ b/cpp/include/raft/solver/least_angle_regression.cuh @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::solver::least_angle_regression { + +/** + * @brief Train a regression model using Least Angle Regression (LARS). + * + * Least Angle Regression (LAR or LARS) is a model selection algorithm. It + * builds up the model using the following algorithm: + * + * 1. We start with all the coefficients equal to zero. + * 2. At each step we select the predictor that has the largest absolute + * correlation with the residual. + * 3. We take the largest step possible in the direction which is equiangular + * with all the predictors selected so far. The largest step is determined + * such that using this step a new predictor will have as much correlation + * with the residual as any of the currently active predictors. + * 4. Stop if max_iter reached or all the predictors are used, or if the + * correlation between any unused predictor and the residual is lower than + * a tolerance. + * + * The solver is based on [1]. The equations referred in the comments correspond + * to the equations in the paper. + * + * Note: this algorithm assumes that the offset is removed from X and y, and + * each feature is normalized: + * - sum_i y_i = 0, + * - sum_i x_{i,j} = 0, sum_i x_{i,j}^2=1 for j=0..n_col-1 + * + * References: + * [1] B. Efron, T. Hastie, I. Johnstone, R Tibshirani, Least Angle Regression + * The Annals of Statistics (2004) Vol 32, No 2, 407-499 + * http://statweb.stanford.edu/~tibs/ftp/lars.pdf + * + * @param handle RAFT handle + * @param[in] A device array of training vectors in column major format, + * size [n_rows * n_cols]. Note that the columns of X will be permuted if + * the Gram matrix is not specified. It is expected that X is normalized so + * that each column has zero mean and unit variance. + * @param[in] b device array of the regression targets, size [n_rows]. y should + * be normalized to have zero mean. + * @param[in] Gram device array containing Gram matrix containing X.T * X. Can be + * nullptr. + * @param[out] x: device array of regression coefficients, has to be allocated on + * entry, size [max_iter] + * @param[in] active_idx device vector containing the indices of active variables. + * Must be allocated on entry. Size [max_iter] + * @param[out] alphas device array to return the maximum correlation along the + * regularization path. Must be allocated on entry, size [max_iter+1]. + * @param[out] n_active host pointer to return the number of active elements (scalar) + * @param[out] coef_path coefficients along the regularization path are returned + * here. Must be nullptr, or a device array already allocated on entry. + * Size [max_iter * (max_iter+1)]. + * @param[in] params: lars hyper-parameters + * @param[in] ld_X leading dimension of A (stride of columns) + * @param[in] ld_G leading dimesion of G + */ +template +void minimize(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + std::optional> Gram, + raft::device_vector_view x, + raft::device_vector_view active_idx, + raft::device_vector_view alphas, + raft::host_scalar_view n_active, + std::optional> coef_path, + lars_params& params, + idx_t ld_X = 0, + idx_t ld_G = 0) +{ +} +} // namespace raft::solver::least_angle_regression \ No newline at end of file diff --git a/cpp/include/raft/solver/quasi_newton.cuh b/cpp/include/raft/solver/quasi_newton.cuh new file mode 100644 index 0000000000..0c731cebd6 --- /dev/null +++ b/cpp/include/raft/solver/quasi_newton.cuh @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace raft::solver::quasi_newton { + +/** + * The following loss functions are wrapped only so they will be included in the docs + */ + +/** + * Absolute difference loss function specification + * @tparam T + */ +template +struct AbsLoss : detail::objectives::AbsLoss { + AbsLoss(const raft::handle_t& handle, int D, bool has_bias) + : detail::objectives::AbsLoss(handle, D, has_bias) + { + } +}; + +/** + * Squared loss function specification + * @tparam T + */ +template +struct SquaredLoss : detail::objectives::SquaredLoss { + SquaredLoss(const raft::handle_t& handle, int D, bool has_bias) + : detail::objectives::SquaredLoss(handle, D, 1, has_bias), lz{}, dlz{} + { + } +}; + +/** + * Standard hinge loss function specification + * @tparam T + */ +template +struct HingeLoss : detail::objectives::HingeLoss { + HingeLoss(const raft::handle_t& handle, int D, bool has_bias) + : detail::objectives::HingeLoss(handle, D, has_bias) + { + } +}; + +/** + * + * @tparam T + */ +template +struct LogisticLoss : detail::objectives::LogisticLoss { + LogisticLoss(const raft::handle_t& handle, int D, bool has_bias) + : detail::objectives::LogisticLoss(handle, D, has_bias) + { + } +}; + +/** + * Squared hinge loss function specification + * @tparam T + */ +template +struct SqHingeLoss : detail::objectives::SqHingeLoss { + SqHingeLoss(const raft::handle_t& handle, int D, bool has_bias) + : detail::objectives::SqHingeLoss(handle, D, has_bias) + { + } +}; + + /** + * Epsilon insensitive (regression) hinge loss function specification + * @tparam T + */ + template + struct EpsInsHingeLoss : detail::objectives::EpsInsHingeLoss { + EpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) + : detail::objectives::EpsInsHingeLoss(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} + { + } +}; + +/** + * Squared Epsilon insensitive (regression) hinge loss function specification + * @tparam T + */ +template +struct SqEpsInsHingeLoss : detail::objectives::SqEpsInsHingeLoss { + SqEpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) + : detail::objectives::SqEpsInsHingeLoss(handle, D, 1, has_bias), + lz{sensitivity}, + dlz{sensitivity} + { + } +}; + +/** + * Tikhonov (l2) penalty function + * @tparam T + */ +template +struct Tikhonov : detail::objectives::Tikhonov { + Tikhonov(T l2) : detail::objectives::Tikhonov(l2) {} + + Tikhonov(const Tikhonov& other) : detail::objectives::Tikhonov(other.l2_penalty) {} +}; + +/** + * Loss function wrapper that add a penalty to another loss function + * + * Example: + * + * raft::handle_t handle; + * AbsLoss abs_loss(handle, 5, true); + * Tikhonov l2_reg(0.3); + * RegularizedQN(&abs_loss, ®); + * + * @tparam T + * @tparam Loss + * @tparam Reg + */ +template +class RegularizedQN : public detail::objectives::RegularizedQN { + RegularizedQN(Loss* loss, Reg* reg) : detail::objectives::RegularizedQN(loss, reg) {} +}; + +/** + * Base loss function that constrains the solution to a linear system + * @tparam T + * @tparam Loss + */ +template +struct QNLinearBase : detail::objectives::QNLinearBase { + QNLinearBase(const raft::handle_t& handle, int D, int C, bool fit_intercept) + : detail::objectives::QNLinearBase(C, D, fit_intercept) + { + } +}; + +/** + * Softmax loss function specification + * @tparam T + */ +template +struct Softmax : detail::objectives::Softmax { + Softmax(const raft::handle_t& handle, int D, int C, bool has_bias) + : detail::objectives::Softmax(handle, D, C, has_bias) + { + } +}; + +/** + * Constructs a end-to-end quasi-newton objective function to solve the system + * AX = b (where each row in X contains the coefficients for each target) + * + * Example: + * + * @tparam T + * @tparam QuasiNewtonObjective + */ +template +struct ObjectiveWithData : detail::objectives::QNWithData { + ObjectiveWithData(QuasiNewtonObjective* obj, + const SimpleMat& A, + const SimpleVec& b, + SimpleDenseMat& X) + : detail::objectives::QNWithData(obj->C, obj->D, obj->fit_intercept) + { + } +}; + +/** + * @brief Minimize the given `raft::solver::quasi_newton::ObjectiveWithData` using + * the Limited-Memory Broyden-Fletcher-Goldfarb-Shanno algorithm. This algorithm + * estimates the inverse of the Hessian matrix, minimizing the memory footprint from + * the original BFGS algorithm by maintaining only a subset of the update history. + * + * @tparam T + * @tparam Function + * @param param + * @param f + * @param x + * @param fx + * @param k + * @param workspace + * @param stream + * @param verbosity + * @return + */ +template +OPT_RETCODE lbfgs_minimize(raft::handle_t& handle, + const LBFGSParam& param, + Function& f, // function to minimize + SimpleVec& x, // initial point, holds result + T& fx, // output function value + int* k) +{ // output iterations + rmm::device_uvector tmp(detail::lbfgs_workspace_size(param, x.len), handle.get_stream()); + SimpleVec workspace(tmp.data(), tmp.size()); + return detail::min_lbfgs(param, f, x, fx, k, workspace, handle.get_stream(), 0); +} + +/** + * @brief Minimize the given `ObjectiveWithData` using the Orthant-wise + * Limited-Memory Quasi-Newton algorithm, an L-BFGS variant for fitting + * models with lasso (l1) penalties, enabling it to exploit the sparsity + * of the models. + * + * @tparam T + * @tparam Function + * @param param + * @param f + * @param l1_penalty + * @param pg_limit + * @param x + * @param fx + * @param k + * @return + */ +template +OPT_RETCODE owl_minimize(raft::handle_t& handle, + const LBFGSParam& param, + Function& f, + const T l1_penalty, + const int pg_limit, + SimpleVec& x, + T& fx, + int* k) +{ + rmm::device_uvector tmp(detail::owlqn_workspace_size(opt_param, x.len), stream); + SimpleVec workspace(tmp.data(), tmp.size()); + return detail::min_owlqn( + param, f, l1_penalty, pg_limit, x, fx, k, workspace, handle.get_stream(), 0); +} + +/** + * @brief Simple wrapper function that chooses the quasi-newton solver to use + * based on the presence of the L1 penalty term. + * @tparam T + * @tparam LossFunction + * @param handle + * @param x + * @param fx + * @param num_iters + * @param loss + * @param l1 + * @param opt_param + * @return + */ +template +inline int minimize(const raft::handle_t& handle, + SimpleVec& x, + T* fx, + int* num_iters, + LossFunction& loss, + const T l1, + const LBFGSParam& opt_param, + cudaStream_t stream, + const int verbosity = 0) +{ + return detail::qn_minimize(handle, x, fx, num_iters, loss, l1, opt_param, handle.get_stream(), 0); +} +} // namespace raft::solver::quasi_newton \ No newline at end of file diff --git a/cpp/include/raft/solver/simple_mat.cuh b/cpp/include/raft/solver/simple_mat.cuh new file mode 100644 index 0000000000..69bd0acdd8 --- /dev/null +++ b/cpp/include/raft/solver/simple_mat.cuh @@ -0,0 +1,624 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include +// #TODO: Replace with public header when ready +#include +#include + +#include +#include +#include +#include +#include + +/** + * NOTE: This will eventually get replaced with mdspan/mdarray + */ + +namespace raft::solver::quasi_newton { + +template +struct SimpleDenseMat; + +template +struct SimpleMat { + int m, n; + + SimpleMat(int m, int n) : m(m), n(n) {} + + void operator=(const SimpleMat& other) = delete; + + virtual void print(std::ostream& oss) const = 0; + + /** + * GEMM assigning to C where `this` refers to B. + * + * ``` + * C <- alpha * A^transA * (*this)^transB + beta * C + * ``` + */ + virtual void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const = 0; +}; + + +template +struct SimpleDenseMat : SimpleMat { + typedef SimpleMat Super; + int len; + T* data; + + STORAGE_ORDER ord; // storage order: runtime param for compile time sake + + SimpleDenseMat(STORAGE_ORDER order = COL_MAJOR) : Super(0, 0), data(nullptr), len(0), ord(order) + { + } + + SimpleDenseMat(T* data, int m, int n, STORAGE_ORDER order = COL_MAJOR) + : Super(m, n), data(data), len(m * n), ord(order) + { + } + + void reset(T* data_, int m_, int n_) + { + this->m = m_; + this->n = n_; + data = data_; + len = m_ * n_; + } + + // Implemented GEMM as a static method here to improve readability + inline static void gemm(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const SimpleDenseMat& B, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) + { + int kA = A.n; + int kB = B.m; + + if (transA) { + ASSERT(A.n == C.m, "GEMM invalid dims: m"); + kA = A.m; + } else { + ASSERT(A.m == C.m, "GEMM invalid dims: m"); + } + + if (transB) { + ASSERT(B.m == C.n, "GEMM invalid dims: n"); + kB = B.n; + } else { + ASSERT(B.n == C.n, "GEMM invalid dims: n"); + } + ASSERT(kA == kB, "GEMM invalid dims: k"); + + if (A.ord == COL_MAJOR && B.ord == COL_MAJOR && C.ord == COL_MAJOR) { + // #TODO: Call from public API when ready + raft::linalg::detail::cublasgemm(handle.get_cublas_handle(), // handle + transA ? CUBLAS_OP_T : CUBLAS_OP_N, // transA + transB ? CUBLAS_OP_T : CUBLAS_OP_N, // transB + C.m, + C.n, + kA, // dimensions m,n,k + &alpha, + A.data, + A.m, // lda + B.data, + B.m, // ldb + &beta, + C.data, + C.m, // ldc, + stream); + return; + } + if (A.ord == ROW_MAJOR) { + const SimpleDenseMat Acm(A.data, A.n, A.m, COL_MAJOR); + gemm(handle, alpha, Acm, !transA, B, transB, beta, C, stream); + return; + } + if (B.ord == ROW_MAJOR) { + const SimpleDenseMat Bcm(B.data, B.n, B.m, COL_MAJOR); + gemm(handle, alpha, A, transA, Bcm, !transB, beta, C, stream); + return; + } + if (C.ord == ROW_MAJOR) { + SimpleDenseMat Ccm(C.data, C.n, C.m, COL_MAJOR); + gemm(handle, alpha, B, !transB, A, !transA, beta, Ccm, stream); + return; + } + } + + inline void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const override + { + SimpleDenseMat::gemm(handle, alpha, A, transA, *this, transB, beta, C, stream); + } + + /** + * GEMM assigning to C where `this` refers to C. + * + * ``` + * *this <- alpha * A^transA * B^transB + beta * (*this) + * ``` + */ + inline void assign_gemm(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const SimpleMat& B, + const bool transB, + const T beta, + cudaStream_t stream) + { + B.gemmb(handle, alpha, A, transA, transB, beta, *this, stream); + } + + // this = a*x + inline void ax(const T a, const SimpleDenseMat& x, cudaStream_t stream) + { + ASSERT(ord == x.ord, "SimpleDenseMat::ax: Storage orders must match"); + + auto scale = [a] __device__(const T x) { return a * x; }; + raft::linalg::unaryOp(data, x.data, len, scale, stream); + } + + // this = a*x + y + inline void axpy(const T a, + const SimpleDenseMat& x, + const SimpleDenseMat& y, + cudaStream_t stream) + { + ASSERT(ord == x.ord, "SimpleDenseMat::axpy: Storage orders must match"); + ASSERT(ord == y.ord, "SimpleDenseMat::axpy: Storage orders must match"); + + auto axpy = [a] __device__(const T x, const T y) { return a * x + y; }; + raft::linalg::binaryOp(data, x.data, y.data, len, axpy, stream); + } + + template + inline void assign_unary(const SimpleDenseMat& other, Lambda f, cudaStream_t stream) + { + ASSERT(ord == other.ord, "SimpleDenseMat::assign_unary: Storage orders must match"); + + raft::linalg::unaryOp(data, other.data, len, f, stream); + } + + template + inline void assign_binary(const SimpleDenseMat& other1, + const SimpleDenseMat& other2, + Lambda& f, + cudaStream_t stream) + { + ASSERT(ord == other1.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); + ASSERT(ord == other2.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); + + raft::linalg::binaryOp(data, other1.data, other2.data, len, f, stream); + } + + template + inline void assign_ternary(const SimpleDenseMat& other1, + const SimpleDenseMat& other2, + const SimpleDenseMat& other3, + Lambda& f, + cudaStream_t stream) + { + ASSERT(ord == other1.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + ASSERT(ord == other2.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + ASSERT(ord == other3.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + + raft::linalg::ternaryOp(data, other1.data, other2.data, other3.data, len, f, stream); + } + + inline void fill(const T val, cudaStream_t stream) + { + // TODO this reads data unnecessary, though it's mostly used for testing + auto f = [val] __device__(const T x) { return val; }; + raft::linalg::unaryOp(data, data, len, f, stream); + } + + inline void copy_async(const SimpleDenseMat& other, cudaStream_t stream) + { + ASSERT((ord == other.ord) && (this->m == other.m) && (this->n == other.n), + "SimpleDenseMat::copy: matrices not compatible"); + + RAFT_CUDA_TRY( + cudaMemcpyAsync(data, other.data, len * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + } + + void print(std::ostream& oss) const override { oss << (*this) << std::endl; } + + void operator=(const SimpleDenseMat& other) = delete; +}; + +template +struct SimpleVec : SimpleDenseMat { + typedef SimpleDenseMat Super; + + SimpleVec(T* data, const int n) : Super(data, n, 1, COL_MAJOR) {} + // this = alpha * A * x + beta * this + void assign_gemv(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + bool transA, + const SimpleVec& x, + const T beta, + cudaStream_t stream) + { + Super::assign_gemm(handle, alpha, A, transA, x, false, beta, stream); + } + + SimpleVec() : Super(COL_MAJOR) {} + + inline void reset(T* new_data, int n) { Super::reset(new_data, n, 1); } +}; + +template +inline void col_ref(const SimpleDenseMat& mat, SimpleVec& mask_vec, int c) +{ + ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); + T* tmp = &mat.data[mat.m * c]; + mask_vec.reset(tmp, mat.m); +} + +template +inline void col_slice(const SimpleDenseMat& mat, + SimpleDenseMat& mask_mat, + int c_from, + int c_to) +{ + ASSERT(c_from >= 0 && c_from < mat.n, "col_slice: invalid from"); + ASSERT(c_to >= 0 && c_to <= mat.n, "col_slice: invalid to"); + + ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); + ASSERT(mask_mat.ord == COL_MAJOR, "col_ref only available for column major mask"); + T* tmp = &mat.data[mat.m * c_from]; + mask_mat.reset(tmp, mat.m, c_to - c_from); +} + +// Reductions such as dot or norm require an additional location in dev mem +// to hold the result. We don't want to deal with this in the SimpleVec class +// as it impedes thread safety and constness + +template +inline T dot(const SimpleVec& u, const SimpleVec& v, T* tmp_dev, cudaStream_t stream) +{ + auto f = [] __device__(const T x, const T y) { return x * y; }; + raft::linalg::mapThenSumReduce(tmp_dev, u.len, f, stream, u.data, v.data); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + + raft::interruptible::synchronize(stream); + return tmp_host; +} + +template +inline T squaredNorm(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + return dot(u, u, tmp_dev, stream); +} + +template +inline T nrmMax(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + auto f = [] __device__(const T x) { return raft::myAbs(x); }; + auto r = [] __device__(const T x, const T y) { return raft::myMax(x, y); }; + raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + raft::interruptible::synchronize(stream); + return tmp_host; +} + +template +inline T nrm2(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + return raft::mySqrt(squaredNorm(u, tmp_dev, stream)); +} + +template +inline T nrm1(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + raft::linalg::rowNorm( + tmp_dev, u.data, u.len, 1, raft::linalg::L1Norm, true, stream, raft::Nop()); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + raft::interruptible::synchronize(stream); + return tmp_host; +} + +template +std::ostream& operator<<(std::ostream& os, const SimpleVec& v) +{ + std::vector out(v.len); + raft::update_host(&out[0], v.data, v.len, 0); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + int it = 0; + for (; it < v.len - 1;) { + os << out[it] << " "; + it++; + } + os << out[it]; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const SimpleDenseMat& mat) +{ + os << "ord=" << (mat.ord == COL_MAJOR ? "CM" : "RM") << "\n"; + std::vector out(mat.len); + raft::update_host(&out[0], mat.data, mat.len, rmm::cuda_stream_default); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + if (mat.ord == COL_MAJOR) { + for (int r = 0; r < mat.m; r++) { + int idx = r; + for (int c = 0; c < mat.n - 1; c++) { + os << out[idx] << ","; + idx += mat.m; + } + os << out[idx] << std::endl; + } + } else { + for (int c = 0; c < mat.m; c++) { + int idx = c * mat.n; + for (int r = 0; r < mat.n - 1; r++) { + os << out[idx] << ","; + idx += 1; + } + os << out[idx] << std::endl; + } + } + + return os; +} + +template +struct SimpleVecOwning : SimpleVec { + typedef SimpleVec Super; + typedef rmm::device_uvector Buffer; + Buffer buf; + + SimpleVecOwning() = delete; + + SimpleVecOwning(int n, cudaStream_t stream) : Super(), buf(n, stream) + { + Super::reset(buf.data(), n); + } + + void operator=(const SimpleVec& other) = delete; +}; + +template +struct SimpleMatOwning : SimpleDenseMat { + typedef SimpleDenseMat Super; + typedef rmm::device_uvector Buffer; + Buffer buf; + using Super::m; + using Super::n; + using Super::ord; + + SimpleMatOwning() = delete; + + SimpleMatOwning(int m, int n, cudaStream_t stream, STORAGE_ORDER order = COL_MAJOR) + : Super(order), buf(m * n, stream) + { + Super::reset(buf.data(), m, n); + } + + void operator=(const SimpleVec& other) = delete; +}; + +/** + * Sparse matrix in CSR format. + * + * Note, we use cuSPARSE to manimulate matrices, and it guarantees: + * + * 1. row_ids[m] == nnz + * 2. cols are sorted within rows. + * + * However, when the data comes from the outside, we cannot guarantee that. + */ +template +struct SimpleSparseMat : SimpleMat { + typedef SimpleMat Super; + T* values; + int* cols; + int* row_ids; + int nnz; + + SimpleSparseMat() : Super(0, 0), values(nullptr), cols(nullptr), row_ids(nullptr), nnz(0) {} + + SimpleSparseMat(T* values, int* cols, int* row_ids, int nnz, int m, int n) + : Super(m, n), values(values), cols(cols), row_ids(row_ids), nnz(nnz) + { + check_csr(*this, 0); + } + + void print(std::ostream& oss) const override { oss << (*this) << std::endl; } + + void operator=(const SimpleSparseMat& other) = delete; + + inline void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const override + { + const SimpleSparseMat& B = *this; + int kA = A.n; + int kB = B.m; + + if (transA) { + ASSERT(A.n == C.m, "GEMM invalid dims: m"); + kA = A.m; + } else { + ASSERT(A.m == C.m, "GEMM invalid dims: m"); + } + + if (transB) { + ASSERT(B.m == C.n, "GEMM invalid dims: n"); + kB = B.n; + } else { + ASSERT(B.n == C.n, "GEMM invalid dims: n"); + } + ASSERT(kA == kB, "GEMM invalid dims: k"); + + // matrix C must change the order and be transposed, because we need + // to swap arguments A and B in cusparseSpMM. + cusparseDnMatDescr_t descrC; + auto order = C.ord == COL_MAJOR ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &descrC, C.n, C.m, order == CUSPARSE_ORDER_COL ? C.n : C.m, C.data, order)); + + /* + The matrix A must have the same order as the matrix C in the input + of function cusparseSpMM (i.e. swapped order w.r.t. original C). + To account this requirement, I may need to flip transA (whether to transpose A). + + C C' rowsC' colsC' ldC' A A' rowsA' colsA' ldA' flipTransA + c r n m m c r n m m x + c r n m m r r m n n o + r c n m n c c m n m o + r c n m n r c n m n x + + where: + c/r - column/row major order + A,C - input to gemmb + A', C' - input to cusparseSpMM + ldX' - leading dimension - m or n, depending on order and transX + */ + cusparseDnMatDescr_t descrA; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(&descrA, + C.ord == A.ord ? A.n : A.m, + C.ord == A.ord ? A.m : A.n, + A.ord == COL_MAJOR ? A.m : A.n, + A.data, + order)); + auto opA = + transA ^ (C.ord == A.ord) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; + + cusparseSpMatDescr_t descrB; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &descrB, B.m, B.n, B.nnz, B.row_ids, B.cols, B.values)); + auto opB = transB ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; + + auto alg = order == CUSPARSE_ORDER_COL ? CUSPARSE_SPMM_CSR_ALG1 : CUSPARSE_SPMM_CSR_ALG2; + + size_t bufferSize; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(), + opB, + opA, + &alpha, + descrB, + descrA, + &beta, + descrC, + alg, + &bufferSize, + stream)); + + raft::interruptible::synchronize(stream); + rmm::device_uvector tmp(bufferSize, stream); + + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(), + opB, + opA, + &alpha, + descrB, + descrA, + &beta, + descrC, + alg, + tmp.data(), + stream)); + + RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrA)); + RAFT_CUSPARSE_TRY(cusparseDestroySpMat(descrB)); + RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrC)); + } +}; + +template +inline void check_csr(const SimpleSparseMat& mat, cudaStream_t stream) +{ + int row_ids_nnz; + raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream); + raft::interruptible::synchronize(stream); + ASSERT(row_ids_nnz == mat.nnz, + "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and " + "the last element must be equal nnz."); +} + +template +std::ostream& operator<<(std::ostream& os, const SimpleSparseMat& mat) +{ + check_csr(mat, 0); + os << "SimpleSparseMat (CSR)" + << "\n"; + std::vector values(mat.nnz); + std::vector cols(mat.nnz); + std::vector row_ids(mat.m + 1); + raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default); + raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default); + raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + + int i, row_end = 0; + for (int row = 0; row < mat.m; row++) { + i = row_end; + row_end = row_ids[row + 1]; + for (int col = 0; col < mat.n; col++) { + if (i >= row_end || col < cols[i]) { + os << "0"; + } else { + os << values[i]; + i++; + } + if (col < mat.n - 1) os << ","; + } + + os << std::endl; + } + + return os; +} + +}; // namespace raft::solver diff --git a/cpp/include/raft/solver/solver_types.hpp b/cpp/include/raft/solver/solver_types.hpp new file mode 100644 index 0000000000..db2f63ee65 --- /dev/null +++ b/cpp/include/raft/solver/solver_types.hpp @@ -0,0 +1,307 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::solver { + +enum STORAGE_ORDER { COL_MAJOR = 0, ROW_MAJOR = 1 }; + + enum lr_type { + OPTIMAL, + CONSTANT, + INVSCALING, + ADAPTIVE, +}; + +enum loss_funct { + SQUARED, + HINGE, + LOG, +}; + +enum penalty { NONE, L1, L2, ELASTICNET }; + +namespace gradient_descent { +template +struct sgd_params { + int batch_size; + int epochs; + lr_type lr_type; + math_t eta0; + math_t power_t; + loss_funct loss; + penalty penalty; + math_t alpha; + math_t l1_ratio; + bool shuffle; + math_t tol; + int n_iter_no_change; + + sgd_params() + : batch_size(100), + epochs(100), + lr_type(lr_type::OPTIMAL), + eta0(0.5), + power_t(0.5), + loss(loss_funct::SQUARED), + penalty(penalty::L1), + alpha(0.5), + l1_ratio(0.2), + shuffle(true), + tol(1e-8), + n_iter_no_change(5) + { + } +}; +} // namespace gradient_descent +namespace coordinate_descent { +template +struct cd_params { + bool normalize; // whether to normalize the data to zero-mean and unit std + int epochs; // number of iterations + loss_funct loss; // loss function to minimize + math_t alpha; // l1 penalty parameter + math_t l1_ratio; // ratio of alpha that will be used for l1 penalty. (1 - l1_ratio) * alpha will + // be used for l2 penalty + bool shuffle; // randomly pick coordinates + math_t tol; // early-stopping convergence tolerance + + cd_params() + : normalize(true), + epochs(100), + alpha(0.3), + l1_ratio(0.5), + shuffle(true), + tol(1e-8), + loss(loss_funct::SQRD_LOSS) + { + } +}; +} // namespace coordinate_descent + +namespace least_angle_regression { +template +struct lars_params { + int max_iter; + math_t eps; + + lars_params() : max_iter(500), eps(-1) {} +}; +} // namespace least_angle_regression + +enum class LarsFitStatus { kOk, kCollinear, kError, kStop }; + +namespace quasi_newton { + +/** Loss function types supported by the Quasi-Newton solvers. */ +enum qn_loss_type { + /** Logistic classification. + * Expected target: {0, 1}. + */ + QN_LOSS_LOGISTIC = 0, + /** L2 regression. + * Expected target: R. + */ + QN_LOSS_SQUARED = 1, + /** Softmax classification.. + * Expected target: {0, 1, ...}. + */ + QN_LOSS_SOFTMAX = 2, + /** Hinge. + * Expected target: {0, 1}. + */ + QN_LOSS_HINGE = 3, + /** Squared-hinge. + * Expected target: {0, 1}. + */ + QN_LOSS_SQ_HINGE = 4, + /** Epsilon-insensitive. + * Expected target: R. + */ + QN_LOSS_HINGE_EPS_INS = 5, + /** Epsilon-insensitive-squared. + * Expected target: R. + */ + QN_LOSS_HINGE_SQ_EPS_INS = 6, + /** L1 regression. + * Expected target: R. + */ + QN_LOSS_ABS = 7, + /** Someone forgot to set the loss type! */ + QN_LOSS_UNKNOWN = 99 +}; + + + struct qn_params { + /** Loss type. */ + qn_loss_type loss; + /** Regularization: L1 component. */ + double penalty_l1; + /** Regularization: L2 component. */ + double penalty_l2; + /** Convergence criteria: the threshold on the gradient. */ + double grad_tol; + /** Convergence criteria: the threshold on the function change. */ + double change_tol; + /** Maximum number of iterations. */ + int max_iter; + /** Maximum number of linesearch (inner loop) iterations. */ + int linesearch_max_iter; + /** Number of vectors approximating the hessian (l-bfgs). */ + int lbfgs_memory; + /** Triggers extra output when greater than zero. */ + int verbose; + /** Whether to fit the bias term. */ + bool fit_intercept; + /** + * Whether to divide the L1 and L2 regularization parameters by the sample size. + * + * Note, the defined QN loss functions normally are scaled for the sample size, + * e.g. the average across the data rows is calculated. + * Enabling `penalty_normalized` makes this solver's behavior compatible to those solvers, + * which do not scale the loss functions (like sklearn.LogisticRegression()). + */ + bool penalty_normalized; + + qn_params() + : loss(QN_LOSS_UNKNOWN), + penalty_l1(0), + penalty_l2(0), + grad_tol(1e-4), + change_tol(1e-5), + max_iter(1000), + linesearch_max_iter(50), + lbfgs_memory(5), + verbose(0), + fit_intercept(true), + penalty_normalized(true) + { + } +}; + +enum LINE_SEARCH_ALGORITHM { + LBFGS_LS_BT_ARMIJO = 1, + LBFGS_LS_BT = 2, // Default. Alias for Wolfe + LBFGS_LS_BT_WOLFE = 2, + LBFGS_LS_BT_STRONG_WOLFE = 3 +}; + +enum LINE_SEARCH_RETCODE { + LS_SUCCESS = 0, + LS_INVALID_STEP_MIN = 1, + LS_INVALID_STEP_MAX = 2, + LS_MAX_ITERS_REACHED = 3, + LS_INVALID_DIR = 4, + LS_INVALID_STEP = 5 +}; + +enum OPT_RETCODE { + OPT_SUCCESS = 0, + OPT_NUMERIC_ERROR = 1, + OPT_LS_FAILED = 2, + OPT_MAX_ITERS_REACHED = 3, + OPT_INVALID_ARGS = 4 +}; + +template +class LBFGSParam { + public: + int m; // lbfgs memory limit + T epsilon; // controls convergence + int past; // lookback for function value based convergence test + T delta; // controls fun val based conv test + int max_iterations; + int linesearch; // see enum above + int max_linesearch; + T min_step; // min. allowed step length + T max_step; // max. allowed step length + T ftol; // line search tolerance + T wolfe; // wolfe parameter + T ls_dec; // line search decrease factor + T ls_inc; // line search increase factor + + public: + LBFGSParam() + { + m = 6; + epsilon = T(1e-5); + past = 0; + delta = T(0); + max_iterations = 0; + linesearch = LBFGS_LS_BT_ARMIJO; + max_linesearch = 20; + min_step = T(1e-20); + max_step = T(1e+20); + ftol = T(1e-4); + wolfe = T(0.9); + ls_dec = T(0.5); + ls_inc = T(2.1); + } + + explicit LBFGSParam(const qn_params& pams) : LBFGSParam() + { + m = pams.lbfgs_memory; + epsilon = T(pams.grad_tol); + // sometimes even number works better - to detect zig-zags; + past = pams.change_tol > 0 ? 10 : 0; + delta = T(pams.change_tol); + max_iterations = pams.max_iter; + max_linesearch = pams.linesearch_max_iter; + ftol = pams.change_tol > 0 ? T(pams.change_tol * 0.1) : T(1e-4); + } + + inline int check_param() const + { // TODO exceptions + int ret = 1; + if (m <= 0) return ret; + ret++; + if (epsilon <= 0) return ret; + ret++; + if (past < 0) return ret; + ret++; + if (delta < 0) return ret; + ret++; + if (max_iterations < 0) return ret; + ret++; + if (linesearch < LBFGS_LS_BT_ARMIJO || linesearch > LBFGS_LS_BT_STRONG_WOLFE) return ret; + ret++; + if (max_linesearch <= 0) return ret; + ret++; + if (min_step < 0) return ret; + ret++; + if (max_step < min_step) return ret; + ret++; + if (ftol <= 0 || ftol >= 0.5) return ret; + ret++; + if (wolfe <= ftol || wolfe >= 1) return ret; + ret++; + return 0; + } +}; + +struct LinearDims { + bool fit_intercept; + int C, D, dims, n_param; + LinearDims(int C, int D, bool fit_intercept) : C(C), D(D), fit_intercept(fit_intercept) + { + dims = D + fit_intercept; + n_param = dims * C; + } +}; +} // namespace quasi_newton + +} // namespace raft::solver diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 5be8401a6f..c75c07a708 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -197,10 +197,14 @@ if(BUILD_TESTS) test/random/sample_without_replacement.cu ) - ConfigureTest( - NAME SOLVERS_TEST PATH test/cluster_solvers_deprecated.cu test/eigen_solvers.cu test/lap/lap.cu - test/mst.cu OPTIONAL DIST - ) + ConfigureTest(NAME SOLVERS_TEST + PATH + test/cluster_solvers_deprecated.cu + test/eigen_solvers.cu + test/solver/lap.cu + test/solver/quasi_newton.cu + test/mst.cu + ) ConfigureTest( NAME diff --git a/cpp/test/lap/lap.cu b/cpp/test/solver/lap.cu similarity index 100% rename from cpp/test/lap/lap.cu rename to cpp/test/solver/lap.cu diff --git a/cpp/test/solver/quasi_newton.cu b/cpp/test/solver/quasi_newton.cu new file mode 100644 index 0000000000..ff2d2316aa --- /dev/null +++ b/cpp/test/solver/quasi_newton.cu @@ -0,0 +1,824 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include + +#include + +#include +#include +#include + +namespace raft::solver::quasi_newton { + +template +int qn_fit(const raft::handle_t& handle, + const qn_params& pams, + LossFunction& loss, + const SimpleMat& X, + const SimpleVec& y, + SimpleDenseMat& Z, + T* w0_data, // initial value and result + T* fx, + int* num_iters) +{ + LBFGSParam opt_param(pams); + SimpleVec w0(w0_data, loss.n_param); + + // Scale the regularization strenght with the number of samples. + T l1 = pams.penalty_l1; + T l2 = pams.penalty_l2; + if (pams.penalty_normalized) { + l1 /= X.m; + l2 /= X.m; + } + + if (l2 == 0) { + ObjectiveWithData lossWith(&loss, X, y, Z); + + return minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); + + } else { + Tikhonov reg(l2); + RegularizedQN obj(&loss, ®); + ObjectiveWithData lossWith(&obj, X, y, Z); + + return minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); + } +} + +template +inline void qn_fit_x(const raft::handle_t& handle, + const qn_params& pams, + SimpleMat& X, + T* y_data, + int C, + T* w0_data, + T* f, + int* num_iters, + cudaStream_t stream, + T* sample_weight = nullptr, + T svr_eps = 0) +{ + /* + NB: + N - number of data rows + D - number of data columns (features) + C - number of output classes + + X in R^[N, D] + w in R^[D, C] + y in {0, 1}^[N, C] or {cat}^N + + Dimensionality of w0 depends on loss, so we initialize it later. + */ + int N = X.m; + int D = X.n; + int n_targets = qn_is_classification(pams.loss) && C == 2 ? 1 : C; + rmm::device_uvector tmp(n_targets * N, stream); + SimpleDenseMat Z(tmp.data(), n_targets, N); + SimpleVec y(y_data, N); + + switch (pams.loss) { + case QN_LOSS_LOGISTIC: { + ASSERT(C == 2, "qn.h: logistic loss invalid C"); + LogisticLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SQUARED: { + ASSERT(C == 1, "qn.h: squared loss invalid C"); + SquaredLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SOFTMAX: { + ASSERT(C > 2, "qn.h: softmax invalid C"); + Softmax loss(handle, D, C, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVC_L1: { + ASSERT(C == 2, "qn.h: SVC-L1 loss invalid C"); + HingeLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVC_L2: { + ASSERT(C == 2, "qn.h: SVC-L2 loss invalid C"); + SqHingeLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVR_L1: { + ASSERT(C == 1, "qn.h: SVR-L1 loss invalid C"); + EpsInsHingeLoss loss(handle, D, pams.fit_intercept, svr_eps); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVR_L2: { + ASSERT(C == 1, "qn.h: SVR-L2 loss invalid C"); + SqEpsInsHingeLoss loss(handle, D, pams.fit_intercept, svr_eps); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_ABS: { + ASSERT(C == 1, "qn.h: abs loss (L1) invalid C"); + AbsLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + default: { + ASSERT(false, "qn.h: unknown loss function type (id = %d).", pams.loss); + } + } +} + +struct QuasiNewtonTest : ::testing::Test { + + QuasiNewtonTest() {} + void SetUp() + { + stream = handle.get_stream(); + Xdev.reset(new SimpleMatOwning(N, D, stream, ROW_MAJOR)); + raft::update_device(Xdev->data, &X[0][0], Xdev->len, stream); + + ydev.reset(new SimpleVecOwning(N, stream)); + handle.sync_stream(stream); + } + void TearDown() {} + + static constexpr int N = 10; + static constexpr int D = 2; + + const static double* nobptr; + const static double tol; + const static double X[N][D]; + const raft::handle_t handle; + cudaStream_t stream = 0; + std::shared_ptr> Xdev; + std::shared_ptr> ydev; + +}; + +const double* QuasiNewtonTest::nobptr = 0; +const double QuasiNewtonTest::tol = 5e-6; +const double QuasiNewtonTest::X[QuasiNewtonTest::N][QuasiNewtonTest::D] = { + {-0.2047076594847130, 0.4789433380575482}, + {-0.5194387150567381, -0.5557303043474900}, + {1.9657805725027142, 1.3934058329729904}, + {0.0929078767437177, 0.2817461528302025}, + {0.7690225676118387, 1.2464347363862822}, + {1.0071893575830049, -1.2962211091122635}, + {0.2749916334321240, 0.2289128789353159}, + {1.3529168351654497, 0.8864293405915888}, + {-2.0016373096603974, -0.3718425371402544}, + {1.6690253095248706, -0.4385697358355719}}; + +template +::testing::AssertionResult checkParamsEqual(const raft::handle_t& handle, + const T* host_weights, + const T* host_bias, + const T* w, + const LinearDims& dims, + Comp& comp, + cudaStream_t stream) +{ + int C = dims.C; + int D = dims.D; + bool fit_intercept = dims.fit_intercept; + std::vector w_ref_cm(C * D); + int idx = 0; + for (int d = 0; d < D; d++) + for (int c = 0; c < C; c++) { + w_ref_cm[idx++] = host_weights[c * D + d]; + } + + SimpleVecOwning w_ref(dims.n_param, stream); + raft::update_device(w_ref.data, &w_ref_cm[0], C * D, stream); + if (fit_intercept) { raft::update_device(&w_ref.data[C * D], host_bias, C, stream); } + handle.sync_stream(stream); + return raft::devArrMatch(w_ref.data, w, w_ref.len, comp); +} + +template +T run(const raft::handle_t& handle, + LossFunction& loss, + const SimpleMat& X, + const SimpleVec& y, + T l1, + T l2, + T* w, + SimpleDenseMat& z) +{ + qn_params pams; + pams.max_iter = 100; + pams.grad_tol = 1e-16; + pams.change_tol = 1e-16; + pams.linesearch_max_iter = 50; + pams.lbfgs_memory = 5; + pams.penalty_l1 = l1; + pams.penalty_l2 = l2; + pams.verbose = verbosity; + + int num_iters = 0; + + T fx; + + qn_fit(handle, pams, loss, X, y, z, w, &fx, &num_iters); + + return fx; +} + +template +T run_api(const raft::handle_t& handle, + qn_loss_type loss_type, + int C, + bool fit_intercept, + const SimpleMat& X, + const SimpleVec& y, + T l1, + T l2, + T* w, + SimpleDenseMat& z, + int verbosity, + cudaStream_t stream) +{ + qn_params pams; + + pams.max_iter = 100; + pams.grad_tol = 1e-8; + pams.change_tol = 1e-8; + pams.linesearch_max_iter = 50; + pams.lbfgs_memory = 5; + pams.penalty_l1 = l1; + pams.penalty_l2 = l2; + pams.verbose = verbosity; + pams.fit_intercept = fit_intercept; + pams.loss = loss_type; + + int num_iters = 0; + + SimpleVec w0(w, X.n + fit_intercept); + w0.fill(T(0), stream); + T fx; + + qn_fit_on_x(handle, + pams, + X_dense->data, + X_dense->ord == COL_MAJOR, + y.data, + X_dense->m, + X_dense->n, + C, + w, + &fx, + &num_iters); + + return fx; +} + +TEST_F(QuasiNewtonTest, binary_logistic_vs_sklearn) +{ +#if CUDART_VERSION >= 11020 + GTEST_SKIP(); +#endif + raft::CompareApprox compApprox(tol); + // Test case generated in python and solved with sklearn + double y[N] = {1, 1, 1, 0, 1, 0, 1, 0, 1, 0}; + raft::update_device(ydev->data, &y[0], ydev->len, stream); + handle.sync_stream(stream); + + double alpha = 0.01 * N; + + LogisticLoss loss_b(handle, D, true); + LogisticLoss loss_no_b(handle, D, false); + + SimpleVecOwning w0(D + 1, stream); + SimpleMatOwning z(1, N, stream); + + double l1, l2, fx; + + double w_l1_b[2] = {-1.6899370396155091, 1.9021577534928300}; + double b_l1_b = 0.8057670813749118; + double obj_l1_b = 0.44295941481024703; + + l1 = alpha; + l2 = 0.0; + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + ASSERT_TRUE(checkParamsEqual(handle, &w_l1_b[0], &b_l1_b, w0.data, loss_b, compApprox, stream)); + + fx = run_api( + handle, QN_LOSS_LOGISTIC, 2, loss_b.fit_intercept, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + double w_l2_b[2] = {-1.5339880402781370, 1.6788639581350926}; + double b_l2_b = 0.806087868102401; + double obj_l2_b = 0.4378085369889721; + + l1 = 0; + l2 = alpha; + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + ASSERT_TRUE(checkParamsEqual(handle, &w_l2_b[0], &b_l2_b, w0.data, loss_b, compApprox, stream)); + + fx = run_api(handle, + QN_LOSS_LOGISTIC, + 2, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + double w_l1_no_b[2] = {-1.6215035298864591, 2.3650868394981086}; + double obj_l1_no_b = 0.4769896009200278; + + l1 = alpha; + l2 = 0.0; + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + ASSERT_TRUE( + checkParamsEqual(handle, &w_l1_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + + fx = run_api(handle, + QN_LOSS_LOGISTIC, + 2, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + double w_l2_no_b[2] = {-1.3931049893764620, 2.0140103094119621}; + double obj_l2_no_b = 0.47502098062114273; + + l1 = 0; + l2 = alpha; + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + ASSERT_TRUE( + checkParamsEqual(handle, &w_l2_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + + fx = run_api(handle, + QN_LOSS_LOGISTIC, + 2, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); +} + +TEST_F(QuasiNewtonTest, multiclass_logistic_vs_sklearn) +{ +#if CUDART_VERSION >= 11020 + GTEST_SKIP(); +#endif + // The data seems to small for the objective to be strongly convex + // leaving out exact param checks + + raft::CompareApprox compApprox(tol); + double y[N] = {2, 2, 0, 3, 3, 0, 0, 0, 1, 0}; + raft::update_device(ydev->data, &y[0], ydev->len, stream); + handle.sync_stream(stream); + + double fx, l1, l2; + int C = 4; + + double alpha = 0.016 * N; + + SimpleMatOwning z(C, N, stream); + SimpleVecOwning w0(C * (D + 1), stream); + + Softmax loss_b(handle, D, C, true); + Softmax loss_no_b(handle, D, C, false); + + l1 = alpha; + l2 = 0.0; + double obj_l1_b = 0.5407911382311313; + + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + fx = run_api(handle, + QN_LOSS_SOFTMAX, + C, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + l1 = 0.0; + l2 = alpha; + double obj_l2_b = 0.5721784062720949; + + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + fx = run_api(handle, + QN_LOSS_SOFTMAX, + C, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + l1 = alpha; + l2 = 0.0; + double obj_l1_no_b = 0.6606929813245878; + + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + fx = run_api(handle, + QN_LOSS_SOFTMAX, + C, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + l1 = 0.0; + l2 = alpha; + + double obj_l2_no_b = 0.6597171282106854; + + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + + fx = run_api(handle, + QN_LOSS_SOFTMAX, + C, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); +} + +TEST_F(QuasiNewtonTest, linear_regression_vs_sklearn) +{ + raft::CompareApprox compApprox(tol); + double y[N] = {0.2675836026202781, + -0.0678277759663704, + -0.6334027174275105, + -0.1018336189077367, + 0.0933815935886932, + -1.1058853496996381, + -0.1658298189619160, + -0.2954290675648911, + 0.7966520536712608, + -1.0767450516284769}; + raft::update_device(ydev->data, &y[0], ydev->len, stream); + handle.sync_stream(stream); + + double fx, l1, l2; + double alpha = 0.01 * N; + + SimpleVecOwning w0(D + 1, stream); + SimpleMatOwning z(1, N, stream); + SquaredLoss loss_b(handle, D, true); + SquaredLoss loss_no_b(handle, D, false); + + l1 = alpha; + l2 = 0.0; + double w_l1_b[2] = {-0.4952397281519840, 0.3813315300180231}; + double b_l1_b = -0.08140861819001188; + double obj_l1_b = 0.011136986298775138; + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + ASSERT_TRUE(checkParamsEqual(handle, &w_l1_b[0], &b_l1_b, w0.data, loss_b, compApprox, stream)); + + fx = run_api(handle, + QN_LOSS_SQUARED, + 1, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + l1 = 0.0; + l2 = alpha; + double w_l2_b[2] = {-0.5022384743587150, 0.3937352417485087}; + double b_l2_b = -0.08062397391797513; + double obj_l2_b = 0.004268621967866347; + + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + ASSERT_TRUE(checkParamsEqual(handle, &w_l2_b[0], &b_l2_b, w0.data, loss_b, compApprox, stream)); + + fx = run_api(handle, + QN_LOSS_SQUARED, + 1, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + l1 = alpha; + l2 = 0.0; + double w_l1_no_b[2] = {-0.5175178128147135, 0.3720844589831813}; + double obj_l1_no_b = 0.013981355746112447; + + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + ASSERT_TRUE( + checkParamsEqual(handle, &w_l1_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + + fx = run_api(handle, + QN_LOSS_SQUARED, + 1, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + l1 = 0.0; + l2 = alpha; + double w_l2_no_b[2] = {-0.5241651041233270, 0.3846317886627560}; + double obj_l2_no_b = 0.007061261366969662; + + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + ASSERT_TRUE( + checkParamsEqual(handle, &w_l2_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + + fx = run_api(handle, + QN_LOSS_SQUARED, + 1, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); +} + +TEST_F(QuasiNewtonTest, predict) +{ + raft::CompareApprox compApprox(1e-8); + std::vector w_host(D); + w_host[0] = 1; + std::vector preds_host(N); + SimpleVecOwning w(D, stream); + SimpleVecOwning preds(N, stream); + + raft::update_device(w.data, &w_host[0], w.len, stream); + qn_params pams; + pams.loss = QN_LOSS_LOGISTIC; + pams.fit_intercept = false; + + qnPredict(handle, pams, Xdev->data, false, N, D, 2, w.data, preds.data, stream); + raft::update_host(&preds_host[0], preds.data, preds.len, stream); + handle.sync_stream(stream); + + for (int it = 0; it < N; it++) { + ASSERT_TRUE(X[it][0] > 0 ? compApprox(preds_host[it], 1) : compApprox(preds_host[it], 0)); + } + + pams.loss = QN_LOSS_SQUARED; + pams.fit_intercept = false; + qnPredict(handle, pams, Xdev->data, false, N, D, 1, w.data, preds.data, stream); + raft::update_host(&preds_host[0], preds.data, preds.len, stream); + handle.sync_stream(stream); + + for (int it = 0; it < N; it++) { + ASSERT_TRUE(compApprox(X[it][0], preds_host[it])); + } +} + +TEST_F(QuasiNewtonTest, predict_softmax) +{ + raft::CompareApprox compApprox(1e-8); + int C = 4; + std::vector w_host(C * D); + w_host[0] = 1; + w_host[D * C - 1] = 1; + + std::vector preds_host(N); + SimpleVecOwning w(w_host.size(), stream); + SimpleVecOwning preds(N, stream); + + raft::update_device(w.data, &w_host[0], w.len, stream); + + qn_params pams; + pams.loss = QN_LOSS_SOFTMAX; + pams.fit_intercept = false; + qnPredict(handle, pams, Xdev->data, false, N, D, C, w.data, preds.data, stream); + raft::update_host(&preds_host[0], preds.data, preds.len, stream); + handle.sync_stream(stream); + + for (int it = 0; it < N; it++) { + if (X[it][0] < 0 && X[it][1] < 0) { + ASSERT_TRUE(compApprox(1, preds_host[it])); + } else if (X[it][0] > X[it][1]) { + ASSERT_TRUE(compApprox(0, preds_host[it])); + } else { + ASSERT_TRUE(compApprox(C - 1, preds_host[it])); + } + } +} + +TEST_F(QuasiNewtonTest, dense_vs_sparse_logistic) +{ +#if CUDART_VERSION >= 11020 + GTEST_SKIP(); +#endif + // Prepare a sparse input matrix from the dense matrix X. + // Yes, it's not sparse at all, yet the test does check whether the behaviour + // of dense and sparse variants is the same. + rmm::device_uvector mem_X_cols(N * D, stream); + rmm::device_uvector mem_X_row_ids(N + 1, stream); + int host_X_cols[N][D]; + int host_X_row_ids[N + 1]; + for (int i = 0; i < N; i++) { + for (int j = 0; j < D; j++) { + host_X_cols[i][j] = j; + } + } + for (int i = 0; i < N + 1; i++) { + host_X_row_ids[i] = i * D; + } + raft::update_device(mem_X_cols.data(), &host_X_cols[0][0], mem_X_cols.size(), stream); + raft::update_device(mem_X_row_ids.data(), &host_X_row_ids[0], mem_X_row_ids.size(), stream); + SimpleSparseMat X_sparse( + Xdev->data, mem_X_cols.data(), mem_X_row_ids.data(), N * D, N, D); + + raft::CompareApprox compApprox(tol); + double y[N] = {2, 2, 0, 3, 3, 0, 0, 0, 1, 0}; + raft::update_device(ydev->data, &y[0], ydev->len, stream); + handle.sync_stream(stream); + + int C = 4; + qn_loss_type loss_type = QN_LOSS_SOFTMAX; // Softmax (loss_b, loss_no_b) + double alpha = 0.016 * N; + Softmax loss_b(handle, D, C, true); + Softmax loss_no_b(handle, D, C, false); + + SimpleMatOwning z_dense(C, N, stream); + SimpleMatOwning z_sparse(C, N, stream); + SimpleVecOwning w0_dense(C * (D + 1), stream); + SimpleVecOwning w0_sparse(C * (D + 1), stream); + + std::vector preds_dense_host(N); + std::vector preds_sparse_host(N); + SimpleVecOwning preds_dense(N, stream); + SimpleVecOwning preds_sparse(N, stream); + + auto test_run = [&](double l1, double l2, Softmax loss) { + qn_params pams; + pams.penalty_l1 = l1; + pams.penalty_l2 = l2; + pams.loss = loss_type; + pams.fit_intercept = loss.fit_intercept; + + double f_dense, f_sparse; + f_dense = run(handle, loss, *Xdev, *ydev, l1, l2, w0_dense.data, z_dense, 0, stream); + f_sparse = run(handle, loss, X_sparse, *ydev, l1, l2, w0_sparse.data, z_sparse, 0, stream); + ASSERT_TRUE(compApprox(f_dense, f_sparse)); + + qnPredict(handle, + pams, + Xdev->data, + Xdev->ord == COL_MAJOR, + N, + D, + C, + w0_dense.data, + preds_dense.data, + stream); + qnPredictSparse(handle, + pams, + X_sparse.values, + X_sparse.cols, + X_sparse.row_ids, + X_sparse.nnz, + N, + D, + C, + w0_sparse.data, + preds_sparse.data, + stream); + + raft::update_host(&preds_dense_host[0], preds_dense.data, preds_dense.len, stream); + raft::update_host(&preds_sparse_host[0], preds_sparse.data, preds_sparse.len, stream); + handle.sync_stream(stream); + for (int i = 0; i < N; i++) { + ASSERT_TRUE(compApprox(preds_dense_host[i], preds_sparse_host[i])); + } + + f_dense = run_api(handle, + QN_LOSS_SOFTMAX, + C, + loss.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0_dense.data, + z_dense, + 0, + stream); + f_sparse = run_api(handle, + QN_LOSS_SOFTMAX, + C, + loss.fit_intercept, + X_sparse, + *ydev, + l1, + l2, + w0_sparse.data, + z_sparse, + 0, + stream); + ASSERT_TRUE(compApprox(f_dense, f_sparse)); + }; + + test_run(alpha, 0.0, loss_b); + test_run(0.0, alpha, loss_b); + test_run(alpha, 0.0, loss_no_b); + test_run(0.0, alpha, loss_no_b); +} + +} // namespace raft::solver::quasi_newton