-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
94 lines (77 loc) · 3.27 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import math
import torch
from torch import Tensor
import io
import time
import os
import pandas as pd
import json
from datetime import datetime
import pickle
from pathlib import Path
from torch.utils.data import Dataset
from collections import Counter
from torch.nn.utils.rnn import pad_sequence
import torchtext
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import torch.nn.functional as F
from pathlib import Path
from . import data_selection
mms_directories = [
("mms-subset91", 'latin-1'),
("modified/location/mms", 'utf-8'),
("modified/platform/mms", 'utf-8'),
("modified/time/mms", 'utf-8'),
("modified/train_name/mms", 'utf-8'),
]
text_directories = [
("annotations_full/annotations", 'latin-1'),
("modified/location/text", 'utf-8'),
("modified/platform/text", 'utf-8'),
("modified/time/text", 'utf-8'),
("modified/train_name/text", 'utf-8'),
]
checkpoint = 'facebook/nllb-200-distilled-600M' #for nllb
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
def read():
data_list_only_original = []
data_list_only_modified = []
for i, text_info in enumerate(text_directories):
mms_info = mms_directories[i]
data_list_one = data_selection.read(text_info, mms_info)
if i <= 0:
data_list_only_original += data_list_one
else:
data_list_only_modified += data_list_one
data_list_full = data_list_only_original + data_list_only_modified
return (data_list_only_original, data_list_only_modified, data_list_full)
class SignLanguageDataset(Dataset):
def __init__(self, data_list, tokenizer, max_length=512):
self.data_list = data_list
self.tokenizer = tokenizer
self.max_length = max_length
self.vocab_size = len(tokenizer)
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
data = self.data_list[idx]
file_Id = data['file_ID']
text_tokens = self.tokenizer.encode(data['text'], add_special_tokens=True)
text_tokens = torch.tensor(text_tokens)
maingloss_tokens = self.tokenizer.encode(' '.join(data['maingloss']).lower(), add_special_tokens=True)
maingloss_tokens = torch.tensor(maingloss_tokens)
return file_Id, text_tokens, maingloss_tokens
def collate_fn(batch):
file_Id, text_tokens, maingloss_tokens = zip(*batch)
padding_value = tokenizer.pad_token_id # here for nllb paddign token is 1
text_tokens_padded = torch.nn.utils.rnn.pad_sequence(text_tokens, batch_first=True, padding_value=padding_value)
maingloss_tokens_padded = torch.nn.utils.rnn.pad_sequence(maingloss_tokens, batch_first=True, padding_value=padding_value)
# Ensure all have the same sequence length
max_len = max(text_tokens_padded.size(1), maingloss_tokens_padded.size(1))
text_tokens_padded = torch.nn.functional.pad(text_tokens_padded, (0, max_len - text_tokens_padded.size(1)), value=padding_value)
maingloss_tokens_padded = torch.nn.functional.pad(maingloss_tokens_padded, (0, max_len - maingloss_tokens_padded.size(1)), value=padding_value)
return file_Id, text_tokens_padded, maingloss_tokens_padded