From 11d7608f2d755899a88bfa1c0fbc779b9e9b3505 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 24 Aug 2021 00:48:16 +0200 Subject: [PATCH] [python] add parameter object_hook to method dump_model (#4533) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add parameter object_hook to function dump_model (python API) * eol * fix syntax * lint * better documentation * Update python-package/lightgbm/basic.py Co-authored-by: Nikita Titov Co-authored-by: xavier dupré Co-authored-by: Nikita Titov --- python-package/lightgbm/basic.py | 13 +++++++++++-- tests/python_package_test/test_engine.py | 20 ++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 5d8efb950f00..d33ba3fd6ebb 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3342,7 +3342,7 @@ def model_to_string(self, num_iteration=None, start_iteration=0, importance_type ret += _dump_pandas_categorical(self.pandas_categorical) return ret - def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split'): + def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split', object_hook=None): """Dump Booster to JSON format. Parameters @@ -3357,6 +3357,15 @@ def dump_model(self, num_iteration=None, start_iteration=0, importance_type='spl What type of feature importance should be dumped. If "split", result contains numbers of times the feature is used in a model. If "gain", result contains total gains of splits which use the feature. + object_hook : callable or None, optional (default=None) + If not None, ``object_hook`` is a function called while parsing the json + string returned by the C API. It may be used to alter the json, to store + specific values while building the json structure. It avoids + walking through the structure again. It saves a significant amount + of time if the number of trees is huge. + Signature is ``def object_hook(node: dict) -> dict``. + None is equivalent to ``lambda node: node``. + See documentation of ``json.loads()`` for further details. Returns ------- @@ -3391,7 +3400,7 @@ def dump_model(self, num_iteration=None, start_iteration=0, importance_type='spl ctypes.c_int64(actual_len), ctypes.byref(tmp_out_len), ptr_string_buffer)) - ret = json.loads(string_buffer.value.decode('utf-8')) + ret = json.loads(string_buffer.value.decode('utf-8'), object_hook=object_hook) ret['pandas_categorical'] = json.loads(json.dumps(self.pandas_categorical, default=json_default_with_numpy)) return ret diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index f7e4d0c5b02f..b44cee469a22 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -2846,3 +2846,23 @@ def test_dump_model(): assert "leaf_const" in dumped_model_str assert "leaf_value" in dumped_model_str assert "leaf_count" in dumped_model_str + + +def test_dump_model_hook(): + + def hook(obj): + if 'leaf_value' in obj: + obj['LV'] = obj['leaf_value'] + del obj['leaf_value'] + return obj + + X, y = load_breast_cancer(return_X_y=True) + train_data = lgb.Dataset(X, label=y) + params = { + "objective": "binary", + "verbose": -1 + } + bst = lgb.train(params, train_data, num_boost_round=5) + dumped_model_str = str(bst.dump_model(5, 0, object_hook=hook)) + assert "leaf_value" not in dumped_model_str + assert "LV" in dumped_model_str