-
Notifications
You must be signed in to change notification settings - Fork 108
/
load_data.py
32 lines (26 loc) · 855 Bytes
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# -*- coding: utf-8 -*-
import torch
import jieba
from torchtext.legacy import data
device = "cuda" if torch.cuda.is_available() else 'cpu'
def tokenizer(text):
token = [tok for tok in jieba.cut(text)]
return token
TEXT = data.Field(sequential=True, tokenize=tokenizer)
LABEL = data.Field(sequential=False, use_vocab=False)
train, val = data.TabularDataset.splits(
path='../data/',
train='train.tsv',
validation='dev.tsv',
format='tsv',
skip_header=True,
fields=[('', None), ('label', LABEL), ('text', TEXT)])
TEXT.build_vocab(train, min_freq=5)
id2vocab = TEXT.vocab.itos
#print(TEXT.vocab.stoi)
#print(TEXT.vocab.itos)
train_iter, val_iter = data.BucketIterator.splits(
(train, val),
sort_key=lambda x: len(x.text),
batch_sizes=(256, 128),
device=device)