Skip to content

Commit

Permalink
Enable Poisson loss for model analysis and fast inference
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705085641
  • Loading branch information
achoum authored and copybara-github committed Dec 11, 2024
1 parent 50e3ef7 commit e6501b7
Show file tree
Hide file tree
Showing 15 changed files with 225 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,21 @@ TEST_F(GradientBoostedTreesOnAdult, MakingAModelPurePureServingModel) {
EXPECT_LE(static_cast<float>(post_pruning_size) / pre_pruning_size, 0.80);
}

TEST_F(GradientBoostedTreesOnAdult, PoissonLoss) {
SetSortingStrategy(Internal::AUTO, Internal::IN_NODE, &train_config_);

auto* gbt_config = train_config_.MutableExtension(
gradient_boosted_trees::proto::gradient_boosted_trees_config);
gbt_config->set_loss(proto::Loss::POISSON);
train_config_.set_label("hours_per_week");
train_config_.set_task(model::proto::Task::REGRESSION);
gbt_config->set_num_trees(10);
gbt_config->mutable_decision_tree()
->mutable_growing_strategy_best_first_global();
gbt_config->mutable_decision_tree()->mutable_sparse_oblique_split();
TrainAndEvaluateModel();
}

// Helper for the training and testing on two non-overlapping samples from the
// Abalone dataset.
class GradientBoostedTreesOnAbalone : public utils::TrainAndTestTester {
Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/model/abstract_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,7 @@ AbstractModel::BuildFastEngine(
} else {
LOG_EVERY_N_SEC(INFO, 10)
<< "Engine \"" << engine_factory->name() << "\" built";
STATUS_CHECK(engine_or.value());
}
return engine_or;
}
Expand Down
4 changes: 4 additions & 0 deletions yggdrasil_decision_forests/model/gradient_boosted_trees/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ cc_test(
deps = [
":gradient_boosted_trees",
":gradient_boosted_trees_cc_proto",
"//yggdrasil_decision_forests/dataset:data_spec",
"//yggdrasil_decision_forests/dataset:data_spec_cc_proto",
"//yggdrasil_decision_forests/dataset:vertical_dataset",
"//yggdrasil_decision_forests/model:abstract_model",
"//yggdrasil_decision_forests/model:abstract_model_cc_proto",
"//yggdrasil_decision_forests/model:model_library",
"//yggdrasil_decision_forests/utils:filesystem",
"//yggdrasil_decision_forests/utils:test",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ void GradientBoostedTreesModel::Predict(
dist->set_counts(1, 1.f - proba_true);
dist->set_counts(2, proba_true);
} break;

case proto::Loss::MULTINOMIAL_LOG_LIKELIHOOD: {
absl::FixedArray<float> accumulator(num_trees_per_iter_);
// Zero initial prediction for the MULTINOMIAL_LOG_LIKELIHOOD.
Expand Down Expand Up @@ -399,20 +400,24 @@ void GradientBoostedTreesModel::Predict(
LOG(FATAL) << "Non supported task";
}
} break;

case proto::Loss::POISSON: {
double accumulator = initial_predictions_[0];
CallOnAllLeafs(dataset, row_idx,
[&accumulator](const decision_tree::proto::Node& node) {
accumulator += node.regressor().top_value();
});
if (task() == model::proto::REGRESSION) {
double clamped_accumulator = std::clamp(accumulator, -19., 19.);
float clamped_accumulator =
std::clamp(static_cast<float>(accumulator),
-kPoissonLossClampBounds, kPoissonLossClampBounds);
prediction->mutable_regression()->set_value(
std::exp(clamped_accumulator));
} else {
LOG(FATAL) << "Non supported task";
LOG(FATAL) << "Only regression is supported with poison loss";
}
} break;

case proto::Loss::LAMBDA_MART_NDCG:
case proto::Loss::LAMBDA_MART_NDCG5:
case proto::Loss::XE_NDCG_MART: {
Expand All @@ -423,8 +428,8 @@ void GradientBoostedTreesModel::Predict(
});
prediction->mutable_ranking()->set_relevance(accumulator);
} break;
default:
LOG(FATAL) << "Not implemented";
case proto::Loss::DEFAULT:
LOG(FATAL) << "Loss not set";
}
}

Expand All @@ -433,6 +438,7 @@ void GradientBoostedTreesModel::Predict(
model::proto::Prediction* prediction) const {
utils::usage::OnInference(1, metadata());
switch (loss_) {
case proto::Loss::BINARY_FOCAL_LOSS:
case proto::Loss::BINOMIAL_LOG_LIKELIHOOD: {
double accumulator = initial_predictions_[0];
CallOnAllLeafs(example,
Expand Down Expand Up @@ -499,6 +505,7 @@ void GradientBoostedTreesModel::Predict(
prediction->mutable_classification()->set_value(highest_cell_idx + 1);
} break;

case proto::Loss::MEAN_AVERAGE_ERROR:
case proto::Loss::SQUARED_ERROR: {
double accumulator = initial_predictions_[0];
CallOnAllLeafs(example,
Expand All @@ -507,6 +514,24 @@ void GradientBoostedTreesModel::Predict(
});
prediction->mutable_regression()->set_value(accumulator);
} break;

case proto::Loss::POISSON: {
double accumulator = initial_predictions_[0];
CallOnAllLeafs(example,
[&accumulator](const decision_tree::proto::Node& node) {
accumulator += node.regressor().top_value();
});
if (task() == model::proto::REGRESSION) {
float clamped_accumulator =
std::clamp(static_cast<float>(accumulator),
-kPoissonLossClampBounds, kPoissonLossClampBounds);
prediction->mutable_regression()->set_value(
std::exp(clamped_accumulator));
} else {
LOG(FATAL) << "Only regression is supported with poison loss";
}
} break;

case proto::Loss::LAMBDA_MART_NDCG:
case proto::Loss::LAMBDA_MART_NDCG5:
case proto::Loss::XE_NDCG_MART: {
Expand All @@ -517,8 +542,8 @@ void GradientBoostedTreesModel::Predict(
});
prediction->mutable_ranking()->set_relevance(accumulator);
} break;
default:
LOG(FATAL) << "Not implemented";
case proto::Loss::DEFAULT:
LOG(FATAL) << "Loss not set";
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class GradientBoostedTreesModel : public AbstractModel,
public:
static constexpr char kRegisteredName[] = "GRADIENT_BOOSTED_TREES";

// The prediction of a model trained with a poisson loss is computed as:
// p := e^clamp(acc, -kPoissonLossClampBounds, kPoissonLossClampBounds)
static constexpr float kPoissonLossClampBounds = 19.;

GradientBoostedTreesModel() : AbstractModel(kRegisteredName) {}
absl::Status Save(absl::string_view directory,
const ModelIOOptions& io_options) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/status/status.h"
#include "yggdrasil_decision_forests/dataset/data_spec.h"
#include "yggdrasil_decision_forests/dataset/data_spec.pb.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset.h"
#include "yggdrasil_decision_forests/model/abstract_model.h"
#include "yggdrasil_decision_forests/model/abstract_model.pb.h"
#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.pb.h"
#include "yggdrasil_decision_forests/model/model_library.h"
#include "yggdrasil_decision_forests/utils/filesystem.h"
Expand Down Expand Up @@ -211,6 +215,43 @@ TEST(GradientBoostedTrees, NDCGTruncationNonRankingModel) {
EXPECT_THAT(description, testing::HasSubstr("BINOMIAL_LOG_LIKELIHOOD\n"));
}

TEST(GradientBoostedTrees, PoissonLossWithClassificationTaskOnExample) {
GradientBoostedTreesModel model;
model.set_loss(proto::Loss::POISSON, {});
model.set_task(model::proto::CLASSIFICATION);
dataset::AddColumn("label", dataset::proto::ColumnType::CATEGORICAL,
model.mutable_data_spec());
model.set_label_col_idx(0);
model.mutable_initial_predictions()->push_back(0);
// Note: We don't validate the model on purpose.

// TODO: Make "Predict" return a status.
dataset::proto::Example example;
model::proto::Prediction prediction;
EXPECT_DEATH(model.Predict(example, &prediction),
"Only regression is supported with poison loss");
}

TEST(GradientBoostedTrees, PoissonLossWithClassificationTaskOnDataset) {
GradientBoostedTreesModel model;
model.set_loss(proto::Loss::POISSON, {});
model.set_task(model::proto::CLASSIFICATION);
dataset::AddColumn("label", dataset::proto::ColumnType::CATEGORICAL,
model.mutable_data_spec());
model.set_label_col_idx(0);
model.mutable_initial_predictions()->push_back(0);
// Note: We don't validate the model on purpose.

dataset::VerticalDataset dataset;
dataset.set_data_spec(model.data_spec());
EXPECT_OK(dataset.CreateColumnsFromDataspec());

// TODO: Make "Predict" return a status.
model::proto::Prediction prediction;
EXPECT_DEATH(model.Predict(dataset, 0, &prediction),
"Only regression is supported with poison loss");
}

} // namespace
} // namespace gradient_boosted_trees
} // namespace model
Expand Down
12 changes: 12 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,18 @@ def test_label_classes_order_str(self, label_data):
).train(data)
npt.assert_equal(model_1.label_classes(), np.unique(label_data).astype(str))

def test_adult_poison(self):
model = specialized_learners.GradientBoostedTreesLearner(
label="hours_per_week",
growing_strategy="BEST_FIRST_GLOBAL",
task=generic_learner.Task.REGRESSION,
loss="POISSON",
split_axis="SPARSE_OBLIQUE",
validation_ratio=0.2,
num_trees=10,
).train(self.adult.train_pd)
_ = model.analyze(self.adult.test_pd, sampling=0.1)


class LoggingTest(parameterized.TestCase):

Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/serving/decision_forest/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ cc_library_ydf(
],
deps = [
"//yggdrasil_decision_forests/model:abstract_model_cc_proto",
"//yggdrasil_decision_forests/model/gradient_boosted_trees",
"//yggdrasil_decision_forests/model/isolation_forest",
"//yggdrasil_decision_forests/serving:example_set",
"//yggdrasil_decision_forests/utils:logging",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,8 @@ absl::Status GenericToSpecializedModel(
absl::Status GenericToSpecializedModel(
const GradientBoostedTreesModel& src,
GradientBoostedTreesRegressionNumericalOnly* dst) {
if (src.loss() != Loss::SQUARED_ERROR ||
if ((src.loss() != Loss::SQUARED_ERROR &&
src.loss() != Loss::MEAN_AVERAGE_ERROR) ||
src.initial_predictions().size() != 1) {
return absl::InvalidArgumentError("The GBT is not trained for regression.");
}
Expand All @@ -923,7 +924,8 @@ absl::Status GenericToSpecializedModel(
absl::Status GenericToSpecializedModel(
const GradientBoostedTreesModel& src,
GradientBoostedTreesRegressionNumericalAndCategorical* dst) {
if (src.loss() != Loss::SQUARED_ERROR ||
if ((src.loss() != Loss::SQUARED_ERROR &&
src.loss() != Loss::MEAN_AVERAGE_ERROR) ||
src.initial_predictions().size() != 1) {
return absl::InvalidArgumentError("The GBT is not trained for regression.");
}
Expand Down Expand Up @@ -1153,7 +1155,8 @@ absl::Status GenericToSpecializedModel(
template <>
absl::Status GenericToSpecializedModel(const GradientBoostedTreesModel& src,
GradientBoostedTreesRegression* dst) {
if (src.loss() != Loss::SQUARED_ERROR ||
if ((src.loss() != Loss::SQUARED_ERROR &&
src.loss() != Loss::MEAN_AVERAGE_ERROR) ||
src.initial_predictions().size() != 1) {
return absl::InvalidArgumentError(
"The Gradient Boosted Tree is not trained for regression.");
Expand Down Expand Up @@ -1183,6 +1186,22 @@ absl::Status GenericToSpecializedModel(const GradientBoostedTreesModel& src,
SetLeafGradientBoostedTreesRegression<DstType>, src, dst);
}

template <>
absl::Status GenericToSpecializedModel(
const GradientBoostedTreesModel& src,
GradientBoostedTreesPoissonRegression* dst) {
if (src.loss() != Loss::POISSON || src.initial_predictions().size() != 1) {
return absl::InvalidArgumentError(
"The Gradient Boosted Tree is not trained for poison regression.");
}

dst->initial_predictions = src.initial_predictions()[0];

using DstType = std::remove_pointer<decltype(dst)>::type;
return GenericToSpecializedGenericModelHelper(
SetLeafGradientBoostedTreesRegression<DstType>, src, dst);
}

template <typename Value>
absl::Status LoadFlatBatchFromDataset(
const VerticalDataset& dataset, VerticalDataset::row_t begin_example_idx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <type_traits>
#include <vector>

#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h"
#include "yggdrasil_decision_forests/model/isolation_forest/isolation_forest.h"
#include "yggdrasil_decision_forests/serving/example_set.h"
#include "yggdrasil_decision_forests/utils/logging.h"
Expand Down Expand Up @@ -52,6 +53,19 @@ float ActivationAddInitialPrediction(const SpecializedModel& model,
return value + model.initial_predictions;
}

// Activation function for regressive GBT with poisson loss.
template <typename SpecializedModel>
float ActivationGradientBoostedTreesPoissonRegression(
const SpecializedModel& model, const float value) {
float clamped_value =
std::clamp(value + model.initial_predictions,
-model::gradient_boosted_trees::GradientBoostedTreesModel::
kPoissonLossClampBounds,
model::gradient_boosted_trees::GradientBoostedTreesModel::
kPoissonLossClampBounds);
return std::exp(clamped_value);
}

// Final function applied by a Gradient Boosted Trees with
// MULTINOMIAL_LOG_LIKELIHOOD loss function. I.e. this is a softmax function.
template <typename SpecializedModel>
Expand Down Expand Up @@ -745,6 +759,17 @@ void Predict(const GradientBoostedTreesRanking& model,
predictions);
}

template <>
void Predict(
const GradientBoostedTreesPoissonRegression& model,
const typename GradientBoostedTreesPoissonRegression::ExampleSet& examples,
int num_examples, std::vector<float>* predictions) {
// Add activation
PredictHelper<std::remove_reference<decltype(model)>::type,
ActivationGradientBoostedTreesPoissonRegression>(
model, examples, num_examples, predictions);
}

} // namespace decision_forest
} // namespace serving
} // namespace yggdrasil_decision_forests
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,18 @@ struct GenericGradientBoostedTreesRanking : ExampleSetModel<NodeOffsetRep> {
};
using GradientBoostedTreesRanking = GenericGradientBoostedTreesRanking<>;

// GBDT model for poisson regression.
template <typename NodeOffsetRep = uint16_t>
struct GenericGradientBoostedTreesPoissonRegression
: ExampleSetModel<NodeOffsetRep> {
static constexpr model::proto::Task kTask = model::proto::Task::REGRESSION;
// Output of the model before any tree is applied, and before the final
// activation function.
float initial_predictions = 0.f;
};
using GradientBoostedTreesPoissonRegression =
GenericGradientBoostedTreesPoissonRegression<>;

template <typename Model>
void Predict(const Model& model, const typename Model::ExampleSet& examples,
int num_examples, std::vector<float>* predictions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,34 @@ TEST(AdultBinaryClassGBDT, ManualGeneric) {
dataset, *model, engine);
}

TEST(AdultBinaryClassGBDT, ManualWithNonCompatibleEngines) {
const auto model = LoadModel("adult_binary_class_gbdt");
const auto dataset = LoadDataset(model->data_spec(), "adult_test.csv", "csv");

auto* gbt_model = dynamic_cast<GradientBoostedTreesModel*>(model.get());

{
GradientBoostedTreesPoissonRegression engine;
EXPECT_THAT(GenericToSpecializedModel(*gbt_model, &engine),
test::StatusIs(absl::StatusCode::kInvalidArgument,
"not trained for poison regression"));
}

{
GradientBoostedTreesRanking engine;
EXPECT_THAT(GenericToSpecializedModel(*gbt_model, &engine),
test::StatusIs(absl::StatusCode::kInvalidArgument,
"not trained for ranking"));
}

{
GradientBoostedTreesRegression engine;
EXPECT_THAT(GenericToSpecializedModel(*gbt_model, &engine),
test::StatusIs(absl::StatusCode::kInvalidArgument,
"not trained for regression"));
}
}

TEST(AdultBinaryClassGBDT, ManualNumCat32) {
const auto model = LoadModel("adult_binary_class_gbdt_32cat");
const auto dataset = LoadDataset(model->data_spec(), "adult_test.csv", "csv");
Expand Down
Loading

0 comments on commit e6501b7

Please sign in to comment.