-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_loader.py
45 lines (37 loc) · 1.34 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import json
import random
import datasets
from datasets import load_dataset, Dataset, concatenate_datasets
from utils import load_jsonl, lower_keys
def load_data(data_name, split, data_dir="./dataset"):
data_file = f"{data_dir}/{data_name}/{split}.jsonl"
if os.path.exists(data_file):
examples = list(load_jsonl(data_file))
else:
if data_name == "math":
dataset = load_dataset(
"competition_math",
split=split,
name="main",
cache_dir=f"{data_dir}/temp",
)
elif data_name == "gsm8k":
dataset = load_dataset(data_name, split=split)
else:
raise NotImplementedError(data_name)
examples = list(dataset)
examples = [lower_keys(example) for example in examples]
dataset = Dataset.from_list(examples)
os.makedirs(f"{data_dir}/{data_name}", exist_ok=True)
dataset.to_json(data_file)
# add 'idx' in the first column
if "idx" not in examples[0]:
examples = [{"idx": i, **example} for i, example in enumerate(examples)]
# dedepulicate & sort
examples = sorted(examples, key=lambda x: x["idx"])
return examples
if __name__ == "__main__":
test_dataset = load_data('math','test')
print(test_dataset[0])
print(len(test_dataset))