Skip to content

Commit

Permalink
[PYDF] Link YDF get leaves to PYDF.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574887613
  • Loading branch information
achoum authored and copybara-github committed Oct 19, 2023
1 parent df3728b commit a800d11
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 1 deletion.
4 changes: 4 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class GenericCCModel:

class DecisionForestCCModel(GenericCCModel):
def num_trees(self) -> int: ...
def PredictLeaves(
self,
dataset: VerticalDataset,
) -> npt.NDArray[np.int32]: ...

class RandomForestCCModel(DecisionForestCCModel):
@property
Expand Down
4 changes: 4 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ pybind_library(
":pydf_models",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@ydf_cc//yggdrasil_decision_forests/dataset:data_spec_cc_proto",
"@ydf_cc//yggdrasil_decision_forests/dataset:types",
"@ydf_cc//yggdrasil_decision_forests/dataset:vertical_dataset",
"@ydf_cc//yggdrasil_decision_forests/metric:metric_cc_proto",
"@ydf_cc//yggdrasil_decision_forests/model:abstract_model",
Expand Down Expand Up @@ -88,7 +90,9 @@ py_library(
srcs = ["decision_forest_model.py"],
deps = [
":generic_model",
# numpy dep,
"//ydf/cc:ydf",
"//ydf/dataset",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,39 @@

"""Definitions for generic decision forest models."""

import numpy as np

from ydf.cc import ydf
from ydf.dataset import dataset
from ydf.model import generic_model


class DecisionForestModel(generic_model.GenericModel):
"""A generic decision forest model for prediction and inspection."""

_model: ydf.DecisionForestCCModel

def num_trees(self):
"""Returns the number of trees in the decision forest."""
return self._model.num_trees()

def predict_leaves(self, data: dataset.InputDataset) -> np.ndarray:
"""Gets the index of the active leaf in each tree.
The active leaf is the leave that that receive the example during inference.
The returned value "leaves[i,j]" is the index of the active leaf for the
i-th example and the j-th tree. Leaves are indexed by depth first
exploration with the negative child visited before the positive one.
Args:
data: Dataset.
Returns:
Index of the active leaf for each tree in the model.
"""

ds = dataset.create_vertical_dataset(
data, data_spec=self._model.data_spec()
)
return self._model.PredictLeaves(ds._dataset) # pylint: disable=protected-access
3 changes: 2 additions & 1 deletion yggdrasil_decision_forests/port/python/ydf/model/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ void init_model(py::module_& m) {
return absl::Substitute(
"<model_cc.DecisionForestCCModel of type $0.", a.name());
})
.def("num_trees", &DecisionForestCCModel::num_trees);
.def("num_trees", &DecisionForestCCModel::num_trees)
.def("PredictLeaves", &DecisionForestCCModel::PredictLeaves);

py::class_<RandomForestCCModel,
/*parent class*/ DecisionForestCCModel>(m, "RandomForestCCModel")
Expand Down
17 changes: 17 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,23 @@ def test_model_to_cpp(self):
cc = model.to_cpp()
logging.info("cc:\n%s", cc)

def test_predict_leaves(self):
model_path = os.path.join(
test_utils.ydf_test_data_path(),
"model",
"adult_binary_class_gbdt",
)
model = model_lib.load_model(model_path)

dataset_path = os.path.join(
test_utils.ydf_test_data_path(), "dataset", "adult_test.csv"
)
dataset = pd.read_csv(dataset_path)

leaves = model.predict_leaves(dataset)
self.assertEqual(leaves.shape, (dataset.shape[0], model.num_trees()))
self.assertTrue(np.all(leaves >= 0))


class RandomForestModelTest(absltest.TestCase):

Expand Down
20 changes: 20 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "yggdrasil_decision_forests/dataset/types.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset.h"
#include "yggdrasil_decision_forests/metric/metric.pb.h"
#include "yggdrasil_decision_forests/model/abstract_model.h"
Expand Down Expand Up @@ -137,6 +139,24 @@ absl::Status GenericCCModel::Save(
return model::SaveModel(directory, model_.get(), {file_prefix});
}

absl::StatusOr<py::array_t<int32_t>> DecisionForestCCModel::PredictLeaves(
const dataset::VerticalDataset& dataset) {
py::array_t<int32_t, py::array::c_style | py::array::forcecast> leaves;

const size_t num_examples = dataset.nrow();
const size_t num_trees = df_model_->num_trees();

leaves.resize({num_examples, num_trees});
auto unchecked_leaves = leaves.mutable_unchecked();
for (size_t example_idx = 0; example_idx < num_examples; example_idx++) {
auto dst = absl::MakeSpan(unchecked_leaves.mutable_data(example_idx, 0),
num_trees);
RETURN_IF_ERROR(df_model_->PredictGetLeaves(dataset, example_idx, dst));
}

return leaves;
}

absl::StatusOr<std::unique_ptr<RandomForestCCModel>>
RandomForestCCModel::Create(std::unique_ptr<model::AbstractModel>& model_ptr) {
auto* rf_model = dynamic_cast<YDFModel*>(model_ptr.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <pybind11/numpy.h>

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -101,6 +102,9 @@ class DecisionForestCCModel : public GenericCCModel {
public:
int num_trees() const { return df_model_->num_trees(); }

absl::StatusOr<py::array_t<int32_t>> PredictLeaves(
const dataset::VerticalDataset& dataset);

protected:
// `model` and `df_model` must correspond to the same object.
DecisionForestCCModel(std::unique_ptr<model::AbstractModel>&& model,
Expand Down

0 comments on commit a800d11

Please sign in to comment.