Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Nov 20, 2024
1 parent 2a7d7dc commit 083f204
Showing 1 changed file with 32 additions and 122 deletions.
154 changes: 32 additions & 122 deletions cpp/src/glm/qn/mg/standardization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,25 @@ namespace ML {
namespace GLM {
namespace opg {

#include <string>

template<typename T>
void print_data(const raft::handle_t& handle, T* data, int len, std::string tag="data") {
ML::Logger::get().setLevel(6);
auto log_data = raft::arr2Str(data, len, tag.c_str(), handle.get_stream());
CUML_LOG_DEBUG("tag %s:, data: %s", tag.c_str(), log_data.c_str());
}

/**
* @brief Compute variance of the input matrix across all GPUs
*
* Variance operation is assumed to be performed on a given column.
*
* @tparam T the data type
* @param handle the internal cuml handle object
* @param X the input dense matrix
* @param n_samples number of rows of data across all GPUs
* @param mean_vector_all_samples the mean vector of rows of data across all GPUs
* @param var_vector the output variance vector
*/
template <typename T>
void var_mg(const raft::handle_t& handle,
const SimpleDenseMat<T>& X,
size_t n_samples,
T* mean_vector,
T* stddev_vector) {
void vars(const raft::handle_t& handle,
const SimpleDenseMat<T>& X,
size_t n_samples,
T* mean_vector_all_samples,
T* var_vector)
{
const T* input_data = X.data;
int D = X.n;
int num_rows = X.m;
Expand All @@ -67,34 +71,27 @@ void var_mg(const raft::handle_t& handle,
zero_vec.fill(0., stream);

// get sum of squares on every column
raft::stats::vars(stddev_vector, input_data, zero.data(), D, num_rows, false, !col_major, stream);
raft::stats::vars(var_vector, input_data, zero.data(), D, num_rows, false, !col_major, stream);
T weight = n_samples < 1 ? T(0) : T(1) * num_rows / T(n_samples - 1);
raft::linalg::multiplyScalar(stddev_vector, stddev_vector, weight, D, stream);
comm.allreduce(stddev_vector, stddev_vector, D, raft::comms::op_t::SUM, stream);
raft::linalg::multiplyScalar(var_vector, var_vector, weight, D, stream);
comm.allreduce(var_vector, var_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);
print_data(handle, stddev_vector, D, "sum squares");


// subtract mean
weight = n_samples <= 1 ? T(1) : T(n_samples) / T(n_samples - 1);
raft::linalg::binaryOp(
stddev_vector,
stddev_vector,
mean_vector,
var_vector,
var_vector,
mean_vector_all_samples,
D,
[weight]__device__(const T v, const T m) {
T scaled_m = weight * m * m;
T diff = v - scaled_m;
// avoid negative variance that is due to precision loss of floating point arithmetic
if (diff < 0) {
diff += scaled_m;
}
[weight] __device__(const T v, const T m) {
T scaled_m = weight * m * m;
T diff = v - scaled_m;
// avoid negative variance that is due to precision loss of floating point arithmetic
if (diff < 0) { diff += scaled_m; }
return diff;
},
},
stream);

raft::linalg::sqrt(stddev_vector, stddev_vector, D, handle.get_stream());
print_data(handle, stddev_vector, D, "sqrt");

}

template <typename T>
Expand All @@ -117,88 +114,10 @@ void mean_stddev(const raft::handle_t& handle,
comm.allreduce(mean_vector, mean_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);

var_mg<T>(handle, X, n_samples, mean_vector, stddev_vector);
/*
raft::stats::vars(stddev_vector, input_data, mean_vector, D, num_rows, false, !col_major, stream);
weight = n_samples < 1 ? T(0) : T(1) * num_rows / T(n_samples - 1);
raft::linalg::multiplyScalar(stddev_vector, stddev_vector, weight, D, stream);
comm.allreduce(stddev_vector, stddev_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);
// avoid negative variance that is due to precision loss of floating point arithmetic
weight = n_samples < 1 ? T(0) : T(1) / T(n_samples - 1);
weight = n_samples * weight;
auto no_neg_op = [weight] __device__(const T a, const T b) -> T {
if (a >= 0) return a;
return a + weight * b * b;
};
raft::linalg::binaryOp(stddev_vector, stddev_vector, mean_vector, D, no_neg_op, stream);
vars<T>(handle, X, n_samples, mean_vector, stddev_vector);
raft::linalg::sqrt(stddev_vector, stddev_vector, D, handle.get_stream());
*/
}


/*
template <typename T>
void mean_stddev(const raft::handle_t& handle,
const SimpleDenseMat<T>& X,
size_t n_samples,
T* mean_vector,
T* stddev_vector)
{
const T* input_data = X.data;
int D = X.n;
int num_rows = X.m;
bool col_major = (X.ord == COL_MAJOR);
auto stream = handle.get_stream();
auto& comm = handle.get_comms();
raft::stats::sum(mean_vector, input_data, D, num_rows, !col_major, stream);
T weight = T(1) / T(n_samples);
raft::linalg::multiplyScalar(mean_vector, mean_vector, weight, D, stream);
comm.allreduce(mean_vector, mean_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);
raft::stats::vars(stddev_vector, input_data, mean_vector, D, num_rows, false, !col_major, stream);
ML::Logger::get().setLevel(6);
auto log_data = raft::arr2Str(X.data, D * num_rows, "data ", stream);
auto log_vars = raft::arr2Str(stddev_vector, D, "", stream);
CUML_LOG_DEBUG("data: %s, row_major: %s, and debug mean_vars is %s", log_data.c_str(), !col_major? "true" : "false", log_vars.c_str());
weight = n_samples < 1 ? T(0) : T(1) * num_rows / T(n_samples - 1);
raft::linalg::multiplyScalar(stddev_vector, stddev_vector, weight, D, stream);
comm.allreduce(stddev_vector, stddev_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);
// avoid negative variance that is due to precision loss of floating point arithmetic
weight = n_samples < 1 ? T(0) : T(1) / T(n_samples - 1);
weight = n_samples * weight;
auto no_neg_op = [weight] __device__(const T a, const T b) -> T {
if (a >= 0) return a;
return a + weight * b * b;
};
ML::Logger::get().setLevel(6);
auto mean_log = raft::arr2Str(mean_vector, D, "", stream);
CUML_LOG_DEBUG("debug start reporting results");
CUML_LOG_DEBUG("debug mean is %s", mean_log.c_str());
raft::linalg::binaryOp(stddev_vector, stddev_vector, mean_vector, D, no_neg_op, stream);
auto stddev_log = raft::arr2Str(stddev_vector, D, "", stream);
CUML_LOG_DEBUG("debug square stddev is %s", stddev_log.c_str());
raft::linalg::sqrt(stddev_vector, stddev_vector, D, handle.get_stream());
stddev_log = raft::arr2Str(stddev_vector, D, "", stream);
CUML_LOG_DEBUG("debug stddev is %s", stddev_log.c_str());
}
*/
template <typename T, typename I = int>
SimpleSparseMat<T, I> get_sub_mat(const raft::handle_t& handle,
SimpleSparseMat<T, I> mat,
Expand Down Expand Up @@ -293,11 +212,6 @@ void mean_stddev(const raft::handle_t& handle,

mean(handle, X, n_samples, mean_vector);

ML::Logger::get().setLevel(6);
auto mean_log = raft::arr2Str(mean_vector, D, "", stream);
CUML_LOG_DEBUG("debug start reporting results");
CUML_LOG_DEBUG("debug mean is %s", mean_log.c_str());

// calculate stdev.S
rmm::device_uvector<T> X_values_squared(X.nnz, stream);
raft::copy(X_values_squared.data(), X.values, X.nnz, stream);
Expand All @@ -320,12 +234,8 @@ void mean_stddev(const raft::handle_t& handle,
return res;
};
raft::linalg::binaryOp(stddev_vector, stddev_vector, mean_vector, X.n, submean_no_neg_op, stream);
auto stddev_log = raft::arr2Str(stddev_vector, D, "", stream);
CUML_LOG_DEBUG("debug square stddev is %s", stddev_log.c_str());

raft::linalg::sqrt(stddev_vector, stddev_vector, X.n, handle.get_stream());
stddev_log = raft::arr2Str(stddev_vector, D, "", stream);
CUML_LOG_DEBUG("debug stddev is %s", stddev_log.c_str());
}

struct inverse_op {
Expand Down

0 comments on commit 083f204

Please sign in to comment.