diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees_test.cc b/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees_test.cc index 69a7884c..72c790b6 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees_test.cc +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees_test.cc @@ -1574,6 +1574,21 @@ TEST_F(GradientBoostedTreesOnAdult, MakingAModelPurePureServingModel) { EXPECT_LE(static_cast(post_pruning_size) / pre_pruning_size, 0.80); } +TEST_F(GradientBoostedTreesOnAdult, PoissonLoss) { + SetSortingStrategy(Internal::AUTO, Internal::IN_NODE, &train_config_); + + auto* gbt_config = train_config_.MutableExtension( + gradient_boosted_trees::proto::gradient_boosted_trees_config); + gbt_config->set_loss(proto::Loss::POISSON); + train_config_.set_label("hours_per_week"); + train_config_.set_task(model::proto::Task::REGRESSION); + gbt_config->set_num_trees(10); + gbt_config->mutable_decision_tree() + ->mutable_growing_strategy_best_first_global(); + gbt_config->mutable_decision_tree()->mutable_sparse_oblique_split(); + TrainAndEvaluateModel(); +} + // Helper for the training and testing on two non-overlapping samples from the // Abalone dataset. class GradientBoostedTreesOnAbalone : public utils::TrainAndTestTester { diff --git a/yggdrasil_decision_forests/model/abstract_model.cc b/yggdrasil_decision_forests/model/abstract_model.cc index 35c1bef0..af7d970e 100644 --- a/yggdrasil_decision_forests/model/abstract_model.cc +++ b/yggdrasil_decision_forests/model/abstract_model.cc @@ -1439,6 +1439,7 @@ AbstractModel::BuildFastEngine( } else { LOG_EVERY_N_SEC(INFO, 10) << "Engine \"" << engine_factory->name() << "\" built"; + STATUS_CHECK(engine_or.value()); } return engine_or; } diff --git a/yggdrasil_decision_forests/model/gradient_boosted_trees/BUILD b/yggdrasil_decision_forests/model/gradient_boosted_trees/BUILD index f8999b67..4505f5f5 100644 --- a/yggdrasil_decision_forests/model/gradient_boosted_trees/BUILD +++ b/yggdrasil_decision_forests/model/gradient_boosted_trees/BUILD @@ -70,7 +70,11 @@ cc_test( deps = [ ":gradient_boosted_trees", ":gradient_boosted_trees_cc_proto", + "//yggdrasil_decision_forests/dataset:data_spec", + "//yggdrasil_decision_forests/dataset:data_spec_cc_proto", + "//yggdrasil_decision_forests/dataset:vertical_dataset", "//yggdrasil_decision_forests/model:abstract_model", + "//yggdrasil_decision_forests/model:abstract_model_cc_proto", "//yggdrasil_decision_forests/model:model_library", "//yggdrasil_decision_forests/utils:filesystem", "//yggdrasil_decision_forests/utils:test", diff --git a/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.cc b/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.cc index eb5fa94b..c433bad5 100644 --- a/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.cc +++ b/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.cc @@ -317,6 +317,7 @@ void GradientBoostedTreesModel::Predict( dist->set_counts(1, 1.f - proba_true); dist->set_counts(2, proba_true); } break; + case proto::Loss::MULTINOMIAL_LOG_LIKELIHOOD: { absl::FixedArray accumulator(num_trees_per_iter_); // Zero initial prediction for the MULTINOMIAL_LOG_LIKELIHOOD. @@ -399,6 +400,7 @@ void GradientBoostedTreesModel::Predict( LOG(FATAL) << "Non supported task"; } } break; + case proto::Loss::POISSON: { double accumulator = initial_predictions_[0]; CallOnAllLeafs(dataset, row_idx, @@ -406,13 +408,16 @@ void GradientBoostedTreesModel::Predict( accumulator += node.regressor().top_value(); }); if (task() == model::proto::REGRESSION) { - double clamped_accumulator = std::clamp(accumulator, -19., 19.); + float clamped_accumulator = + std::clamp(static_cast(accumulator), + -kPoissonLossClampBounds, kPoissonLossClampBounds); prediction->mutable_regression()->set_value( std::exp(clamped_accumulator)); } else { - LOG(FATAL) << "Non supported task"; + LOG(FATAL) << "Only regression is supported with poison loss"; } } break; + case proto::Loss::LAMBDA_MART_NDCG: case proto::Loss::LAMBDA_MART_NDCG5: case proto::Loss::XE_NDCG_MART: { @@ -423,8 +428,8 @@ void GradientBoostedTreesModel::Predict( }); prediction->mutable_ranking()->set_relevance(accumulator); } break; - default: - LOG(FATAL) << "Not implemented"; + case proto::Loss::DEFAULT: + LOG(FATAL) << "Loss not set"; } } @@ -433,6 +438,7 @@ void GradientBoostedTreesModel::Predict( model::proto::Prediction* prediction) const { utils::usage::OnInference(1, metadata()); switch (loss_) { + case proto::Loss::BINARY_FOCAL_LOSS: case proto::Loss::BINOMIAL_LOG_LIKELIHOOD: { double accumulator = initial_predictions_[0]; CallOnAllLeafs(example, @@ -499,6 +505,7 @@ void GradientBoostedTreesModel::Predict( prediction->mutable_classification()->set_value(highest_cell_idx + 1); } break; + case proto::Loss::MEAN_AVERAGE_ERROR: case proto::Loss::SQUARED_ERROR: { double accumulator = initial_predictions_[0]; CallOnAllLeafs(example, @@ -507,6 +514,24 @@ void GradientBoostedTreesModel::Predict( }); prediction->mutable_regression()->set_value(accumulator); } break; + + case proto::Loss::POISSON: { + double accumulator = initial_predictions_[0]; + CallOnAllLeafs(example, + [&accumulator](const decision_tree::proto::Node& node) { + accumulator += node.regressor().top_value(); + }); + if (task() == model::proto::REGRESSION) { + float clamped_accumulator = + std::clamp(static_cast(accumulator), + -kPoissonLossClampBounds, kPoissonLossClampBounds); + prediction->mutable_regression()->set_value( + std::exp(clamped_accumulator)); + } else { + LOG(FATAL) << "Only regression is supported with poison loss"; + } + } break; + case proto::Loss::LAMBDA_MART_NDCG: case proto::Loss::LAMBDA_MART_NDCG5: case proto::Loss::XE_NDCG_MART: { @@ -517,8 +542,8 @@ void GradientBoostedTreesModel::Predict( }); prediction->mutable_ranking()->set_relevance(accumulator); } break; - default: - LOG(FATAL) << "Not implemented"; + case proto::Loss::DEFAULT: + LOG(FATAL) << "Loss not set"; } } diff --git a/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h b/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h index d161a53f..707f4d48 100644 --- a/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h +++ b/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h @@ -58,6 +58,10 @@ class GradientBoostedTreesModel : public AbstractModel, public: static constexpr char kRegisteredName[] = "GRADIENT_BOOSTED_TREES"; + // The prediction of a model trained with a poisson loss is computed as: + // p := e^clamp(acc, -kPoissonLossClampBounds, kPoissonLossClampBounds) + static constexpr float kPoissonLossClampBounds = 19.; + GradientBoostedTreesModel() : AbstractModel(kRegisteredName) {} absl::Status Save(absl::string_view directory, const ModelIOOptions& io_options) const override; diff --git a/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees_test.cc b/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees_test.cc index 95c19277..6264804c 100644 --- a/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees_test.cc +++ b/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees_test.cc @@ -21,7 +21,11 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/status/status.h" +#include "yggdrasil_decision_forests/dataset/data_spec.h" +#include "yggdrasil_decision_forests/dataset/data_spec.pb.h" +#include "yggdrasil_decision_forests/dataset/vertical_dataset.h" #include "yggdrasil_decision_forests/model/abstract_model.h" +#include "yggdrasil_decision_forests/model/abstract_model.pb.h" #include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.pb.h" #include "yggdrasil_decision_forests/model/model_library.h" #include "yggdrasil_decision_forests/utils/filesystem.h" @@ -211,6 +215,43 @@ TEST(GradientBoostedTrees, NDCGTruncationNonRankingModel) { EXPECT_THAT(description, testing::HasSubstr("BINOMIAL_LOG_LIKELIHOOD\n")); } +TEST(GradientBoostedTrees, PoissonLossWithClassificationTaskOnExample) { + GradientBoostedTreesModel model; + model.set_loss(proto::Loss::POISSON, {}); + model.set_task(model::proto::CLASSIFICATION); + dataset::AddColumn("label", dataset::proto::ColumnType::CATEGORICAL, + model.mutable_data_spec()); + model.set_label_col_idx(0); + model.mutable_initial_predictions()->push_back(0); + // Note: We don't validate the model on purpose. + + // TODO: Make "Predict" return a status. + dataset::proto::Example example; + model::proto::Prediction prediction; + EXPECT_DEATH(model.Predict(example, &prediction), + "Only regression is supported with poison loss"); +} + +TEST(GradientBoostedTrees, PoissonLossWithClassificationTaskOnDataset) { + GradientBoostedTreesModel model; + model.set_loss(proto::Loss::POISSON, {}); + model.set_task(model::proto::CLASSIFICATION); + dataset::AddColumn("label", dataset::proto::ColumnType::CATEGORICAL, + model.mutable_data_spec()); + model.set_label_col_idx(0); + model.mutable_initial_predictions()->push_back(0); + // Note: We don't validate the model on purpose. + + dataset::VerticalDataset dataset; + dataset.set_data_spec(model.data_spec()); + EXPECT_OK(dataset.CreateColumnsFromDataspec()); + + // TODO: Make "Predict" return a status. + model::proto::Prediction prediction; + EXPECT_DEATH(model.Predict(dataset, 0, &prediction), + "Only regression is supported with poison loss"); +} + } // namespace } // namespace gradient_boosted_trees } // namespace model diff --git a/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py b/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py index 85a15d2e..ffab3d5e 100644 --- a/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py +++ b/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py @@ -1481,6 +1481,18 @@ def test_label_classes_order_str(self, label_data): ).train(data) npt.assert_equal(model_1.label_classes(), np.unique(label_data).astype(str)) + def test_adult_poison(self): + model = specialized_learners.GradientBoostedTreesLearner( + label="hours_per_week", + growing_strategy="BEST_FIRST_GLOBAL", + task=generic_learner.Task.REGRESSION, + loss="POISSON", + split_axis="SPARSE_OBLIQUE", + validation_ratio=0.2, + num_trees=10, + ).train(self.adult.train_pd) + _ = model.analyze(self.adult.test_pd, sampling=0.1) + class LoggingTest(parameterized.TestCase): diff --git a/yggdrasil_decision_forests/serving/decision_forest/BUILD b/yggdrasil_decision_forests/serving/decision_forest/BUILD index 33f3ff4b..0e600c4e 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/BUILD +++ b/yggdrasil_decision_forests/serving/decision_forest/BUILD @@ -127,6 +127,7 @@ cc_library_ydf( ], deps = [ "//yggdrasil_decision_forests/model:abstract_model_cc_proto", + "//yggdrasil_decision_forests/model/gradient_boosted_trees", "//yggdrasil_decision_forests/model/isolation_forest", "//yggdrasil_decision_forests/serving:example_set", "//yggdrasil_decision_forests/utils:logging", diff --git a/yggdrasil_decision_forests/serving/decision_forest/decision_forest.cc b/yggdrasil_decision_forests/serving/decision_forest/decision_forest.cc index 1fbc65f4..5f9bd0dc 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/decision_forest.cc +++ b/yggdrasil_decision_forests/serving/decision_forest/decision_forest.cc @@ -903,7 +903,8 @@ absl::Status GenericToSpecializedModel( absl::Status GenericToSpecializedModel( const GradientBoostedTreesModel& src, GradientBoostedTreesRegressionNumericalOnly* dst) { - if (src.loss() != Loss::SQUARED_ERROR || + if ((src.loss() != Loss::SQUARED_ERROR && + src.loss() != Loss::MEAN_AVERAGE_ERROR) || src.initial_predictions().size() != 1) { return absl::InvalidArgumentError("The GBT is not trained for regression."); } @@ -923,7 +924,8 @@ absl::Status GenericToSpecializedModel( absl::Status GenericToSpecializedModel( const GradientBoostedTreesModel& src, GradientBoostedTreesRegressionNumericalAndCategorical* dst) { - if (src.loss() != Loss::SQUARED_ERROR || + if ((src.loss() != Loss::SQUARED_ERROR && + src.loss() != Loss::MEAN_AVERAGE_ERROR) || src.initial_predictions().size() != 1) { return absl::InvalidArgumentError("The GBT is not trained for regression."); } @@ -1153,7 +1155,8 @@ absl::Status GenericToSpecializedModel( template <> absl::Status GenericToSpecializedModel(const GradientBoostedTreesModel& src, GradientBoostedTreesRegression* dst) { - if (src.loss() != Loss::SQUARED_ERROR || + if ((src.loss() != Loss::SQUARED_ERROR && + src.loss() != Loss::MEAN_AVERAGE_ERROR) || src.initial_predictions().size() != 1) { return absl::InvalidArgumentError( "The Gradient Boosted Tree is not trained for regression."); @@ -1183,6 +1186,22 @@ absl::Status GenericToSpecializedModel(const GradientBoostedTreesModel& src, SetLeafGradientBoostedTreesRegression, src, dst); } +template <> +absl::Status GenericToSpecializedModel( + const GradientBoostedTreesModel& src, + GradientBoostedTreesPoissonRegression* dst) { + if (src.loss() != Loss::POISSON || src.initial_predictions().size() != 1) { + return absl::InvalidArgumentError( + "The Gradient Boosted Tree is not trained for poison regression."); + } + + dst->initial_predictions = src.initial_predictions()[0]; + + using DstType = std::remove_pointer::type; + return GenericToSpecializedGenericModelHelper( + SetLeafGradientBoostedTreesRegression, src, dst); +} + template absl::Status LoadFlatBatchFromDataset( const VerticalDataset& dataset, VerticalDataset::row_t begin_example_idx, diff --git a/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.cc b/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.cc index 01ca4ae8..eab2cd47 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.cc +++ b/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.cc @@ -22,6 +22,7 @@ #include #include +#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h" #include "yggdrasil_decision_forests/model/isolation_forest/isolation_forest.h" #include "yggdrasil_decision_forests/serving/example_set.h" #include "yggdrasil_decision_forests/utils/logging.h" @@ -52,6 +53,19 @@ float ActivationAddInitialPrediction(const SpecializedModel& model, return value + model.initial_predictions; } +// Activation function for regressive GBT with poisson loss. +template +float ActivationGradientBoostedTreesPoissonRegression( + const SpecializedModel& model, const float value) { + float clamped_value = + std::clamp(value + model.initial_predictions, + -model::gradient_boosted_trees::GradientBoostedTreesModel:: + kPoissonLossClampBounds, + model::gradient_boosted_trees::GradientBoostedTreesModel:: + kPoissonLossClampBounds); + return std::exp(clamped_value); +} + // Final function applied by a Gradient Boosted Trees with // MULTINOMIAL_LOG_LIKELIHOOD loss function. I.e. this is a softmax function. template @@ -745,6 +759,17 @@ void Predict(const GradientBoostedTreesRanking& model, predictions); } +template <> +void Predict( + const GradientBoostedTreesPoissonRegression& model, + const typename GradientBoostedTreesPoissonRegression::ExampleSet& examples, + int num_examples, std::vector* predictions) { + // Add activation + PredictHelper::type, + ActivationGradientBoostedTreesPoissonRegression>( + model, examples, num_examples, predictions); +} + } // namespace decision_forest } // namespace serving } // namespace yggdrasil_decision_forests diff --git a/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.h b/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.h index 2b197c76..c8961421 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.h +++ b/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.h @@ -445,6 +445,18 @@ struct GenericGradientBoostedTreesRanking : ExampleSetModel { }; using GradientBoostedTreesRanking = GenericGradientBoostedTreesRanking<>; +// GBDT model for poisson regression. +template +struct GenericGradientBoostedTreesPoissonRegression + : ExampleSetModel { + static constexpr model::proto::Task kTask = model::proto::Task::REGRESSION; + // Output of the model before any tree is applied, and before the final + // activation function. + float initial_predictions = 0.f; +}; +using GradientBoostedTreesPoissonRegression = + GenericGradientBoostedTreesPoissonRegression<>; + template void Predict(const Model& model, const typename Model::ExampleSet& examples, int num_examples, std::vector* predictions); diff --git a/yggdrasil_decision_forests/serving/decision_forest/decision_forest_test.cc b/yggdrasil_decision_forests/serving/decision_forest/decision_forest_test.cc index 0c79bc9a..3bdbfd4e 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/decision_forest_test.cc +++ b/yggdrasil_decision_forests/serving/decision_forest/decision_forest_test.cc @@ -410,6 +410,34 @@ TEST(AdultBinaryClassGBDT, ManualGeneric) { dataset, *model, engine); } +TEST(AdultBinaryClassGBDT, ManualWithNonCompatibleEngines) { + const auto model = LoadModel("adult_binary_class_gbdt"); + const auto dataset = LoadDataset(model->data_spec(), "adult_test.csv", "csv"); + + auto* gbt_model = dynamic_cast(model.get()); + + { + GradientBoostedTreesPoissonRegression engine; + EXPECT_THAT(GenericToSpecializedModel(*gbt_model, &engine), + test::StatusIs(absl::StatusCode::kInvalidArgument, + "not trained for poison regression")); + } + + { + GradientBoostedTreesRanking engine; + EXPECT_THAT(GenericToSpecializedModel(*gbt_model, &engine), + test::StatusIs(absl::StatusCode::kInvalidArgument, + "not trained for ranking")); + } + + { + GradientBoostedTreesRegression engine; + EXPECT_THAT(GenericToSpecializedModel(*gbt_model, &engine), + test::StatusIs(absl::StatusCode::kInvalidArgument, + "not trained for regression")); + } +} + TEST(AdultBinaryClassGBDT, ManualNumCat32) { const auto model = LoadModel("adult_binary_class_gbdt_32cat"); const auto dataset = LoadDataset(model->data_spec(), "adult_test.csv", "csv"); diff --git a/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.cc b/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.cc index 222512ed..a531136f 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.cc +++ b/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.cc @@ -96,7 +96,12 @@ float ActivationBinomialLogLikelihood(const float value) { // Activation function for binary classification GBDT trained with Binomial // LogLikelihood loss. float ActivationPoisson(const float value) { - return std::exp(std::clamp(value, -19.f, 19.f)); + return std::exp( + std::clamp(value, + -model::gradient_boosted_trees::GradientBoostedTreesModel:: + kPoissonLossClampBounds, + model::gradient_boosted_trees::GradientBoostedTreesModel:: + kPoissonLossClampBounds)); } // Identity activation function. diff --git a/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc b/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc index 49d9c068..850dccac 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc +++ b/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc @@ -181,10 +181,6 @@ class GradientBoostedTreesGenericFastEngineFactory : public FastEngineFactory { if (gbt_model == nullptr) { return false; } - // TODO: Support Poisson loss - if (gbt_model->loss() == gradient_boosted_trees::proto::POISSON) { - return false; - } return gbt_model->CheckStructure({/*.global_imputation_is_higher =*/false}); } @@ -208,7 +204,6 @@ class GradientBoostedTreesGenericFastEngineFactory : public FastEngineFactory { // More than 65k nodes in a single tree is a likely indication of a problem // with the model. - std::unique_ptr engine; switch (gbt_model->task()) { case proto::CLASSIFICATION: if (gbt_model->label_col_spec() @@ -241,11 +236,19 @@ class GradientBoostedTreesGenericFastEngineFactory : public FastEngineFactory { } case proto::REGRESSION: { - auto engine = std::make_unique>(); - RETURN_IF_ERROR(engine->LoadModel(*gbt_model)); - return engine; + if (gbt_model->loss() == gradient_boosted_trees::proto::POISSON) { + auto engine = std::make_unique>(); + RETURN_IF_ERROR(engine->LoadModel(*gbt_model)); + return engine; + } else { + auto engine = std::make_unique>(); + RETURN_IF_ERROR(engine->LoadModel(*gbt_model)); + return engine; + } } case proto::RANKING: { diff --git a/yggdrasil_decision_forests/utils/test_utils.h b/yggdrasil_decision_forests/utils/test_utils.h index 2316676b..56c5388d 100644 --- a/yggdrasil_decision_forests/utils/test_utils.h +++ b/yggdrasil_decision_forests/utils/test_utils.h @@ -76,9 +76,16 @@ namespace utils { // Train and test a model on a dataset stored in the "test_data" folder. class TrainAndTestTester : public ::testing::Test { public: - // Run the training and evaluation of the model. Should be called after - // "train_config_" is set. After this function is called, "evaluation_" - // contains the result of the evaluation. + // Trains, evaluates, serialized & deserialization (save and load a model to + // disk [directory format], or save and load a model from a sequence of bytes + // [byte sequence format]) + tests predictions, and check the equality of the + // predictions from the different inference implementations (e.g., slow + // engine, all available fast engines). + // + // This method should be called after "train_config_" is set. Once this + // function returns, "evaluation_" contains the result of the evaluation, + // "training_duration_" contains the duration of the training, and "model_" + // contains the model. void TrainAndEvaluateModel( std::optional numerical_weight_attribute = {}, bool emulate_weight_with_duplication = false,