-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
186 lines (149 loc) · 6.49 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import collections
import random
from torch.utils.data import Dataset
from utils import sample_negative, batch_sample_negative
def load_user_seqs(dp):
user2seq_map = collections.defaultdict(list)
user2ts_map = collections.defaultdict(list)
with open(dp) as f:
for line in f:
user, item, ts = line.rstrip().split(',')
user, item, ts = int(user), int(item), int(ts)
user2seq_map[user].append(item)
user2ts_map[user].append(ts)
return user2seq_map, user2ts_map
def partition_dataset(dp, tr=3, test_ratio=0.5):
user2seq_map, user2ts_map = load_user_seqs(dp)
num_user = len(user2seq_map)
num_item = max([max(seq) for seq in user2seq_map.values()])
# users = list(user2seq_map.keys())
pred_users = {user for user, seq in user2seq_map.items() if len(seq) >= tr}
test_users = set(random.sample(pred_users, k=int(len(pred_users) * test_ratio)))
valid_users = pred_users.difference(test_users)
history_user2seq_map = dict()
history_user2ts_map = dict()
valid_user2seq_map = dict()
valid_user2ts_map = dict()
test_user2seq_map = dict()
test_user2ts_map = dict()
for user, seq in user2seq_map.items():
ts_seq = user2ts_map[user]
if user in valid_users:
history_user2seq_map[user] = seq[:-1]
history_user2ts_map[user] = ts_seq[:-1]
valid_user2seq_map[user] = seq[-1:]
valid_user2ts_map[user] = ts_seq[-1:]
elif user in test_users:
history_user2seq_map[user] = seq[:-1]
history_user2ts_map[user] = ts_seq[:-1]
test_user2seq_map[user] = seq[-1:]
test_user2ts_map[user] = ts_seq[-1:]
else:
history_user2seq_map[user] = seq
history_user2ts_map[user] = ts_seq
seq_maps = (history_user2seq_map,
history_user2ts_map,
valid_user2seq_map,
valid_user2ts_map,
test_user2seq_map,
test_user2ts_map)
return num_user, num_item, seq_maps
class TrainDataset(Dataset):
def __init__(self, num_item,
history_user2seq_map,
history_user2ts_map,
max_seq_len,
limit=2,
is_bert=True,
mask_ratio=0.2,
transform_fn=None):
self.num_item = num_item
self.is_bert = is_bert
# masking ratio only used for BERT-based models
self.mask_ratio = mask_ratio
self.user2seq_map = {user: seq[-max_seq_len:] for user, seq in history_user2seq_map.items()
if len(seq) >= limit}
self.user2ts_map = {user: history_user2ts_map[user][-max_seq_len:] for user in self.user2seq_map}
self.idx2user_map = {i: user for i, user in enumerate(self.user2seq_map)}
self.transform_fn = transform_fn
def __len__(self):
return len(self.idx2user_map)
def __getitem__(self, idx):
user = self.idx2user_map.get(idx)
instance = self._prepare_bert(user) if self.is_bert else self._prepare_left_wise(user)
return instance if self.transform_fn is None else self.transform_fn(instance)
def _prepare_bert(self, user):
seq = list(self.user2seq_map.get(user))
ts = self.user2ts_map.get(user)
pos = list()
# negative sampling might not be used
neg = list()
for i, item in enumerate(seq):
rand_ratio = random.random()
if rand_ratio < self.mask_ratio:
pos.append(item)
neg.append(sample_negative(seq, self.num_item))
# [mask] item index
seq[i] = self.num_item + 1
else:
# padding not for loss calculation
pos.append(0)
neg.append(0)
return user, seq, ts, pos, neg
def _prepare_left_wise(self, user):
user_history = self.user2seq_map.get(user)
seq = user_history[:-1]
ts = self.user2ts_map.get(user)[1:]
pos = user_history[1:]
# allow repeating for sampling in training stage
neg = [sample_negative(user_history, self.num_item) for _ in range(len(pos))]
return user, seq, ts, pos, neg
class PredictionDataset(Dataset):
def __init__(self, num_item,
history_user2seq_map,
history_user2ts_map,
pred_user2seq_map,
pred_user2ts_map,
max_seq_len,
is_bert=True,
is_sampling=False,
neg_num=100,
transform_fn=None):
self.num_item = num_item
self.is_bert = is_bert
# 100 negative sampling for evaluation when is_sampling=True
self.is_sampling = is_sampling
self.neg_num = neg_num
self.history_user2seq_map = {user: seq[-max_seq_len:] for user, seq in history_user2seq_map.items()}
self.history_user2ts_map = {user: ts[-max_seq_len:] for user, ts in history_user2ts_map.items()}
self.pred_user2seq_map = pred_user2seq_map
self.pred_user2ts_map = pred_user2ts_map
self.idx2user_map = {i: user for i, user in enumerate(pred_user2seq_map)}
self.transform_fn = transform_fn
def __len__(self):
return len(self.idx2user_map)
def __getitem__(self, idx):
user = self.idx2user_map.get(idx)
instance = self._prepare_bert(user) if self.is_bert else self._prepare_left_wise(user)
return instance if self.transform_fn is None else self.transform_fn(instance)
def _prepare_bert(self, user):
seq = self.history_user2seq_map.get(user) + [self.num_item + 1]
ts = self.history_user2ts_map.get(user) + self.pred_user2ts_map.get(user)
indices = self._prepare_indices(user)
return user, seq, ts, indices
def _prepare_left_wise(self, user):
seq = self.history_user2seq_map.get(user)
ts = self.history_user2ts_map.get(user)[1:] + self.pred_user2ts_map.get(user)
indices = self._prepare_indices(user)
return user, seq, ts, indices
def _prepare_indices(self, user):
if self.is_sampling:
pos = self.pred_user2seq_map.get(user)
neg = batch_sample_negative(pos, self.num_item, self.neg_num)
indices = pos + neg
else:
pos_label = self.pred_user2seq_map.get(user)[0]
indices = [i for i in range(1, self.num_item + 1)]
indices[0] = pos_label
indices[pos_label - 1] = 1
return indices