Skip to content

Commit

Permalink
ignore unknown parameters when loading from model file
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Oct 4, 2023
1 parent f175ceb commit b459d70
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
11 changes: 8 additions & 3 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,20 @@ class GBDT : public GBDTBase {
const auto pair = Common::Split(line.c_str(), ":");
if (pair[1] == " ]")
continue;
const auto param = pair[0].substr(1);
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
auto iter = param_types.find(param);
if (iter == param_types.end()) {
Log::Warning("Type for param: %s not found. This doesn't affect inference.", param.c_str());
continue;
}
std::string param_type = iter->second;
if (first) {
first = false;
str_buf << "\"";
} else {
str_buf << ",\"";
}
const auto param = pair[0].substr(1);
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
const auto param_type = param_types.at(param);
str_buf << param << "\": ";
if (param_type == "string") {
str_buf << "\"" << value_str << "\"";
Expand Down
19 changes: 17 additions & 2 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ def test_feature_name_with_non_ascii():
assert feature_names == gbm2.feature_name()


def test_parameters_are_loaded_from_model_file(tmp_path):
def test_parameters_are_loaded_from_model_file(tmp_path, capsys):
X = np.hstack([np.random.rand(100, 1), np.random.randint(0, 5, (100, 2))])
y = np.random.rand(100)
ds = lgb.Dataset(X, y)
Expand All @@ -1487,8 +1487,18 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
'num_threads': 1,
}
model_file = tmp_path / 'model.txt'
lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2]).save_model(model_file)
orig_bst = lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2])
orig_bst.save_model(model_file)
with model_file.open('rt') as f:
model_contents = f.readlines()
params_start = model_contents.index('parameters:\n')
model_contents.insert(params_start + 1, '[max_conflict_rate: 0]\n')
with model_file.open('wt') as f:
f.writelines(model_contents)
bst = lgb.Booster(model_file=model_file)
expected_msg = "[LightGBM] [Warning] Type for param: 'max_conflict_rate' not found. This doesn't affect inference."
stdout = capsys.readouterr().out
assert expected_msg in stdout
set_params = {k: bst.params[k] for k in params.keys()}
assert set_params == params
assert bst.params['categorical_feature'] == [1, 2]
Expand All @@ -1498,6 +1508,11 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
bst2 = lgb.Booster(params={'num_leaves': 7}, model_file=model_file)
assert bst.params == bst2.params

# check inference isn't affected by unknown parameter
orig_preds = orig_bst.predict(X)
preds = bst.predict(X)
np.testing.assert_allclose(preds, orig_preds)


def test_save_load_copy_pickle():
def train_and_predict(init_model=None, return_model=False):
Expand Down

0 comments on commit b459d70

Please sign in to comment.