Skip to content

Commit

Permalink
[PYDF] Add support for hash columns
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577754979
  • Loading branch information
rstz authored and copybara-github committed Oct 30, 2023
1 parent cf7eb8c commit 5e6a00e
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 0 deletions.
6 changes: 6 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class VerticalDataset:
data: npt.NDArray[np.bool_],
column_idx: Optional[int],
) -> None: ...
def PopulateColumnHashNPBytes(
self,
name: str,
data: npt.NDArray[np.bytes_],
column_idx: Optional[int] = None,
) -> None: ...
def CreateFromPathWithDataSpec(
self, path: str, data_spec: data_spec_pb2.DataSpecification
) -> None: ...
Expand Down
51 changes: 51 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/dataset/dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ using BooleanColumn =
::yggdrasil_decision_forests::dataset::VerticalDataset::BooleanColumn;
using CategoricalColumn =
::yggdrasil_decision_forests::dataset::VerticalDataset::CategoricalColumn;
using HashColumn =
::yggdrasil_decision_forests::dataset::VerticalDataset::HashColumn;

// Checks if all columns of the dataset have the same number of rows and sets
// the dataset's number of rows accordingly. If requested, also modifies the
Expand Down Expand Up @@ -540,6 +542,52 @@ absl::Status PopulateColumnCategoricalNPBytes(
return absl::OkStatus();
}

// Append contents of `data` to a HASH column. If no `column_idx` is not
// given, a new column is created.
//
// Note that this function only creates the columns and copies the data, but it
// does not set `num_rows` on the dataset. Before using the dataset, `num_rows
// has to be set (e.g. using SetAndCheckNumRows).
absl::Status PopulateColumnHashNPBytes(dataset::VerticalDataset& self,
const std::string& name, py::array& data,
std::optional<int> column_idx) {
ASSIGN_OR_RETURN(const auto values, NPByteArray::Create(data));

HashColumn* column;
size_t offset = 0;
if (!column_idx.has_value()) {
// Create column spec
dataset::proto::Column column_spec;
column_spec.set_name(name);
column_spec.set_type(dataset::proto::ColumnType::HASH);

// Import column data
ASSIGN_OR_RETURN(auto* abstract_column, self.AddColumn(column_spec));
ASSIGN_OR_RETURN(column,
abstract_column->MutableCastWithStatus<HashColumn>());
column_idx = self.ncol() - 1;
} else {
ASSIGN_OR_RETURN(column, self.MutableColumnWithCastWithStatus<HashColumn>(
column_idx.value()));
offset = column->values().size();
}
column->Resize(offset + values.size());
auto& dst_values = *column->mutable_values();

for (size_t value_idx = 0; value_idx < values.size(); value_idx++) {
const auto value = values[value_idx];
uint64_t dst_value;
if (value.empty()) {
dst_value = dataset::VerticalDataset::HashColumn::kNaValue;
} else {
dst_value = dataset::HashColumnString(value);
}
dst_values[offset + value_idx] = dst_value;
}

return absl::OkStatus();
}

absl::Status CreateColumnsFromDataSpec(
dataset::VerticalDataset& self,
const dataset::proto::DataSpecification& data_spec) {
Expand Down Expand Up @@ -647,6 +695,9 @@ void init_dataset(py::module_& m) {
&PopulateColumnNumericalNPFloat32, py::arg("name"),
py::arg("data").noconvert(), py::arg("column_idx") = std::nullopt)
.def("PopulateColumnBooleanNPBool", &PopulateColumnBooleanNPBool,
py::arg("name"), py::arg("data").noconvert(),
py::arg("column_idx") = std::nullopt)
.def("PopulateColumnHashNPBytes", &PopulateColumnHashNPBytes,
py::arg("name"), py::arg("data").noconvert(),
py::arg("column_idx") = std::nullopt);
}
Expand Down
34 changes: 34 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,40 @@ def _add_column(
)
return

elif column.semantic == dataspec.Semantic.HASH:
if not isinstance(column_data, np.ndarray):
column_data = np.array(column_data, dtype=np.bytes_)
elif column_data.dtype.type in [
np.object_,
np.string_,
np.bool_,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
]:
column_data = column_data.astype(np.bytes_)
elif column_data.dtype.type in [
np.float16,
np.float32,
np.float64,
]:
raise ValueError(
f"Cannot import column {column.name!r} with"
f" semantic={column.semantic} as it contains floating point values."
f" Got {original_column_data!r}."
)

if column_data.dtype.type == np.bytes_:
self._dataset.PopulateColumnHashNPBytes(
column.name, column_data, column_idx=column_idx
)
return

raise ValueError(
f"Cannot import column {column.name!r} with semantic={column.semantic},"
f" type={_type(original_column_data)} and"
Expand Down
19 changes: 19 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,25 @@ def test_create_vds_pd_boolean(self):
)
self.assertEqual(ds.data_spec(), expected_data_spec)

def test_create_vds_pd_hash(self):
df = pd.DataFrame(
{"col_hash": ["a", "b", "abc"]},
)

ds = dataset.create_vertical_dataset(
df, columns=[("col_hash", Semantic.HASH)]
)
expected_data_spec = ds_pb.DataSpecification(
created_num_rows=3,
columns=(
ds_pb.Column(
name="col_hash",
type=ds_pb.ColumnType.HASH,
),
),
)
self.assertEqual(ds.data_spec(), expected_data_spec)

@parameterized.parameters(
(["col_numerical"],),
([Column("col_numerical")],),
Expand Down

0 comments on commit 5e6a00e

Please sign in to comment.