Skip to content

Commit

Permalink
[YDF] Fix errors for HASH and BOOLEAN features
Browse files Browse the repository at this point in the history
- HASH features: Convert crash to an error message
- BOOLEAN features: Make a more clear error message

PiperOrigin-RevId: 689677523
  • Loading branch information
rstz authored and copybara-github committed Oct 25, 2024
1 parent 7607f49 commit 7123adb
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 0 deletions.
17 changes: 17 additions & 0 deletions yggdrasil_decision_forests/learner/decision_tree/training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4073,6 +4073,23 @@ absl::Status DecisionTreeTrain(
std::vector<UnsignedExampleIdx> selected_examples_buffer;
std::vector<UnsignedExampleIdx> leaf_examples_buffer;

// Fail if the data spec has invalid columns.
for (const auto feature_idx : config_link.features()) {
const auto& data_spec_columns = train_dataset.data_spec().columns();
const auto column_type = data_spec_columns[feature_idx].type();
if (column_type != dataset::proto::NUMERICAL &&
column_type != dataset::proto::CATEGORICAL &&
column_type != dataset::proto::CATEGORICAL_SET &&
column_type != dataset::proto::BOOLEAN &&
column_type != dataset::proto::DISCRETIZED_NUMERICAL) {
return absl::InvalidArgumentError(
absl::Substitute("Column $0 has type $1, which is not supported "
"for decision tree training.",
data_spec_columns[feature_idx].name(),
dataset::proto::ColumnType_Name(column_type)));
}
}

// Check monotonic constraints
if (config.monotonic_constraints_size() > 0 &&
!dt_config.keep_non_leaf_label_distribution()) {
Expand Down
7 changes: 7 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ def _add_column(
if not isinstance(column_data, np.ndarray):
column_data = np.array(column_data, np.bool_)
ydf_dtype = dataspec.np_dtype_to_ydf_dtype(column_data.dtype)
if column_data.dtype != np.bool_:
message = (
f"Cannot import column {column.name!r} with"
f" semantic={column.semantic} as it does not contain boolean"
f" values. Got {original_column_data!r}."
)
raise ValueError(message)

self._dataset.PopulateColumnBooleanNPBool(
column.name,
Expand Down
39 changes: 39 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,45 @@ def test_from_column_less_pandas(self):
):
dataset.create_vertical_dataset(pd.DataFrame([[1, 2, 3], [4, 5, 6]]))

def test_boolean_column(self):
data = {
"f1": np.array([True, True, True, False, False, False, False]),
}
ds = dataset.create_vertical_dataset(
data,
columns=[("f1", dataspec.Semantic.BOOLEAN)],
)
self.assertEqual(
ds.data_spec(),
ds_pb.DataSpecification(
created_num_rows=7,
columns=[
ds_pb.Column(
name="f1",
type=ds_pb.ColumnType.BOOLEAN,
count_nas=0,
boolean=ds_pb.BooleanSpec(count_true=3, count_false=4),
dtype=ds_pb.DType.DTYPE_BOOL,
)
],
),
)
self.assertEqual(ds._dataset.DebugString(), "f1\n1\n1\n1\n0\n0\n0\n0\n")

def test_fail_gracefully_for_incorrect_boolean_type(self):
data = {
"f1": np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]),
}
with self.assertRaisesRegex(
test_utils.AbslInvalidArgumentError,
"Cannot import column 'f1' with semantic=Semantic.BOOLEAN as it does"
" not contain boolean values.*",
):
dataset.create_vertical_dataset(
data,
columns=[("f1", dataspec.Semantic.BOOLEAN)],
)


class CategoricalSetTest(absltest.TestCase):

Expand Down
29 changes: 29 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,35 @@ def create_dataset(n: int) -> Dict[str, np.ndarray]:
).train(create_dataset(1_000))
_ = model.analyze(create_dataset(100_000))

def test_boolean_feature(self):
data = {
"f1": np.array(
[True, True, True, True, True, False, False, False, False, False]
),
"label": np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]),
}
model = specialized_learners.RandomForestLearner(
label="label",
features=[("f1", dataspec.Semantic.BOOLEAN)],
num_trees=1,
bootstrap_training_dataset=False,
).train(data)
npt.assert_equal(model.predict(data), data["label"])

def test_fail_gracefully_for_hash_columns(self):
data = {
"f1": np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
"label": np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]),
}
with self.assertRaisesRegex(
test_utils.AbslInvalidArgumentError,
"Column f1 has type HASH, which is not supported for decision tree"
" training.",
):
_ = specialized_learners.RandomForestLearner(
label="label", features=[("f1", dataspec.Semantic.HASH)], num_trees=2
).train(data)


class CARTLearnerTest(LearnerTest):

Expand Down

0 comments on commit 7123adb

Please sign in to comment.