Skip to content

Commit

Permalink
feat: add from_parquet to dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesHWade committed Feb 27, 2024
1 parent a4a3397 commit 25b7e5f
Showing 1 changed file with 53 additions and 37 deletions.
90 changes: 53 additions & 37 deletions dspy/datasets/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,100 @@
import dspy
import random
from dspy.datasets import Dataset
from collections.abc import Mapping
from typing import Union

from datasets import load_dataset
from typing import Union, List, Mapping, Tuple

import dspy
from dspy.datasets import Dataset


class DataLoader(Dataset):
def __init__(self,):
def __init__(
self,
):
pass

def from_huggingface(
self,
dataset_name: str,
*args,
input_keys: Tuple[str] = (),
fields: Tuple[str] = None,
**kwargs
) -> Union[Mapping[str, List[dspy.Example]], List[dspy.Example]]:
input_keys: tuple[str] = (),
fields: tuple[str] = None,
**kwargs,
) -> Union[Mapping[str, list[dspy.Example]], list[dspy.Example]]:
if fields and not isinstance(fields, tuple):
raise ValueError(f"Invalid fields provided. Please provide a tuple of fields.")
raise ValueError("Invalid fields provided. Please provide a tuple of fields.")

if not isinstance(input_keys, tuple):
raise ValueError(f"Invalid input keys provided. Please provide a tuple of input keys.")
raise ValueError("Invalid input keys provided. Please provide a tuple of input keys.")

dataset = load_dataset(dataset_name, *args, **kwargs)

if isinstance(dataset, list) and isinstance(kwargs["split"], list):
dataset = {split_name:dataset[idx] for idx, split_name in enumerate(kwargs["split"])}
dataset = {split_name: dataset[idx] for idx, split_name in enumerate(kwargs["split"])}

try:
returned_split = {}
for split_name in dataset.keys():
if fields:
returned_split[split_name] = [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset[split_name]]
returned_split[split_name] = [
dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys)
for row in dataset[split_name]
]
else:
returned_split[split_name] = [dspy.Example({field:row[field] for field in row.keys()}).with_inputs(input_keys) for row in dataset[split_name]]
returned_split[split_name] = [
dspy.Example({field: row[field] for field in row.keys()}).with_inputs(input_keys)
for row in dataset[split_name]
]

return returned_split
except AttributeError:
if fields:
return [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset]
return [
dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset
]
else:
return [dspy.Example({field:row[field] for field in row.keys()}).with_inputs(input_keys) for row in dataset]
return [
dspy.Example({field: row[field] for field in row.keys()}).with_inputs(input_keys) for row in dataset
]

def from_csv(self, file_path:str, fields: List[str] = None, input_keys: Tuple[str] = ()) -> List[dspy.Example]:
def from_csv(self, file_path: str, fields: list[str] = None, input_keys: tuple[str] = ()) -> list[dspy.Example]:
dataset = load_dataset("csv", data_files=file_path)["train"]

if not fields:
fields = list(dataset.features)

return [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset]

def from_json(self, file_path:str, fields: List[str] = None, input_keys: Tuple[str] = ()) -> List[dspy.Example]:
return [dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset]

def from_json(self, file_path: str, fields: list[str] = None, input_keys: tuple[str] = ()) -> list[dspy.Example]:
dataset = load_dataset("json", data_files=file_path)["train"]

if not fields:
fields = list(dataset.features)

return [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset]

return [dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset]

def sample(
self,
dataset: List[dspy.Example],
n: int,
*args,
**kwargs
) -> List[dspy.Example]:
def from_parquet(self, file_path: str, fields: list[str] = None, input_keys: tuple[str] = ()) -> list[dspy.Example]:
dataset = load_dataset("parquet", data_files=file_path)["train"]

if not fields:
fields = list(dataset.features)

return [dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset]

def sample(self, dataset: list[dspy.Example], n: int, *args, **kwargs) -> list[dspy.Example]:
if not isinstance(dataset, list):
raise ValueError(f"Invalid dataset provided of type {type(dataset)}. Please provide a list of examples.")

return random.sample(dataset, n, *args, **kwargs)

def train_test_split(
self,
dataset: List[dspy.Example],
dataset: list[dspy.Example],
train_size: Union[int, float] = 0.75,
test_size: Union[int, float] = None,
random_state: int = None
) -> Mapping[str, List[dspy.Example]]:
random_state: int = None,
) -> Mapping[str, list[dspy.Example]]:
if random_state is not None:
random.seed(random_state)

Expand All @@ -105,6 +121,6 @@ def train_test_split(
test_end = len(dataset_shuffled) - train_end

train_dataset = dataset_shuffled[:train_end]
test_dataset = dataset_shuffled[train_end:train_end + test_end]
test_dataset = dataset_shuffled[train_end : train_end + test_end]

return {'train': train_dataset, 'test': test_dataset}
return {"train": train_dataset, "test": test_dataset}

0 comments on commit 25b7e5f

Please sign in to comment.