Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Reduce code duplication in c_api.cpp #3539

Merged
merged 3 commits into from
Nov 24, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 73 additions & 142 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2309,48 +2309,37 @@ int LGBM_NetworkInitWithFunctions(int num_machines, int rank,

// ---- start of some help functions


template<typename T>
std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric_helper(const void* data, int num_row, int num_col, int is_row_major) {
const T* data_ptr = reinterpret_cast<const T*>(data);
if (is_row_major) {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
}
return ret;
};
} else {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
}
return ret;
};
}
}

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
}
return ret;
};
} else {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
}
return ret;
};
}
return RowFunctionFromDenseMatric_helper<float>(data, num_row, num_col, is_row_major);
} else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
}
return ret;
};
} else {
return [=] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
}
return ret;
};
}
return RowFunctionFromDenseMatric_helper<double>(data, num_row, num_col, is_row_major);
}
Log::Fatal("Unknown data type in RowFunctionFromDenseMatric");
return nullptr;
Expand Down Expand Up @@ -2392,136 +2381,78 @@ RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
};
}

template<typename T, typename T1, typename T2>
std::function<std::vector<std::pair<int, double>>(T idx)>
RowFunctionFromCSR_helper(const void* indptr, const int32_t* indices, const void* data) {
const T1* data_ptr = reinterpret_cast<const T1*>(data);
const T2* ptr_indptr = reinterpret_cast<const T2*>(indptr);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
}

template<typename T>
std::function<std::vector<std::pair<int, double>>(T idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
return RowFunctionFromCSR_helper<T, float, int32_t>(indptr, indices, data);
} else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
return RowFunctionFromCSR_helper<T, float, int64_t>(indptr, indices, data);
}
} else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
return RowFunctionFromCSR_helper<T, double, int32_t>(indptr, indices, data);
} else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [=] (T idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
if (end - start > 0) {
ret.reserve(end - start);
}
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
return RowFunctionFromCSR_helper<T, double, int64_t>(indptr, indices, data);
}
}
Log::Fatal("Unknown data type in RowFunctionFromCSR");
return nullptr;
}



template <typename T1, typename T2>
std::function<std::pair<int, double>(int idx)> IterateFunctionFromCSC_helper(const void* col_ptr, const int32_t* indices, const void* data, int col_idx) {
const T1* data_ptr = reinterpret_cast<const T1*>(data);
const T2* ptr_col_ptr = reinterpret_cast<const T2*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
}

std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
CHECK(col_idx < ncol_ptr && col_idx >= 0);
if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (col_ptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
return IterateFunctionFromCSC_helper<float, int32_t>(col_ptr, indices, data, col_idx);
} else if (col_ptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
return IterateFunctionFromCSC_helper<float, int64_t>(col_ptr, indices, data, col_idx);
}
} else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (col_ptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
return IterateFunctionFromCSC_helper<double, int32_t>(col_ptr, indices, data, col_idx);
} else if (col_ptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [=] (int offset) {
int64_t i = static_cast<int64_t>(start + offset);
if (i >= end) {
return std::make_pair(-1, 0.0);
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
return std::make_pair(idx, val);
};
return IterateFunctionFromCSC_helper<double, int64_t>(col_ptr, indices, data, col_idx);
}
}
Log::Fatal("Unknown data type in CSC matrix");
Expand Down