-
Notifications
You must be signed in to change notification settings - Fork 39
/
seq2seq_utils.py
executable file
·131 lines (105 loc) · 4.89 KB
/
seq2seq_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import logging
import os
import pickle
from multiprocessing import Pool
from typing import Tuple
import pandas as pd
import torch
from tokenizers.implementations import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer
logger = logging.getLogger(__name__)
def preprocess_data(data):
input_text, target_text, encoder_tokenizer, decoder_tokenizer, args = data
input_text = encoder_tokenizer.encode(
input_text, max_length=args.max_seq_length, pad_to_max_length=True, return_tensors="pt",
)
target_text = decoder_tokenizer.encode(
target_text, max_length=args.max_seq_length, pad_to_max_length=True, return_tensors="pt"
)
return (torch.flatten(input_text), torch.flatten(target_text))
class Seq2SeqDataset(Dataset):
def __init__(self, encoder_tokenizer, decoder_tokenizer, args, data, mode):
cached_features_file = os.path.join(
args.cache_dir, args.model_name + "_cached_" + str(args.max_seq_length) + str(len(data))
)
if os.path.exists(cached_features_file) and (
(not args.reprocess_input_data and not args.no_cache)
or (mode == "dev" and args.use_cached_eval_features and not args.no_cache)
):
logger.info(" Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as handle:
self.examples = pickle.load(handle)
else:
logger.info(" Creating features from dataset file at %s", args.cache_dir)
data = [
(input_text, target_text, encoder_tokenizer, decoder_tokenizer, args)
for input_text, target_text in zip(data["input_text"], data["target_text"])
]
if args.use_multiprocessing:
with Pool(args.process_count) as p:
self.examples = list(
tqdm(
p.imap(preprocess_data, data, chunksize=args.multiprocessing_chunksize),
total=len(data),
disable=args.silent,
)
)
else:
self.examples = [preprocess_data(d) for d in tqdm(data, disable=args.silent)]
logger.info(" Saving features into cached file %s", cached_features_file)
with open(cached_features_file, "wb") as handle:
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
return self.examples[index]
def preprocess_data_bart(data):
input_text, target_text, tokenizer, args = data
input_ids = tokenizer.batch_encode_plus(
[input_text], max_length=args.max_seq_length, padding='max_length', truncation=True, return_tensors="pt",
)
target_ids = tokenizer.batch_encode_plus(
[target_text], max_length=args.max_seq_length, padding='max_length', truncation=True, return_tensors="pt"
)
return {
"source_ids": input_ids["input_ids"].squeeze(),
"source_mask": input_ids["attention_mask"].squeeze(),
"target_ids": target_ids["input_ids"].squeeze(),
}
class SimpleSummarizationDataset(Dataset):
def __init__(self, tokenizer, args, data, mode):
self.tokenizer = tokenizer
cached_features_file = os.path.join(
args.cache_dir, args.model_name + "_cached_" + str(args.max_seq_length) + str(len(data))
)
if os.path.exists(cached_features_file) and (
(not args.reprocess_input_data and not args.no_cache)
or (mode == "dev" and args.use_cached_eval_features and not args.no_cache)
):
logger.info(" Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as handle:
self.examples = pickle.load(handle)
else:
logger.info(" Creating features from dataset file at %s", args.cache_dir)
data = [
(input_text, target_text, tokenizer, args)
for input_text, target_text in zip(data["input_text"], data["target_text"])
]
if args.use_multiprocessing:
with Pool(args.process_count) as p:
self.examples = list(
tqdm(
p.imap(preprocess_data_bart, data, chunksize=args.multiprocessing_chunksize),
total=len(data),
disable=args.silent,
)
)
else:
self.examples = [preprocess_data_bart(d) for d in tqdm(data, disable=args.silent)]
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
return self.examples[index]