Skip to content

Commit

Permalink
Further cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertoEAF committed Nov 8, 2020
1 parent 1d30846 commit 57b9e58
Showing 1 changed file with 27 additions and 38 deletions.
65 changes: 27 additions & 38 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2309,48 +2309,37 @@ int LGBM_NetworkInitWithFunctions(int num_machines, int rank,

// ---- start of some help functions


template<typename T>
std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric_helper(const void* data, int num_row, int num_col, int is_row_major) {
const T* data_ptr = reinterpret_cast<const T*>(data);
if (is_row_major) {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
}
return ret;
};
} else {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
}
return ret;
};
}
}

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
}
return ret;
};
} else {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
}
return ret;
};
}
return RowFunctionFromDenseMatric_helper<float>(data, num_row, num_col, is_row_major);
} else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
}
return ret;
};
} else {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
}
return ret;
};
}
return RowFunctionFromDenseMatric_helper<double>(data, num_row, num_col, is_row_major);
}
Log::Fatal("Unknown data type in RowFunctionFromDenseMatric");
return nullptr;
Expand Down

0 comments on commit 57b9e58

Please sign in to comment.