Skip to content

Commit

Permalink
Divide metric into train_metric and valid_metric
Browse files Browse the repository at this point in the history
  • Loading branch information
henry0312 committed Dec 17, 2018
1 parent cba8244 commit bc0ea4a
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 60 deletions.
4 changes: 2 additions & 2 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,15 +484,15 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int*
* \param out_len total number of eval results
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len);
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int data_idx, int* out_len);

/*!
* \brief Get name of eval
* \param out_len total number of eval results
* \param out_strs names of eval result, need to pre-allocate memory before call this
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs);
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int data_idx, int* out_len, char** out_strs);

/*!
* \brief Get name of features
Expand Down
3 changes: 2 additions & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,8 @@ struct Config {
// descl2 = ``xentlambda``, "intensity-weighted" cross-entropy, aliases: ``cross_entropy_lambda``
// descl2 = ``kldiv``, `Kullback-Leibler divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`__, aliases: ``kullback_leibler``
// desc = support multiple metrics, separated by ``,``
std::vector<std::string> metric;
std::vector<std::string> train_metric;
std::vector<std::string> valid_metric;

// check = >0
// alias = output_freq
Expand Down
53 changes: 25 additions & 28 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,7 +1472,6 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False):
"""
self.handle = None
self.network = False
self.__need_reload_eval_info = True
self.__train_data_name = "training"
self.__attr = {}
self.__set_objective_to_none = False
Expand Down Expand Up @@ -1528,7 +1527,7 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False):
# buffer for inner predict
self.__inner_predict_buffer = [None]
self.__is_predicted_cur_iter = [False]
self.__get_eval_info()
self.__get_eval_info(0) # init by train data (data_idx=0)
self.pandas_categorical = train_set.pandas_categorical
elif model_file is not None:
# Prediction task
Expand Down Expand Up @@ -1708,8 +1707,6 @@ def reset_parameter(self, params):
self : Booster
Booster with new parameters.
"""
if any(metric_alias in params for metric_alias in ('metric', 'metrics', 'metric_types')):
self.__need_reload_eval_info = True
params_str = param_dict_to_str(params)
if params_str:
_safe_call(_LIB.LGBM_BoosterResetParameter(
Expand Down Expand Up @@ -2310,7 +2307,7 @@ def __inner_eval(self, data_name, data_idx, feval=None):
"""Evaluate training or validation data."""
if data_idx >= self.__num_dataset:
raise ValueError("Data_idx should be smaller than number of dataset")
self.__get_eval_info()
self.__get_eval_info(data_idx)
ret = []
if self.__num_inner_eval > 0:
result = np.zeros(self.__num_inner_eval, dtype=np.float64)
Expand Down Expand Up @@ -2363,31 +2360,31 @@ def __inner_predict(self, data_idx):
self.__is_predicted_cur_iter[data_idx] = True
return self.__inner_predict_buffer[data_idx]

def __get_eval_info(self):
def __get_eval_info(self, data_idx):
"""Get inner evaluation count and names."""
if self.__need_reload_eval_info:
self.__need_reload_eval_info = False
out_num_eval = ctypes.c_int(0)
# Get num of inner evals
_safe_call(_LIB.LGBM_BoosterGetEvalCounts(
out_num_eval = ctypes.c_int(0)
# Get num of inner evals
_safe_call(_LIB.LGBM_BoosterGetEvalCounts(
self.handle,
ctypes.c_int(data_idx),
ctypes.byref(out_num_eval)))
self.__num_inner_eval = out_num_eval.value
if self.__num_inner_eval > 0:
# Get name of evals
tmp_out_len = ctypes.c_int(0)
string_buffers = [ctypes.create_string_buffer(255) for i in range_(self.__num_inner_eval)]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetEvalNames(
self.handle,
ctypes.byref(out_num_eval)))
self.__num_inner_eval = out_num_eval.value
if self.__num_inner_eval > 0:
# Get name of evals
tmp_out_len = ctypes.c_int(0)
string_buffers = [ctypes.create_string_buffer(255) for i in range_(self.__num_inner_eval)]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetEvalNames(
self.handle,
ctypes.byref(tmp_out_len),
ptr_string_buffers))
if self.__num_inner_eval != tmp_out_len.value:
raise ValueError("Length of eval names doesn't equal with num_evals")
self.__name_inner_eval = \
[string_buffers[i].value.decode() for i in range_(self.__num_inner_eval)]
self.__higher_better_inner_eval = \
[name.startswith(('auc', 'ndcg@', 'map@')) for name in self.__name_inner_eval]
ctypes.c_int(data_idx),
ctypes.byref(tmp_out_len),
ptr_string_buffers))
if self.__num_inner_eval != tmp_out_len.value:
raise ValueError("Length of eval names doesn't equal with num_evals")
self.__name_inner_eval = \
[string_buffers[i].value.decode() for i in range_(self.__num_inner_eval)]
self.__higher_better_inner_eval = \
[name.startswith(('auc', 'ndcg@', 'map@')) for name in self.__name_inner_eval]

def attr(self, key):
"""Get attribute string from the Booster.
Expand Down
6 changes: 3 additions & 3 deletions src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void Application::LoadData() {
}
// create training metric
if (config_.is_provide_training_metric) {
for (auto metric_type : config_.metric) {
for (auto metric_type : config_.train_metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; }
metric->Init(train_data_->metadata(), train_data_->num_data());
Expand All @@ -126,7 +126,7 @@ void Application::LoadData() {
train_metric_.shrink_to_fit();


if (!config_.metric.empty()) {
if (!config_.train_metric.empty()) {
// only when have metrics then need to construct validation data

// Add validation data, if it exists
Expand All @@ -146,7 +146,7 @@ void Application::LoadData() {

// add metric for validation data
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric) {
for (auto metric_type : config_.valid_metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; }
metric->Init(valid_datas_.back()->metadata(),
Expand Down
50 changes: 35 additions & 15 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class Booster {

// create training metric
train_metric_.clear();
for (auto metric_type : config_.metric) {
for (auto metric_type : config_.train_metric) {
auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; }
Expand Down Expand Up @@ -164,7 +164,7 @@ class Booster {
void AddValidData(const Dataset* valid_data) {
std::lock_guard<std::mutex> lock(mutex_);
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric) {
for (auto metric_type : config_.valid_metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; }
metric->Init(valid_data->metadata(), valid_data->num_data());
Expand Down Expand Up @@ -297,21 +297,41 @@ class Booster {
boosting_->ShuffleModels(start_iter, end_iter);
}

int GetEvalCounts() const {
int GetEvalCounts(int data_idx) const {
int ret = 0;
for (const auto& metric : train_metric_) {
ret += static_cast<int>(metric->GetName().size());
if (data_idx == 0) {
for (const auto& metric : train_metric_) {
ret += static_cast<int>(metric->GetName().size());
}
} else {
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
ret += static_cast<int>(valid_metrics_[i][j]->GetName().size());
}
}
}
return ret;
}

int GetEvalNames(char** out_strs) const {
int GetEvalNames(int data_idx, char** out_strs) const {
int idx = 0;
for (const auto& metric : train_metric_) {
for (const auto& name : metric->GetName()) {
std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
++idx;
}
if (data_idx == 0) {
for (const auto& metric : train_metric_) {
for (const auto& name : metric->GetName()) {
std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
++idx;
}
}
} else {
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
auto& metric = valid_metrics_[i][j];
for (const auto& name : metric->GetName()) {
std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
++idx;
}
}
}
}
return idx;
}
Expand Down Expand Up @@ -1029,17 +1049,17 @@ int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
API_END();
}

int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int data_idx, int* out_len) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalCounts();
*out_len = ref_booster->GetEvalCounts(data_idx);
API_END();
}

int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int data_idx, int* out_len, char** out_strs) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalNames(out_strs);
*out_len = ref_booster->GetEvalNames(data_idx, out_strs);
API_END();
}

Expand Down
50 changes: 45 additions & 5 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ void GetObjectiveType(const std::unordered_map<std::string, std::string>& params
}
}

void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
void GetTrainMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
std::string value;
if (Config::GetString(params, "metric", &value)) {
if (Config::GetString(params, "train_metric", &value)) {
// clear old metrics
metric->clear();
// to lower
Expand Down Expand Up @@ -99,6 +99,31 @@ void GetMetricType(const std::unordered_map<std::string, std::string>& params, s
}
}

void GetValidMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
std::string value;
if (Config::GetString(params, "valid_metric", &value)) {
// clear old metrics
metric->clear();
// to lower
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
// split
std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
// remove duplicate
std::unordered_set<std::string> metric_sets;
for (auto& met : metrics) {
std::transform(met.begin(), met.end(), met.begin(), Common::tolower);
if (metric_sets.count(met) <= 0) {
metric_sets.insert(met);
}
}
for (auto& met : metric_sets) {
metric->push_back(met);
}
metric->shrink_to_fit();
}
}


void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task) {
std::string value;
if (Config::GetString(params, "task", &value)) {
Expand Down Expand Up @@ -164,7 +189,8 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {

GetTaskType(params, &task);
GetBoostingType(params, &boosting);
GetMetricType(params, &metric);
GetTrainMetricType(params, &train_metric);
GetValidMetricType(params, &valid_metric);
GetObjectiveType(params, &objective);
GetDeviceType(params, &device_type);
GetTreeLearnerType(params, &tree_learner);
Expand Down Expand Up @@ -215,7 +241,7 @@ void Config::CheckParamConflict() {
Log::Fatal("Number of classes must be 1 for non-multiclass training");
}
}
for (std::string metric_type : metric) {
for (std::string metric_type : train_metric) {
bool metric_custom_or_none = metric_type == std::string("none") || metric_type == std::string("null")
|| metric_type == std::string("custom") || metric_type == std::string("na");
bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
Expand All @@ -227,6 +253,19 @@ void Config::CheckParamConflict() {
Log::Fatal("Multiclass objective and metrics don't match");
}
}
for (std::string metric_type : valid_metric) {
bool metric_custom_or_none = metric_type == std::string("none") || metric_type == std::string("null")
|| metric_type == std::string("custom") || metric_type == std::string("na");
bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
|| metric_type == std::string("multi_logloss")
|| metric_type == std::string("multi_error")
|| (metric_custom_or_none && num_class_check > 1));
if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)) {
Log::Fatal("Multiclass objective and metrics don't match");
}
}


if (num_machines > 1) {
is_parallel = true;
Expand Down Expand Up @@ -270,7 +309,8 @@ std::string Config::ToString() const {
std::stringstream str_buf;
str_buf << "[boosting: " << boosting << "]\n";
str_buf << "[objective: " << objective << "]\n";
str_buf << "[metric: " << Common::Join(metric, ",") << "]\n";
str_buf << "[train_metric: " << Common::Join(train_metric, ",") << "]\n";
str_buf << "[valid_metric: " << Common::Join(valid_metric, ",") << "]\n";
str_buf << "[tree_learner: " << tree_learner << "]\n";
str_buf << "[device_type: " << device_type << "]\n";
str_buf << SaveMembersToString();
Expand Down
4 changes: 2 additions & 2 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ std::unordered_map<std::string, std::string> Config::alias_table({
{"output_freq", "metric_freq"},
{"training_metric", "is_provide_training_metric"},
{"is_training_metric", "is_provide_training_metric"},
{"train_metric", "is_provide_training_metric"},
{"ndcg_eval_at", "eval_at"},
{"ndcg_at", "eval_at"},
{"map_eval_at", "eval_at"},
Expand Down Expand Up @@ -249,7 +248,8 @@ std::unordered_set<std::string> Config::parameter_set({
"tweedie_variance_power",
"max_position",
"label_gain",
"metric",
"train_metric",
"valid_metric",
"metric_freq",
"is_provide_training_metric",
"eval_at",
Expand Down
9 changes: 5 additions & 4 deletions src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,19 +426,20 @@ LGBM_SE LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
LGBM_SE buf_len,
LGBM_SE actual_len,
LGBM_SE eval_names,
LGBM_SE call_state) {
LGBM_SE call_state,
LGBM_SE data_idx) {

R_API_BEGIN();
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), R_AS_INT(data_idx), &len));
std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len);
for (int i = 0; i < len; ++i) {
names[i].resize(128);
ptr_names[i] = names[i].data();
}
int out_len;
CHECK_CALL(LGBM_BoosterGetEvalNames(R_GET_PTR(handle), &out_len, ptr_names.data()));
CHECK_CALL(LGBM_BoosterGetEvalNames(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_names.data()));
CHECK(out_len == len);
auto merge_names = Common::Join<char*>(ptr_names, "\t");
EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
Expand All @@ -451,7 +452,7 @@ LGBM_SE LGBM_BoosterGetEval_R(LGBM_SE handle,
LGBM_SE call_state) {
R_API_BEGIN();
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), R_AS_INT(data_idx), &len));
double* ptr_ret = R_REAL_PTR(out_result);
int out_len;
CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
Expand Down

0 comments on commit bc0ea4a

Please sign in to comment.