From a800d118695cd6bf593e32a2551dfa3ef0fe2dff Mon Sep 17 00:00:00 2001 From: Mathieu Guillame-Bert Date: Thu, 19 Oct 2023 08:53:24 -0700 Subject: [PATCH] [PYDF] Link YDF get leaves to PYDF. PiperOrigin-RevId: 574887613 --- .../port/python/ydf/cc/ydf.pyi | 4 +++ .../port/python/ydf/model/BUILD | 4 +++ .../python/ydf/model/decision_forest_model.py | 25 +++++++++++++++++++ .../port/python/ydf/model/model.cc | 3 ++- .../port/python/ydf/model/model_test.py | 17 +++++++++++++ .../port/python/ydf/model/model_wrapper.cc | 20 +++++++++++++++ .../port/python/ydf/model/model_wrapper.h | 4 +++ 7 files changed, 76 insertions(+), 1 deletion(-) diff --git a/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi b/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi index 5eaba934..17b922c4 100644 --- a/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi +++ b/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi @@ -74,6 +74,10 @@ class GenericCCModel: class DecisionForestCCModel(GenericCCModel): def num_trees(self) -> int: ... + def PredictLeaves( + self, + dataset: VerticalDataset, + ) -> npt.NDArray[np.int32]: ... class RandomForestCCModel(DecisionForestCCModel): @property diff --git a/yggdrasil_decision_forests/port/python/ydf/model/BUILD b/yggdrasil_decision_forests/port/python/ydf/model/BUILD index f58e1f97..23100409 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/BUILD +++ b/yggdrasil_decision_forests/port/python/ydf/model/BUILD @@ -29,7 +29,9 @@ pybind_library( ":pydf_models", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", "@ydf_cc//yggdrasil_decision_forests/dataset:data_spec_cc_proto", + "@ydf_cc//yggdrasil_decision_forests/dataset:types", "@ydf_cc//yggdrasil_decision_forests/dataset:vertical_dataset", "@ydf_cc//yggdrasil_decision_forests/metric:metric_cc_proto", "@ydf_cc//yggdrasil_decision_forests/model:abstract_model", @@ -88,7 +90,9 @@ py_library( srcs = ["decision_forest_model.py"], deps = [ ":generic_model", + # numpy dep, "//ydf/cc:ydf", + "//ydf/dataset", ], ) diff --git a/yggdrasil_decision_forests/port/python/ydf/model/decision_forest_model.py b/yggdrasil_decision_forests/port/python/ydf/model/decision_forest_model.py index 04d2a0ff..4a9a070a 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/decision_forest_model.py +++ b/yggdrasil_decision_forests/port/python/ydf/model/decision_forest_model.py @@ -14,14 +14,39 @@ """Definitions for generic decision forest models.""" +import numpy as np + from ydf.cc import ydf +from ydf.dataset import dataset from ydf.model import generic_model class DecisionForestModel(generic_model.GenericModel): """A generic decision forest model for prediction and inspection.""" + _model: ydf.DecisionForestCCModel def num_trees(self): """Returns the number of trees in the decision forest.""" return self._model.num_trees() + + def predict_leaves(self, data: dataset.InputDataset) -> np.ndarray: + """Gets the index of the active leaf in each tree. + + The active leaf is the leave that that receive the example during inference. + + The returned value "leaves[i,j]" is the index of the active leaf for the + i-th example and the j-th tree. Leaves are indexed by depth first + exploration with the negative child visited before the positive one. + + Args: + data: Dataset. + + Returns: + Index of the active leaf for each tree in the model. + """ + + ds = dataset.create_vertical_dataset( + data, data_spec=self._model.data_spec() + ) + return self._model.PredictLeaves(ds._dataset) # pylint: disable=protected-access diff --git a/yggdrasil_decision_forests/port/python/ydf/model/model.cc b/yggdrasil_decision_forests/port/python/ydf/model/model.cc index f3a6f262..568ea52c 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/model.cc +++ b/yggdrasil_decision_forests/port/python/ydf/model/model.cc @@ -88,7 +88,8 @@ void init_model(py::module_& m) { return absl::Substitute( "(m, "RandomForestCCModel") diff --git a/yggdrasil_decision_forests/port/python/ydf/model/model_test.py b/yggdrasil_decision_forests/port/python/ydf/model/model_test.py index aa321979..8f7404e1 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/model_test.py +++ b/yggdrasil_decision_forests/port/python/ydf/model/model_test.py @@ -279,6 +279,23 @@ def test_model_to_cpp(self): cc = model.to_cpp() logging.info("cc:\n%s", cc) + def test_predict_leaves(self): + model_path = os.path.join( + test_utils.ydf_test_data_path(), + "model", + "adult_binary_class_gbdt", + ) + model = model_lib.load_model(model_path) + + dataset_path = os.path.join( + test_utils.ydf_test_data_path(), "dataset", "adult_test.csv" + ) + dataset = pd.read_csv(dataset_path) + + leaves = model.predict_leaves(dataset) + self.assertEqual(leaves.shape, (dataset.shape[0], model.num_trees())) + self.assertTrue(np.all(leaves >= 0)) + class RandomForestModelTest(absltest.TestCase): diff --git a/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.cc b/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.cc index 00c93328..b9dc1eb5 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.cc +++ b/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.cc @@ -29,6 +29,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "yggdrasil_decision_forests/dataset/types.h" #include "yggdrasil_decision_forests/dataset/vertical_dataset.h" #include "yggdrasil_decision_forests/metric/metric.pb.h" #include "yggdrasil_decision_forests/model/abstract_model.h" @@ -137,6 +139,24 @@ absl::Status GenericCCModel::Save( return model::SaveModel(directory, model_.get(), {file_prefix}); } +absl::StatusOr> DecisionForestCCModel::PredictLeaves( + const dataset::VerticalDataset& dataset) { + py::array_t leaves; + + const size_t num_examples = dataset.nrow(); + const size_t num_trees = df_model_->num_trees(); + + leaves.resize({num_examples, num_trees}); + auto unchecked_leaves = leaves.mutable_unchecked(); + for (size_t example_idx = 0; example_idx < num_examples; example_idx++) { + auto dst = absl::MakeSpan(unchecked_leaves.mutable_data(example_idx, 0), + num_trees); + RETURN_IF_ERROR(df_model_->PredictGetLeaves(dataset, example_idx, dst)); + } + + return leaves; +} + absl::StatusOr> RandomForestCCModel::Create(std::unique_ptr& model_ptr) { auto* rf_model = dynamic_cast(model_ptr.get()); diff --git a/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.h b/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.h index d81533cc..54c46de7 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.h +++ b/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.h @@ -18,6 +18,7 @@ #include +#include #include #include #include @@ -101,6 +102,9 @@ class DecisionForestCCModel : public GenericCCModel { public: int num_trees() const { return df_model_->num_trees(); } + absl::StatusOr> PredictLeaves( + const dataset::VerticalDataset& dataset); + protected: // `model` and `df_model` must correspond to the same object. DecisionForestCCModel(std::unique_ptr&& model,