From b0137debe6e9cc92b65ec71b0fe8a56ea213c143 Mon Sep 17 00:00:00 2001 From: chjinche <49483542+chjinche@users.noreply.github.com> Date: Tue, 16 Nov 2021 14:27:23 +0800 Subject: [PATCH] Add customized parser support (#4782) * add customized parser support * fix typo of parser_config_file description * make delimiter as parameter of JoinedLines --- README.md | 2 + docs/Parameters.rst | 8 ++++ include/LightGBM/boosting.h | 2 + include/LightGBM/config.h | 5 +++ include/LightGBM/dataset.h | 59 +++++++++++++++++++++++++ include/LightGBM/utils/common.h | 25 +++++++++++ include/LightGBM/utils/text_reader.h | 11 +++++ src/application/predictor.hpp | 6 ++- src/boosting/gbdt.cpp | 3 ++ src/boosting/gbdt.h | 4 ++ src/boosting/gbdt_model_text.cpp | 31 ++++++++++++- src/io/config_auto.cpp | 4 ++ src/io/dataset_loader.cpp | 36 ++++++++++++--- src/io/parser.cpp | 56 +++++++++++++++++++++++ tests/python_package_test/test_basic.py | 12 +++++ 15 files changed, 253 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 921bfb76f308..da11b743a924 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,8 @@ MLflow (experiment tracking, model monitoring framework): https://github.com/mlf `{mlr3extralearners}` (R `{mlr3}`-compliant interface): https://github.com/mlr-org/mlr3extralearners +lightgbm-transform (feature transformation binding): https://github.com/microsoft/lightgbm-transform + Support ------- diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 75bef7add9bc..7ace90d9b34d 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -850,6 +850,14 @@ Dataset Parameters - **Note**: setting this to ``true`` may lead to much slower text parsing +- ``parser_config_file`` :raw-html:`🔗︎`, default = ``""``, type = string + + - path to a ``.json`` file that specifies customized parser initialized configuration + + - see `lightgbm-transform `__ for usage examples + + - **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issues page `__ + Predict Parameters ~~~~~~~~~~~~~~~~~~ diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index ddbcdbc18e44..7530495c0e17 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -314,6 +314,8 @@ class LIGHTGBM_EXPORT Boosting { static Boosting* CreateBoosting(const std::string& type, const char* filename); virtual bool IsLinear() const { return false; } + + virtual std::string ParserConfigStr() const = 0; }; class GBDTBase : public Boosting { diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index da43a5ec9782..50371f3a2d91 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -721,6 +721,11 @@ struct Config { // desc = **Note**: setting this to ``true`` may lead to much slower text parsing bool precise_float_parser = false; + // desc = path to a ``.json`` file that specifies customized parser initialized configuration + // desc = see `lightgbm-transform `__ for usage examples + // desc = **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issues page `__ + std::string parser_config_file = ""; + #pragma endregion #pragma region Predict Parameters diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index abef980f5fbe..cf19429322ee 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -254,6 +255,14 @@ class Parser { public: typedef const char* (*AtofFunc)(const char* p, double* out); + /*! \brief Default constructor */ + Parser() {} + + /*! + * \brief Constructor for customized parser. The constructor accepts content not path because need to save/load the config along with model string + */ + explicit Parser(std::string) {} + /*! \brief virtual destructor */ virtual ~Parser() {} @@ -271,12 +280,58 @@ class Parser { /*! * \brief Create an object of parser, will auto choose the format depend on file * \param filename One Filename of data + * \param header whether input file contains header * \param num_features Pass num_features of this data file if you know, <=0 means don't know * \param label_idx index of label column * \param precise_float_parser using precise floating point number parsing if true * \return Object of parser */ static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser); + + /*! + * \brief Create an object of parser, could use customized parser, or auto choose the format depend on file + * \param filename One Filename of data + * \param header whether input file contains header + * \param num_features Pass num_features of this data file if you know, <=0 means don't know + * \param label_idx index of label column + * \param precise_float_parser using precise floating point number parsing if true + * \param parser_config_str Customized parser config content + * \return Object of parser + */ + static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser, + std::string parser_config_str); + + /*! + * \brief Generate parser config str used for custom parser initialization, may save values of label id and header + * \param filename One Filename of data + * \param parser_config_filename One Filename of parser config + * \param header whether input file contains header + * \param label_idx index of label column + * \return Parser config str + */ + static std::string GenerateParserConfigStr(const char* filename, const char* parser_config_filename, bool header, int label_idx); +}; + +/*! \brief Interface for parser factory, used by customized parser */ +class ParserFactory { + private: + ParserFactory() {} + std::map> object_map_; + + public: + ~ParserFactory() {} + static ParserFactory& getInstance(); + void Register(std::string class_name, std::function objc); + Parser* getObject(std::string class_name, std::string config_str); +}; + +/*! \brief Interface for parser reflector, used by customized parser */ +class ParserReflector { + public: + ParserReflector(std::string class_name, std::function objc) { + ParserFactory::getInstance().Register(class_name, objc); + } + virtual ~ParserReflector() {} }; /*! \brief The main class of data set, @@ -605,6 +660,9 @@ class Dataset { /*! \brief Get names of current data set */ inline const std::vector& feature_names() const { return feature_names_; } + /*! \brief Get content of parser config file */ + inline const std::string parser_config_str() const { return parser_config_str_; } + inline void set_feature_names(const std::vector& feature_names) { if (feature_names.size() != static_cast(num_total_features_)) { Log::Fatal("Size of feature_names error, should equal with total number of features"); @@ -722,6 +780,7 @@ class Dataset { /*! map feature (inner index) to its index in the list of numeric (non-categorical) features */ std::vector numeric_feature_map_; int num_numeric_features_; + std::string parser_config_str_; }; } // namespace LightGBM diff --git a/include/LightGBM/utils/common.h b/include/LightGBM/utils/common.h index af9810627b04..1e47700d0a61 100644 --- a/include/LightGBM/utils/common.h +++ b/include/LightGBM/utils/common.h @@ -8,6 +8,7 @@ #if ((defined(sun) || defined(__sun)) && (defined(__SVR4) || defined(__svr4__))) #include #endif +#include #include #include @@ -62,6 +63,8 @@ namespace LightGBM { namespace Common { +using json11::Json; + /*! * Imbues the stream with the C locale. */ @@ -200,6 +203,28 @@ inline static std::vector Split(const char* c_str, const char* deli return ret; } +inline static std::string GetFromParserConfig(std::string config_str, std::string key) { + // parser config should follow json format. + std::string err; + Json config_json = Json::parse(config_str, &err); + if (!err.empty()) { + Log::Fatal("Invalid parser config: %s. Please check if follow json format.", err.c_str()); + } + return config_json[key].string_value(); +} + +inline static std::string SaveToParserConfig(std::string config_str, std::string key, std::string value) { + std::string err; + Json config_json = Json::parse(config_str, &err); + if (!err.empty()) { + Log::Fatal("Invalid parser config: %s. Please check if follow json format.", err.c_str()); + } + CHECK(config_json.is_object()); + std::map config_map = config_json.object_items(); + config_map.insert(std::pair(key, Json(value))); + return Json(config_map).dump(); +} + template inline static const char* Atoi(const char* p, T* out) { int sign; diff --git a/include/LightGBM/utils/text_reader.h b/include/LightGBM/utils/text_reader.h index b6090b1ce9b2..ccb25f960d05 100644 --- a/include/LightGBM/utils/text_reader.h +++ b/include/LightGBM/utils/text_reader.h @@ -84,6 +84,17 @@ class TextReader { * \return Text data, store in std::vector by line */ inline std::vector& Lines() { return lines_; } + /*! + * \brief Get joined text data that read from file + * \return Text data, store in std::string, joined all lines by delimiter + */ + inline std::string JoinedLines(std::string delimiter = "\n") { + std::stringstream ss; + for (auto line : lines_) { + ss << line << delimiter; + } + return ss.str(); + } INDEX_T ReadAllAndProcess(const std::function& process_fun) { last_line_ = ""; diff --git a/src/application/predictor.hpp b/src/application/predictor.hpp index dff23add2df5..d1a8aca4d041 100644 --- a/src/application/predictor.hpp +++ b/src/application/predictor.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -167,7 +168,7 @@ class Predictor { } auto label_idx = header ? -1 : boosting_->LabelIdx(); auto parser = std::unique_ptr(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx, - precise_float_parser)); + precise_float_parser, boosting_->ParserConfigStr())); if (parser == nullptr) { Log::Fatal("Could not recognize the data format of data file %s", data_filename); @@ -179,7 +180,8 @@ class Predictor { TextReader predict_data_reader(data_filename, header); std::vector feature_remapper(parser->NumFeatures(), -1); bool need_adjust = false; - if (header) { + // skip raw feature remapping if trained model has parser config str which may contain actual feature names. + if (header && boosting_->ParserConfigStr().empty()) { std::string first_line = predict_data_reader.first_line(); std::vector header_words = Common::Split(first_line.c_str(), "\t,"); std::unordered_map header_mapper; diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index d393d46d5133..5b31865748f1 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -120,6 +120,8 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective feature_names_ = train_data_->feature_names(); feature_infos_ = train_data_->feature_infos(); monotone_constraints_ = config->monotone_constraints; + // get parser config file content + parser_config_str_ = train_data_->parser_config_str(); // if need bagging, create buffer ResetBaggingConfig(config_.get(), true); @@ -730,6 +732,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* label_idx_ = train_data_->label_idx(); feature_names_ = train_data_->feature_names(); feature_infos_ = train_data_->feature_infos(); + parser_config_str_ = train_data_->parser_config_str(); tree_learner_->ResetTrainingData(train_data, is_constant_hessian_); ResetBaggingConfig(config_.get(), true); diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 472ea1707104..efeacfbfaef0 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -394,6 +394,8 @@ class GBDT : public GBDTBase { bool IsLinear() const override { return linear_tree_; } + inline std::string ParserConfigStr() const override {return parser_config_str_;} + protected: virtual bool GetIsConstHessian(const ObjectiveFunction* objective_function) { if (objective_function != nullptr) { @@ -483,6 +485,8 @@ class GBDT : public GBDTBase { std::vector> models_; /*! \brief Max feature index of training data*/ int max_feature_idx_; + /*! \brief Parser config file content */ + std::string parser_config_str_ = ""; #ifdef USE_CUDA /*! \brief First order derivative of training data */ diff --git a/src/boosting/gbdt_model_text.cpp b/src/boosting/gbdt_model_text.cpp index 0e296c1ec984..73c3ea98d3f6 100644 --- a/src/boosting/gbdt_model_text.cpp +++ b/src/boosting/gbdt_model_text.cpp @@ -16,7 +16,7 @@ namespace LightGBM { -const char* kModelVersion = "v3"; +const char* kModelVersion = "v4"; std::string GBDT::DumpModel(int start_iteration, int num_iteration, int feature_importance_type) const { std::stringstream str_buf; @@ -399,6 +399,11 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration, int ss << loaded_parameter_ << "\n"; ss << "end of parameters" << '\n'; } + if (!parser_config_str_.empty()) { + ss << "\nparser:" << '\n'; + ss << parser_config_str_ << "\n"; + ss << "end of parser" << '\n'; + } return ss.str(); } @@ -568,7 +573,7 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { num_iteration_for_pred_ = static_cast(models_.size()) / num_tree_per_iteration_; num_init_iteration_ = num_iteration_for_pred_; iter_ = 0; - bool is_inparameter = false; + bool is_inparameter = false, is_inparser = false; std::stringstream ss; Common::C_stringstream(ss); while (p < end) { @@ -594,6 +599,28 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { if (!ss.str().empty()) { loaded_parameter_ = ss.str(); } + ss.clear(); + ss.str(""); + while (p < end) { + auto line_len = Common::GetLine(p); + if (line_len > 0) { + std::string cur_line(p, line_len); + if (cur_line == std::string("parser:")) { + is_inparser = true; + } else if (cur_line == std::string("end of parser")) { + p += line_len; + p = Common::SkipNewLine(p); + break; + } else if (is_inparser) { + ss << cur_line << "\n"; + } + } + p += line_len; + p = Common::SkipNewLine(p); + } + parser_config_str_ = ss.str(); + ss.clear(); + ss.str(""); return true; } diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index cd24790b820c..682264358893 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -272,6 +272,7 @@ const std::unordered_set& Config::parameter_set() { "forcedbins_filename", "save_binary", "precise_float_parser", + "parser_config_file", "start_iteration_predict", "num_iteration_predict", "predict_raw_score", @@ -540,6 +541,8 @@ void Config::GetMembersFromString(const std::unordered_map parser_config_reader(config_.parser_config_file.c_str(), false); + parser_config_reader.ReadAllLines(); + std::string parser_config_str = parser_config_reader.JoinedLines(); + if (!parser_config_str.empty()) { + std::string header_in_parser_config = Common::GetFromParserConfig(parser_config_str, "header"); + if (!header_in_parser_config.empty()) { + Log::Info("Get raw column names from parser config."); + feature_names_ = Common::Split(header_in_parser_config.c_str(), "\t,"); + } + } } // load label idx first @@ -71,6 +83,15 @@ void DatasetLoader::SetHeader(const char* filename) { } } + if (!config_.parser_config_file.empty()) { + // if parser config file exists, feature names may be changed after customized parser applied. + // clear here so could use default filled feature names during dataset construction. + // may improve by saving real feature names defined in parser in the future. + if (!feature_names_.empty()) { + feature_names_.clear(); + } + } + if (!feature_names_.empty()) { // erase label column name feature_names_.erase(feature_names_.begin() + label_idx_); @@ -196,8 +217,9 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac auto bin_filename = CheckCanLoadFromBin(filename); bool is_load_from_binary = false; if (bin_filename.size() == 0) { + dataset->parser_config_str_ = Parser::GenerateParserConfigStr(filename, config_.parser_config_file.c_str(), config_.header, label_idx_); auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_, - config_.precise_float_parser)); + config_.precise_float_parser, dataset->parser_config_str_)); if (parser == nullptr) { Log::Fatal("Could not recognize data format of %s", filename); } @@ -257,8 +279,6 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac return dataset.release(); } - - Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) { data_size_t num_global_data = 0; std::vector used_data_indices; @@ -269,7 +289,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, auto bin_filename = CheckCanLoadFromBin(filename); if (bin_filename.size() == 0) { auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_, - config_.precise_float_parser)); + config_.precise_float_parser, dataset->parser_config_str_)); if (parser == nullptr) { Log::Fatal("Could not recognize data format of %s", filename); } @@ -1010,7 +1030,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, categorical_features_); // check the range of label_idx, weight_idx and group_idx - CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_); + // skip label check if user input parser config file, + // because label id is got from raw features while dataset features are consistent with customized parser. + if (dataset->parser_config_str_.empty()) { + CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_); + } CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_); CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_); @@ -1383,8 +1407,6 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) { } } - - std::vector> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features, const std::unordered_set& categorical_features) { std::vector> forced_bins(num_total_features, std::vector()); diff --git a/src/io/parser.cpp b/src/io/parser.cpp index 58f2d5b94467..68e4cdb8116b 100644 --- a/src/io/parser.cpp +++ b/src/io/parser.cpp @@ -4,8 +4,10 @@ */ #include "parser.hpp" +#include #include #include +#include #include namespace LightGBM { @@ -230,6 +232,30 @@ DataType GetDataType(const char* filename, bool header, return type; } +// parser factory implementation. +ParserFactory& ParserFactory::getInstance() { + static ParserFactory factory; + return factory; +} + +void ParserFactory::Register(std::string class_name, std::function m_objc) { + if (m_objc) { + object_map_.insert( + std::map>::value_type(class_name, m_objc)); + } +} + +Parser* ParserFactory::getObject(std::string class_name, std::string config_str) { + std::map>::const_iterator iter = + object_map_.find(class_name); + if (iter != object_map_.end()) { + return iter->second(config_str); + } else { + Log::Fatal("Cannot find parser class '%s', please register first or check config format.", class_name.c_str()); + return nullptr; + } +} + Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser) { const int n_read_line = 32; auto lines = ReadKLineFromFile(filename, header, n_read_line); @@ -258,4 +284,34 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features return ret.release(); } +Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser, std::string parser_config_str) { + // customized parser add-on. + if (!parser_config_str.empty()) { + std::unique_ptr ret; + std::string class_name = Common::GetFromParserConfig(parser_config_str, "className"); + Log::Info("Custom parser class name: %s", class_name.c_str()); + Parser* p = ParserFactory::getInstance().getObject(class_name, parser_config_str); + ret.reset(p); + return ret.release(); + } + return CreateParser(filename, header, num_features, label_idx, precise_float_parser); +} + +std::string Parser::GenerateParserConfigStr(const char* filename, const char* parser_config_filename, bool header, int label_idx) { + TextReader parser_config_reader(parser_config_filename, false); + parser_config_reader.ReadAllLines(); + std::string parser_config_str = parser_config_reader.JoinedLines(); + if (!parser_config_str.empty()) { + // save header to parser config in case needed. + if (header && Common::GetFromParserConfig(parser_config_str, "header").empty()) { + TextReader text_reader(filename, header); + parser_config_str = Common::SaveToParserConfig(parser_config_str, "header", text_reader.first_line()); + } + // save label id to parser config in case needed. + if (Common::GetFromParserConfig(parser_config_str, "labelId").empty()) { + parser_config_str = Common::SaveToParserConfig(parser_config_str, "labelId", std::to_string(label_idx)); + } + } + return parser_config_str; +} } // namespace LightGBM diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 77b5f06362f2..40ad062fb8a7 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -557,3 +557,15 @@ def test_init_score_for_multiclass_classification(init_score_type): ds = lgb.Dataset(data, init_score=init_score).construct() np.testing.assert_equal(ds.get_field('init_score'), init_score) np.testing.assert_equal(ds.init_score, init_score) + + +def test_smoke_custom_parser(tmp_path): + data_path = Path(__file__).absolute().parents[2] / 'examples' / 'binary_classification' / 'binary.train' + parser_config_file = tmp_path / 'parser.ini' + with open(parser_config_file, 'w') as fout: + fout.write('{"className": "dummy", "id": "1"}') + + data = lgb.Dataset(data_path, params={"parser_config_file": parser_config_file}) + with pytest.raises(lgb.basic.LightGBMError, + match="Cannot find parser class 'dummy', please register first or check config format"): + data.construct()