Skip to content

Commit

Permalink
Tabular transformer (part 4)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704295716
  • Loading branch information
achoum authored and copybara-github committed Dec 9, 2024
1 parent 19a7892 commit b43657d
Show file tree
Hide file tree
Showing 14 changed files with 885 additions and 2 deletions.
34 changes: 33 additions & 1 deletion yggdrasil_decision_forests/port/python/ydf/dataset/io/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ py_library(
srcs = ["dataset_io.py"],
deps = [
":dataset_io_types",
":numpy_io",
":pandas_io",
":polars_io",
":pygrain_io",
Expand All @@ -39,8 +40,9 @@ py_library(
srcs = ["pandas_io.py"],
deps = [
":dataset_io_types",
":generator",
# absl/logging dep,
# pandas dep,
# numpy dep,
],
)

Expand Down Expand Up @@ -79,17 +81,36 @@ py_library(
],
)

py_library(
name = "generator",
srcs = ["generator.py"],
deps = [
# numpy dep,
],
)

py_library(
name = "numpy_io",
srcs = ["numpy_io.py"],
deps = [
":generator",
# numpy dep,
],
)

# Tests
# =====

py_test(
name = "pandas_io_test",
srcs = ["pandas_io_test.py"],
data = ["@ydf_cc//yggdrasil_decision_forests/test_data"],
deps = [
":pandas_io",
# absl/testing:absltest dep,
# pandas dep,
# polars dep,
"//ydf/utils:test_utils",
],
)

Expand Down Expand Up @@ -143,3 +164,14 @@ py_test(
# absl/testing:parameterized dep,
],
)

py_test(
name = "numpy_io_test",
srcs = ["numpy_io_test.py"],
deps = [
":numpy_io",
# absl/testing:absltest dep,
# numpy dep,
"//ydf/utils:test_utils",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import numpy as np

from ydf.dataset.io import dataset_io_types
from ydf.dataset.io import numpy_io
from ydf.dataset.io import pandas_io
from ydf.dataset.io import polars_io
from ydf.dataset.io import pygrain_io
from ydf.dataset.io import tensorflow_io
from ydf.dataset.io import xarray_io


def unrolled_feature_names(name: str, num_dims: int) -> Sequence[str]:
"""Returns the names of an unrolled feature."""

Expand Down Expand Up @@ -207,3 +209,38 @@ def cast_input_dataset_to_dict(
"Unsupported dataset type: "
f"{type(data)}\n{dataset_io_types.SUPPORTED_INPUT_DATA_DESCRIPTION}"
)


def build_batched_example_generator(
data: dataset_io_types.IODataset,
):
"""Converts any support dataset format into a batched example generator.
Usage example:
```python
generator = build_batched_example_generator({
"a":np.array([1, 2, 3]),
"b":np.array(["x", "y", "z"]),
})
for batch in generator.generate(batch_size=2, shuffle=False):
print(batch)
>> { "a":np.array([1, 2]), "b":np.array(["x", "y"]) }
>> { "a":np.array([3]), "b":np.array(["z"]) }
```
Args:
data: Support dataset format.
Returns:
Example generator.
"""
if pandas_io.is_pandas_dataframe(data):
return pandas_io.PandasBatchedExampleGenerator(data)
elif isinstance(data, dict):
return numpy_io.NumpyDictBatchedExampleGenerator(data)
else:
# TODO: Add support for other YDF dataset formats.
raise ValueError(
f"Unsupported dataset type to train a Deep YDF model: {type(data)}"
)
50 changes: 50 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/dataset/io/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility to handle datasets."""

import abc
from typing import Dict, Iterator, Optional
import numpy as np

# A single batch of data in various formats. The attribute values are indexed by
# attribute names.
NumpyExampleBatch = Dict[str, np.ndarray]


class BatchedExampleGenerator(abc.ABC):
"""A class able to generate batches of examples."""

def __init__(self, num_examples: Optional[int]):
self._num_examples = num_examples

@property
def num_examples(self) -> Optional[int]:
"""Number of examples in the dataset."""
return self._num_examples

def num_batches(self, batch_size: int) -> Optional[int]:
if self._num_examples is None:
return None
return (self._num_examples + batch_size - 1) // batch_size

@abc.abstractmethod
def generate(
self,
batch_size: int,
shuffle: bool,
seed: Optional[int] = None,
) -> Iterator[NumpyExampleBatch]:
"""Generate an iterator."""
raise NotImplementedError
54 changes: 54 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/dataset/io/numpy_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dataset generator for dict of numpy arrays."""

from typing import Dict, Iterator, Optional
import numpy as np
from ydf.dataset.io import generator as generator_lib


class NumpyDictBatchedExampleGenerator(generator_lib.BatchedExampleGenerator):
"""Class to consume dictionaries of Numpy arrays."""

def __init__(self, data: Dict[str, np.ndarray]):
self._data = data
super().__init__(num_examples=len(next(iter(data.values()))))

def generate(
self,
batch_size: int,
shuffle: bool,
seed: Optional[int] = None,
) -> Iterator[generator_lib.NumpyExampleBatch]:
assert self._num_examples is not None
if not shuffle:
i = 0
while i < self._num_examples:
begin_idx = i
end_idx = min(i + batch_size, self._num_examples)
yield {str(k): v[begin_idx:end_idx] for k, v in self._data.items()}
i += batch_size
else:
if seed is None:
raise ValueError("seed is required if shuffle=True")
rng = np.random.default_rng(seed)
idxs = rng.permutation(self._num_examples)
i = 0
while i < self._num_examples:
begin_idx = i
end_idx = min(i + batch_size, self._num_examples)
selected_idxs = idxs[begin_idx:end_idx]
yield {str(k): v[selected_idxs] for k, v in self._data.items()}
i += batch_size
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test dataspec utilities for pandas."""

from absl.testing import absltest
import numpy as np

from ydf.dataset.io import numpy_io
from ydf.utils import test_utils


class NumpyIOTest(absltest.TestCase):

def test_numpy_generator(self):
ds = numpy_io.NumpyDictBatchedExampleGenerator({
"a": np.array([1, 2, 3]),
"b": np.array(["x", "y", "z"]),
})

for batch_idx, batch in enumerate(ds.generate(batch_size=2, shuffle=False)):
if batch_idx == 0:
test_utils.assert_almost_equal(
batch, {"a": np.array([1, 2]), "b": np.array(["x", "y"])}
)
elif batch_idx == 1:
test_utils.assert_almost_equal(
batch, {"a": np.array([3]), "b": np.array(["z"])}
)
else:
assert False

def test_numpy_generator_shuffle(self):
ds = numpy_io.NumpyDictBatchedExampleGenerator({
"a": np.array([1, 2, 3]),
"b": np.array(["x", "y", "z"]),
})
count_per_first_a_value = [0] * 4
num_runs = 100
for i in range(100):
num_sum_a = 0
num_batches = 0
for batch_idx, batch in enumerate(
ds.generate(batch_size=2, shuffle=True, seed=i)
):
num_sum_a += np.sum(batch["a"])
num_batches += 1
if batch_idx == 0:
first_value = batch["a"][0]
count_per_first_a_value[first_value] += 1
self.assertEqual(num_batches, 2)
self.assertEqual(num_sum_a, 1 + 2 + 3)
for i in range(1, 3):
self.assertGreater(count_per_first_a_value[i], num_runs / 10)


if __name__ == "__main__":
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
"""Connectors for loading data from Pandas dataframes."""

import sys
from typing import Dict
from typing import Any, Dict, Iterator, Optional

from absl import logging
import numpy as np

from ydf.dataset.io import dataset_io_types
from ydf.dataset.io import generator as generator_lib


def import_pd():
Expand Down Expand Up @@ -66,3 +68,46 @@ def clean(values):
data_dict = {k: clean(v) for k, v in data_dict.items()}

return data_dict


class PandasBatchedExampleGenerator(generator_lib.BatchedExampleGenerator):
"""Class to consume Pandas Dataframes."""

def __init__(self, dataframe: Any):
pd = import_pd()
assert isinstance(dataframe, pd.DataFrame)
self._dataframe = dataframe
super().__init__(num_examples=len(dataframe))

def generate(
self,
batch_size: int,
shuffle: bool,
seed: Optional[int] = None,
) -> Iterator[generator_lib.NumpyExampleBatch]:
assert self._num_examples is not None
if not shuffle:
i = 0
while i < self._num_examples:
begin_idx = i
end_idx = min(i + batch_size, self._num_examples)
yield {
str(k): v[begin_idx:end_idx].to_numpy()
for k, v in self._dataframe.items()
}
i += batch_size
else:
if seed is None:
raise ValueError("seed is required if shuffle=True")
rng = np.random.default_rng(seed)
idxs = rng.permutation(self._num_examples)
i = 0
while i < self._num_examples:
begin_idx = i
end_idx = min(i + batch_size, self._num_examples)
selected_idxs = idxs[begin_idx:end_idx]
yield {
str(k): v[selected_idxs].to_numpy()
for k, v in self._dataframe.items()
}
i += batch_size
Loading

0 comments on commit b43657d

Please sign in to comment.