forked from liucongg/ChatGLM-Finetuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_set.py
64 lines (52 loc) · 2.66 KB
/
data_set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*- coding:utf-8 -*-
# @project: ChatGPT
# @filename: data_set
# @author: 刘聪NLP
# @zhihu: https://www.zhihu.com/people/LiuCongNLP
# @contact: [email protected]
# @time: 2023/4/4 14:42
"""
文件说明:
"""
import json
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
class Seq2SeqDataSet(Dataset):
"""数据处理函数"""
def __init__(self, data_path, tokenizer, max_len, max_src_len, prompt_text):
# prompt_text = "你现在是一个信息抽取模型,请你帮我抽取出关系内容为\"性能故障\", \"部件故障\", \"组成\"和 \"检测工具\"的相关三元组,三元组内部用\"_\"连接,三元组之间用\\n分割。文本:"
max_tgt_len = max_len - max_src_len - 3
self.all_data = []
with open(data_path, "r", encoding="utf-8") as fh:
for i, line in enumerate(fh):
sample = json.loads(line.strip())
src_tokens = tokenizer.tokenize(sample["text"])
prompt_tokens = tokenizer.tokenize(prompt_text)
if len(src_tokens) > max_src_len - len(prompt_tokens):
src_tokens = src_tokens[:max_src_len - len(prompt_tokens)]
tgt_tokens = tokenizer.tokenize(sample["answer"])
if len(tgt_tokens) > max_tgt_len:
tgt_tokens = tgt_tokens[:max_tgt_len]
tokens = prompt_tokens + src_tokens + ["[gMASK]", "<sop>"] + tgt_tokens + ["<eop>"]
input_ids = tokenizer.convert_tokens_to_ids(tokens)
context_length = input_ids.index(tokenizer.bos_token_id)
mask_position = context_length - 1
labels = [-100] * context_length + input_ids[mask_position + 1:]
pad_len = max_len - len(input_ids)
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
labels = labels + [-100] * pad_len
self.all_data.append(
{"text": sample["text"], "answer": sample["answer"], "input_ids": input_ids, "labels": labels})
def __len__(self):
return len(self.all_data)
def __getitem__(self, item):
instance = self.all_data[item]
return instance
def coll_fn(batch):
input_ids_list, labels_list = [], []
for instance in batch:
input_ids_list.append(torch.tensor(instance["input_ids"], dtype=torch.long))
labels_list.append(torch.tensor(instance["labels"], dtype=torch.long))
return {"input_ids": pad_sequence(input_ids_list, batch_first=True, padding_value=20003),
"labels": pad_sequence(labels_list, batch_first=True, padding_value=20003)}