diff --git a/docs/code/all_drivers.py b/docs/code/all_drivers.py index 17975e2..73d0b09 100644 --- a/docs/code/all_drivers.py +++ b/docs/code/all_drivers.py @@ -2,6 +2,7 @@ HubDriver, HuggingfaceDriver, TorchvisionDriver, + # SklearnDriver ) HuggingfaceDriver("cifar100").get_iter("train").take(1).map(print).join() diff --git a/examples/11.TimeSeries_With_Spark_and_Squirrel.ipynb b/examples/11.TimeSeries_With_Spark_and_Squirrel.ipynb new file mode 100644 index 0000000..3b4d205 --- /dev/null +++ b/examples/11.TimeSeries_With_Spark_and_Squirrel.ipynb @@ -0,0 +1,263 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "QcdKnyz6cmMf" + }, + "source": [ + "# Squirrel for Timeseries data \n", + "\n", + "Squirrel also handles timeseries data or any form **ordered** data. However, a few modifications are needed here to maintain the ordereness after storing. \n", + "\n", + "In this notebook we will show two possible approaches for storing and loading timeseries. The first one utilizes squirrel-native functionalities and the second one makes use of **Squirrel** and **Spark**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LQbGHY75apkJ", + "outputId": "700c326c-27fc-4cbd-c47f-2ff6ecce2283" + }, + "outputs": [], + "source": [ + "!pip install squirrel-core pyspark\n", + "!pip install more-itertools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7jFgTFT-bV7Q" + }, + "outputs": [], + "source": [ + "import typing as t\n", + "import tempfile\n", + "import numpy as np\n", + "import pickle\n", + "\n", + "from random import randint\n", + "from functools import partial\n", + "\n", + "from pyspark.sql import SparkSession\n", + "from squirrel.driver import MessagepackDriver\n", + "from squirrel.store import SquirrelStore\n", + "from squirrel.serialization import MessagepackSerializer\n", + "from squirrel.iterstream import IterableSource, Composable" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xX_UjNdzfbsQ" + }, + "outputs": [], + "source": [ + "# Some utility functions to generate timeseries samples and verify the ordereness\n", + "\n", + "\n", + "def generate_timeseries_samples(N: int):\n", + " \"\"\"Generate timeseries\"\"\"\n", + " for _ in range(N):\n", + " yield {\"time_stamp_sec\": randint(0, 1e6), \"data\": pickle.dumps(np.random.rand(2, 2))}\n", + "\n", + "\n", + "def is_ordered(li: t.List[t.Dict], key=None) -> bool:\n", + " \"\"\"Test if the list is ordered according to a key in l\"\"\"\n", + " return all(li[i].get(key) <= li[i + 1].get(key) for i in range(len(li) - 1))\n", + "\n", + "\n", + "# Constants shared between the experiments\n", + "N_SHARDS = 10\n", + "N = int(1e4)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HzyAhF2PaybX" + }, + "source": [ + "## Timeseries with Squirrel-native\n", + "\n", + "The only thing we changed here is to use `zip_index` to obtain a key for storing the data. This integer key is formated as a string padded with zeros, if it does not have the same number of digits as given by `pad_length`. `zip_index` returns an iterable over tuples where the first item is the index, and the second the item is the value.\n", + "\n", + "This sounds cumbersome at first, but to preserve the order we sort by the shard keys before yielding them. As shard keys are used as filenames, keys are sorted as strings. A key with `11` will be then sorted before `9`, when sorting in ascending order. For this reason, the key is padded with zeros." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yI7UOZiRbyMM" + }, + "outputs": [], + "source": [ + "# Note that we assume the data is already sorted, we just guarantee that the data\n", + "# remain sorted\n", + "samples_list = list(sorted(generate_timeseries_samples(N), key=lambda x: x[\"time_stamp_sec\"]))\n", + "samples = IterableSource(samples_list)\n", + "\n", + "with tempfile.TemporaryDirectory() as tempdir:\n", + " # Write to a new cleaned store\n", + " store = SquirrelStore(url=str(tempdir), serializer=MessagepackSerializer(), clean=True)\n", + " samples.batched(N_SHARDS).zip_index(pad_length=9).map(lambda x: store.set(key=x[0], value=x[1])).join()\n", + " # Read\n", + " driver = MessagepackDriver(url=str(tempdir))\n", + " retrieved = driver.get_iter()\n", + " assert is_ordered(retrieved.collect(), key=\"time_stamp_sec\")\n", + " assert len(retrieved.collect()) == N" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "63Vi2C05rgIX" + }, + "source": [ + "## Timeseries with Squirrel and Spark \n", + "\n", + "We leverage Spark here to sort the time-series. Spark is useful, when your data does not fit entirely into memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kgVbs6cChNQT" + }, + "outputs": [], + "source": [ + "def save_iterable_as_shard(it, store, pad_len=10) -> None:\n", + " \"\"\"Helper to save a shard into a messagepack store using squirrel.\"\"\"\n", + " it_list = list(it)\n", + " if len(it_list) > 0:\n", + " # use the earliest time_stamp as key\n", + " smallest_timestamp = str(it_list[0][\"time_stamp_sec\"])\n", + " # pad the key similar to zip_index()\n", + " key = \"0\" * (pad_len - len(smallest_timestamp)) + smallest_timestamp\n", + " store.set(value=it_list, key=key)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XoXypb7efQXE" + }, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oEvVWN7lwGNg" + }, + "outputs": [], + "source": [ + "samples = IterableSource(generate_timeseries_samples(N))\n", + "# Initiate Spark\n", + "spark = SparkSession.builder.appName(\"test\").getOrCreate()\n", + "rdd = spark.sparkContext.parallelize(samples)\n", + "# Sort\n", + "rdd = rdd.repartition(N_SHARDS).sortBy(lambda x: x[\"time_stamp_sec\"])\n", + "\n", + "with tempfile.TemporaryDirectory() as tempdir:\n", + " # Store into a new store\n", + " store = SquirrelStore(url=str(tempdir), serializer=MessagepackSerializer(), clean=True)\n", + "\n", + " rdd.foreachPartition(partial(save_iterable_as_shard, store=store))\n", + "\n", + " # Read\n", + " driver = MessagepackDriver(url=str(tempdir))\n", + " retrieved = driver.get_iter().collect()\n", + "\n", + " assert len(retrieved) == N\n", + " assert is_ordered(retrieved, key=\"time_stamp_sec\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cgrnE61Vza3R" + }, + "source": [ + "We can also sort the data with Spark during loading. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "itwOpKK9z02k", + "outputId": "57d012ab-2d4d-4062-84da-f7aa78c324c9" + }, + "outputs": [], + "source": [ + "class SparkSource(Composable):\n", + " def __init__(self, url: str, sort_callback):\n", + " \"\"\"Define a helper class to encapsulate an Iterator over Spark contents\"\"\"\n", + " self.url = url\n", + " self.sort_callback = sort_callback\n", + " self.spark = SparkSession.builder.appName(\"test\").getOrCreate()\n", + "\n", + " def __iter__(self):\n", + " store = SquirrelStore(url=str(tempdir), serializer=MessagepackSerializer())\n", + " keys = list(store.keys())\n", + " # Here we do the sorting\n", + " rdd = self.spark.sparkContext.parallelize(keys).map(lambda k: list(store.get(k))).flatMap(lambda x: x)\n", + " rdd = rdd.sortBy(self.sort_callback)\n", + " for item in rdd.toLocalIterator():\n", + " yield item\n", + "\n", + "\n", + "# unsorted data\n", + "samples = IterableSource(generate_timeseries_samples(N))\n", + "\n", + "with tempfile.TemporaryDirectory() as tempdir:\n", + " print(tempdir)\n", + " # Write\n", + " store = SquirrelStore(url=tempdir, serializer=MessagepackSerializer(), clean=True)\n", + " samples.batched(N_SHARDS).map(store.set).join()\n", + "\n", + " # Read\n", + " spark_iterable = SparkSource(tempdir, lambda x: x[\"time_stamp_sec\"]).collect()\n", + " assert is_ordered(spark_iterable, key=\"time_stamp_sec\")\n", + " assert len(spark_iterable) == N" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/12.Split_Data_Into_Different_Stores.ipynb b/examples/12.Split_Data_Into_Different_Stores.ipynb new file mode 100644 index 0000000..ea4e660 --- /dev/null +++ b/examples/12.Split_Data_Into_Different_Stores.ipynb @@ -0,0 +1,126 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Xd4DVhIQUXG-" + }, + "source": [ + "Sometimes, it is usefull to store your data into different stores based on a categorical label of your data. In this notebook, we demonstrate how this can be done using the additional help of Spark." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LQbGHY75apkJ", + "outputId": "1b5150a5-da94-4dba-c18d-fecfc101f1f4" + }, + "outputs": [], + "source": [ + "!pip install squirrel-core pyspark\n", + "!pip install more-itertools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7jFgTFT-bV7Q" + }, + "outputs": [], + "source": [ + "import tempfile\n", + "from random import randint\n", + "from functools import partial\n", + "from pyspark.sql import SparkSession\n", + "from squirrel.store import SquirrelStore\n", + "from squirrel.serialization import MessagepackSerializer\n", + "from squirrel.iterstream import IterableSource, FilePathGenerator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OmcD5juDTE74", + "outputId": "e4049ff8-2923-414a-ec7e-49cb59c63ab6" + }, + "outputs": [], + "source": [ + "def generate_categorical_samples(N, C):\n", + " \"\"\"Generate data where the uid field is used as a categorical label to split\"\"\"\n", + " return [{\"uid\": randint(1, C), \"data\": 0} for _ in range(N)]\n", + "\n", + "\n", + "def save_shards(tuple_, shard_size, uri):\n", + " \"\"\"Used as a partial function to save the data into a different store based on the uid\"\"\"\n", + " key = tuple_[0]\n", + " store = SquirrelStore(url=f\"{uri}/{key}\", serializer=MessagepackSerializer())\n", + " iterab = tuple_[1]\n", + " store.set(value=iterab, key=key)\n", + "\n", + "\n", + "N_SHARDS = 50\n", + "N = 100_000\n", + "C = 10\n", + "# Generate samples\n", + "samples = IterableSource(generate_categorical_samples(N))\n", + "\n", + "# Initiate Spark\n", + "spark = SparkSession.builder.appName(\"test\").getOrCreate()\n", + "rdd = spark.sparkContext.parallelize(samples)\n", + "with tempfile.TemporaryDirectory() as tempdir:\n", + "\n", + " def to_list(a):\n", + " return [a]\n", + "\n", + " def append(a, b):\n", + " a.append(b)\n", + " return a\n", + "\n", + " def extend(a, b):\n", + " a.extend(b)\n", + " return a\n", + "\n", + " _ = (\n", + " rdd.map(lambda x: (x[\"uid\"], x))\n", + " .combineByKey(to_list, append, extend)\n", + " .foreach(partial(save_shards, uri=f\"{tempdir}\", shard_size=100))\n", + " )\n", + " # We can see that each uid now has its own storage URI\n", + " print(FilePathGenerator(tempdir, nested=True).collect())" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/requirements.in.sklearn b/requirements.in.sklearn new file mode 100644 index 0000000..ff88936 --- /dev/null +++ b/requirements.in.sklearn @@ -0,0 +1 @@ +scikit-learn \ No newline at end of file diff --git a/src/squirrel_datasets_core/datasets/sklearn/__init__.py b/src/squirrel_datasets_core/datasets/sklearn/__init__.py new file mode 100644 index 0000000..2048431 --- /dev/null +++ b/src/squirrel_datasets_core/datasets/sklearn/__init__.py @@ -0,0 +1,12 @@ +# Exemplary sources for the torchvision driver + +from squirrel.catalog import Source, CatalogKey + +from squirrel_datasets_core.driver.sklearn import TOY_DATASETS, REAL_DATASETS + +__all__ = ["SOURCES"] + +SKLEARN_DATASETS = TOY_DATASETS + REAL_DATASETS + +SOURCES = [(CatalogKey(f"{name}_sklearn", 1), Source(driver_name="sklearn", driver_kwargs={"name": name})) + for name in SKLEARN_DATASETS] diff --git a/src/squirrel_datasets_core/driver/__init__.py b/src/squirrel_datasets_core/driver/__init__.py index 272a51d..433a885 100644 --- a/src/squirrel_datasets_core/driver/__init__.py +++ b/src/squirrel_datasets_core/driver/__init__.py @@ -1,5 +1,6 @@ from squirrel_datasets_core.driver.huggingface import HuggingfaceDriver from squirrel_datasets_core.driver.torchvision import TorchvisionDriver from squirrel_datasets_core.driver.hub import HubDriver +from squirrel_datasets_core.driver.sklearn import SklearnDriver -__all__ = ["HuggingfaceDriver", "TorchvisionDriver", "HubDriver"] +__all__ = ["HuggingfaceDriver", "TorchvisionDriver", "HubDriver", "SklearnDriver"] diff --git a/src/squirrel_datasets_core/driver/sklearn.py b/src/squirrel_datasets_core/driver/sklearn.py new file mode 100644 index 0000000..dda5ddc --- /dev/null +++ b/src/squirrel_datasets_core/driver/sklearn.py @@ -0,0 +1,71 @@ +import inspect +import logging +from tempfile import gettempdir +from typing import Optional, List, Dict + +from sklearn import datasets +from pandas import DataFrame +from squirrel.driver import IterDriver +from squirrel.iterstream import IterableSource +from squirrel.catalog.catalog import Catalog + +logger = logging.getLogger(__name__) + +__all__ = ["SklearnDriver"] + +TOY_DATASETS = ["_".join(name.split('_')[1:]) + for name, _ in inspect.getmembers(datasets) if name.startswith("load_")] + +REAL_DATASETS = ["_".join(name.split('_')[1:]) + for name, _ in inspect.getmembers(datasets) if name.startswith("fetch_")] + +class SklearnDriver(IterDriver): + name = "sklearn" + + def __init__(self, name: str, + data_home: Optional[str] = None, + download_if_missing: Optional[bool] = True, + catalog: Optional[Catalog] = None) -> None: + """A driver for sklearn datasets. There toy datasets, which are internal to the package + and real-world datasets which needs to be downloaded + + Args: + name (str): name of the dataset. Visit https://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets or + call get_dataset_names() method + data_home (Optional[str], optional): location where the dataset is downloaded. Defaults to ‘~/scikit_learn_data’ + download_if_missing (Optional[bool], optional): Only download if this is true, otherwise throw an IOError. Defaults to True. + catalog (Optional[Catalog], optional): default argument for catalog registration. Defaults to None. + """ + super().__init__(catalog) + self.name = name + self._data = self._get_data(data_home, download_if_missing) + self._data_home = data_home + + def _get_data(self, + data_home: Optional[str] = None, + download_if_missing: Optional[bool] = True) -> DataFrame: + """Return a pandas dataframe with input and target.""" + if self.name in TOY_DATASETS: + load_ds = getattr(datasets, f"load_{self.name}") + return load_ds() + elif self.name in REAL_DATASETS: + load_ds = getattr(datasets, f"fetch_{self.name}") + return load_ds(data_home=data_home, download_if_missing=download_if_missing) + else: + raise ValueError(f"Dataset {self.name} is not available. Use one of these {TOY_DATASETS + REAL_DATASETS}") + + def get_iter(self) -> IterableSource: + """Returns a iterstream over the dataframe""" + return IterableSource(zip(self._data["data"], self._data["target"])) + + def get_info(self) -> Dict[str, str]: + info = {"description": self._data.get("DESCR", "No description available"), + "feature_names": self._data.get("feature_names", "no feature names available"), + "target_names": self._data.get("target_names", "no target names available")} + return info + + @staticmethod + def get_dataset_names() -> Dict[str, List[str]]: + """Returns a list of available datasets""" + return {"toy": TOY_DATASETS, "real": REAL_DATASETS} + \ No newline at end of file diff --git a/src/squirrel_datasets_core/squirrel_plugin.py b/src/squirrel_datasets_core/squirrel_plugin.py index 4bca4c4..12f2ce2 100644 --- a/src/squirrel_datasets_core/squirrel_plugin.py +++ b/src/squirrel_datasets_core/squirrel_plugin.py @@ -6,6 +6,11 @@ from squirrel.driver import Driver from squirrel.framework.plugins.hookimpl import hookimpl +def get_sklearn_driver() -> Driver: + """Imports and returns the sklearn driver class""" + from squirrel_datasets_core.driver.sklearn import SklearnDriver + + return SklearnDriver def get_hub_driver() -> Driver: """Imports and returns the hub driver class""" @@ -38,6 +43,7 @@ def squirrel_drivers() -> List[Type[Driver]]: "hub": get_hub_driver, "huggingface": get_huggingface_driver, "torchvision": get_torchvision_driver, + "sklearn": get_sklearn_driver } for d in add_drivers: diff --git a/test/test_datasets/test_sklearn_ds.py b/test/test_datasets/test_sklearn_ds.py new file mode 100644 index 0000000..77dfde3 --- /dev/null +++ b/test/test_datasets/test_sklearn_ds.py @@ -0,0 +1,48 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest +from squirrel.catalog import Catalog + +from squirrel_datasets_core.driver.sklearn import SklearnDriver, TOY_DATASETS + +TAKE = 10 + +def test_get_dataset_names() -> None: + names = SklearnDriver.get_dataset_names() + assert len(names) > 0 + +@pytest.mark.parametrize("cat_key", TOY_DATASETS) +def test_stuff(plugin_catalog: Catalog, cat_key: str) -> None: + stuff = plugin_catalog[f"{cat_key}_sklearn"].get_driver()._get_data() + raise ValueError(stuff) + +@pytest.mark.parametrize("cat_key", TOY_DATASETS) +def test_toy_data(plugin_catalog: Catalog, cat_key: str) -> None: + """Test toydatasets which are internal to sklearn.""" + sample = plugin_catalog[f"{cat_key}_sklearn"].get_driver().get_iter().take(TAKE).collect() + assert len(sample) > 0 + +@pytest.mark.skip(reason="Dataset is on public storage.") +@pytest.mark.parametrize("cat_key", ["20newsgroups", "california_housing"]) +def test_real_data(plugin_catalog: Catalog, cat_key: str) -> None: + sample = plugin_catalog[f"{cat_key}_sklearn"].get_driver().get_iter().take(TAKE).collect() + assert len(sample) > 0 + +@pytest.mark.skip(reason="Dataset is on public storage.") +@pytest.mark.parametrize("cat_key", ["olivetti_faces"]) +def test_sklearn_data_downloading_exception(plugin_catalog: Catalog, cat_key: str) -> None: + """Test that downloading data fails via if undesired.""" + with pytest.raises(Exception): + plugin_catalog[cat_key].get_driver(download_if_missing=False).get_iter().take(TAKE).collect() + +@pytest.mark.skip(reason="Dataset is on public storage.") +@pytest.mark.parametrize("cat_key", ["california_housing", "olivetti_faces"]) +def test_california_housing_data_downloading(plugin_catalog: Catalog, cat_key: str) -> None: + """Test that we can download data and yield from it""" + with TemporaryDirectory() as temp_dir: + sample = plugin_catalog[f"{cat_key}_sklearn"].get_driver(data_home=temp_dir, + download_if_missing=True).get_iter().take(TAKE).collect() + assert len(sample) > 0 + assert len(list(Path(temp_dir).glob("*"))) > 0 +