-
Notifications
You must be signed in to change notification settings - Fork 3
/
data.py
90 lines (75 loc) · 3.16 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import datasets
import logging
import json
import pandas as pd
import torch
from torch.utils.data import DataLoader
def _get_data(file_path: str):
data = [json.loads(line) for line in open(file_path, "r")]
dataset = pd.DataFrame(data)
dataset = dataset[dataset.gold_label != "-"]
return dataset["sentence1"].tolist(), dataset["sentence2"].tolist(), dataset["gold_label"].tolist()
class NLIDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
# return len(self.labels)
return len(self.encodings.input_ids)
def get_nli_dataset(config, tokenizer):
if config.data_path:
train_premises, train_hypotheses, train_labels = _get_data(
config.data_path + "/train.jsonl"
)
logging.info(f"First training example: {train_premises[0]} --> {train_hypotheses[0]} ({train_labels[0]})")
dev_premises, dev_hypotheses, dev_labels = _get_data(
config.data_path + "/dev.jsonl"
)
logging.info(f"First dev example: {dev_premises[0]} --> {dev_hypotheses[0]} ({dev_labels[0]})")
test_premises, test_hypotheses, test_labels = _get_data(
config.data_path + "/test.jsonl"
)
logging.info(f"First test example: {test_premises[0]} --> {test_hypotheses[0]} ({test_labels[0]})")
else:
train_dataset = datasets.load_dataset(
"snli", config.train_language, split="train"
)
train_premises = train_dataset["premise"]
train_hypotheses = train_dataset["hypothesis"]
train_labels = train_dataset["label"]
dev_dataset = datasets.load_dataset(
"snli", config.test_language, split="validation"
)
dev_premises = dev_dataset["premise"]
dev_hypotheses = dev_dataset["hypothesis"]
dev_labels = dev_dataset["label"]
test_dataset = datasets.load_dataset("snli", config.test_language, split="test")
test_premises = test_dataset["premise"]
test_hypotheses = test_dataset["hypothesis"]
test_labels = test_dataset["label"]
train_encodings = tokenizer(
train_premises,
train_hypotheses,
truncation=True,
padding=True,
)
dev_encodings = tokenizer(
dev_premises, dev_hypotheses, truncation=True, padding=True
)
test_encodings = tokenizer(
test_premises,
test_hypotheses,
truncation=True,
padding=True,
)
train_dataset = NLIDataset(train_encodings, train_labels)
dev_dataset = NLIDataset(dev_encodings, dev_labels)
test_dataset = NLIDataset(test_encodings, test_labels)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=True)
return train_loader, dev_loader, test_loader