diff --git a/dspy/datasets/dataloader.py b/dspy/datasets/dataloader.py index b4593bd8af..1e1f0dbed5 100644 --- a/dspy/datasets/dataloader.py +++ b/dspy/datasets/dataloader.py @@ -1,84 +1,100 @@ -import dspy import random -from dspy.datasets import Dataset +from collections.abc import Mapping +from typing import Union from datasets import load_dataset -from typing import Union, List, Mapping, Tuple + +import dspy +from dspy.datasets import Dataset + class DataLoader(Dataset): - def __init__(self,): + def __init__( + self, + ): pass def from_huggingface( self, dataset_name: str, *args, - input_keys: Tuple[str] = (), - fields: Tuple[str] = None, - **kwargs - ) -> Union[Mapping[str, List[dspy.Example]], List[dspy.Example]]: + input_keys: tuple[str] = (), + fields: tuple[str] = None, + **kwargs, + ) -> Union[Mapping[str, list[dspy.Example]], list[dspy.Example]]: if fields and not isinstance(fields, tuple): - raise ValueError(f"Invalid fields provided. Please provide a tuple of fields.") + raise ValueError("Invalid fields provided. Please provide a tuple of fields.") if not isinstance(input_keys, tuple): - raise ValueError(f"Invalid input keys provided. Please provide a tuple of input keys.") + raise ValueError("Invalid input keys provided. Please provide a tuple of input keys.") dataset = load_dataset(dataset_name, *args, **kwargs) - + if isinstance(dataset, list) and isinstance(kwargs["split"], list): - dataset = {split_name:dataset[idx] for idx, split_name in enumerate(kwargs["split"])} + dataset = {split_name: dataset[idx] for idx, split_name in enumerate(kwargs["split"])} try: returned_split = {} for split_name in dataset.keys(): if fields: - returned_split[split_name] = [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset[split_name]] + returned_split[split_name] = [ + dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) + for row in dataset[split_name] + ] else: - returned_split[split_name] = [dspy.Example({field:row[field] for field in row.keys()}).with_inputs(input_keys) for row in dataset[split_name]] + returned_split[split_name] = [ + dspy.Example({field: row[field] for field in row.keys()}).with_inputs(input_keys) + for row in dataset[split_name] + ] return returned_split except AttributeError: if fields: - return [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset] + return [ + dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset + ] else: - return [dspy.Example({field:row[field] for field in row.keys()}).with_inputs(input_keys) for row in dataset] + return [ + dspy.Example({field: row[field] for field in row.keys()}).with_inputs(input_keys) for row in dataset + ] - def from_csv(self, file_path:str, fields: List[str] = None, input_keys: Tuple[str] = ()) -> List[dspy.Example]: + def from_csv(self, file_path: str, fields: list[str] = None, input_keys: tuple[str] = ()) -> list[dspy.Example]: dataset = load_dataset("csv", data_files=file_path)["train"] - + if not fields: fields = list(dataset.features) - - return [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset] - def from_json(self, file_path:str, fields: List[str] = None, input_keys: Tuple[str] = ()) -> List[dspy.Example]: + return [dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset] + + def from_json(self, file_path: str, fields: list[str] = None, input_keys: tuple[str] = ()) -> list[dspy.Example]: dataset = load_dataset("json", data_files=file_path)["train"] - + if not fields: fields = list(dataset.features) - - return [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset] + return [dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset] - def sample( - self, - dataset: List[dspy.Example], - n: int, - *args, - **kwargs - ) -> List[dspy.Example]: + def from_parquet(self, file_path: str, fields: list[str] = None, input_keys: tuple[str] = ()) -> list[dspy.Example]: + dataset = load_dataset("parquet", data_files=file_path)["train"] + + if not fields: + fields = list(dataset.features) + + return [dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset] + + def sample(self, dataset: list[dspy.Example], n: int, *args, **kwargs) -> list[dspy.Example]: if not isinstance(dataset, list): raise ValueError(f"Invalid dataset provided of type {type(dataset)}. Please provide a list of examples.") - + return random.sample(dataset, n, *args, **kwargs) def train_test_split( self, - dataset: List[dspy.Example], + dataset: list[dspy.Example], train_size: Union[int, float] = 0.75, test_size: Union[int, float] = None, - random_state: int = None - ) -> Mapping[str, List[dspy.Example]]: + random_state: int = None, + ) -> Mapping[str, list[dspy.Example]]: if random_state is not None: random.seed(random_state) @@ -105,6 +121,6 @@ def train_test_split( test_end = len(dataset_shuffled) - train_end train_dataset = dataset_shuffled[:train_end] - test_dataset = dataset_shuffled[train_end:train_end + test_end] + test_dataset = dataset_shuffled[train_end : train_end + test_end] - return {'train': train_dataset, 'test': test_dataset} + return {"train": train_dataset, "test": test_dataset}