-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy pathload_data.py
36 lines (30 loc) · 1.02 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# -*- coding: utf-8 -*-
import torch
import jieba
from torchtext.legacy import data
device = "cuda" if torch.cuda.is_available() else 'cpu'
#make sure the longest sentence in the bucket is no shorter than the biggest filter size.
def tokenizer(text):
token = [tok for tok in jieba.cut(text)]
if len(token) < 4:
for i in range(0, 4 - len(token)):
token.append('<pad>')
return token
TEXT = data.Field(sequential=True, tokenize=tokenizer)
LABEL = data.Field(sequential=False, use_vocab=False)
train, val = data.TabularDataset.splits(
path='../data/',
train='train.tsv',
validation='dev.tsv',
format='tsv',
skip_header=True,
fields=[('', None), ('label', LABEL), ('text', TEXT)])
TEXT.build_vocab(train, min_freq=5)
id2vocab = TEXT.vocab.itos
#print(TEXT.vocab.stoi)
#print(TEXT.vocab.itos)
train_iter, val_iter = data.BucketIterator.splits(
(train, val),
sort_key=lambda x: len(x.text),
batch_sizes=(256, 128),
device=device)