-
Notifications
You must be signed in to change notification settings - Fork 1
/
preprocess_data.py
87 lines (66 loc) · 2.19 KB
/
preprocess_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
import numpy as np
import pandas as pd
from pathlib import Path
import argparse
def preprocess(data_name):
u_list, i_list, ts_list, label_list = [], [], [], []
feat_l = []
idx_list = []
with open(data_name) as f:
s = next(f)
for idx, line in enumerate(f):
e = line.strip().split(',')
u = int(e[0])
i = int(e[1])
ts = float(e[2])
label = float(e[3]) # int(e[3])
feat = np.array([float(x) for x in e[4:]])
u_list.append(u)
i_list.append(i)
ts_list.append(ts)
label_list.append(label)
idx_list.append(idx)
feat_l.append(feat)
return pd.DataFrame({'u': u_list,
'i': i_list,
'ts': ts_list,
'label': label_list,
'idx': idx_list}), np.array(feat_l)
def reindex(df, bipartite=True):
new_df = df.copy()
if bipartite:
assert (df.u.max() - df.u.min() + 1 == len(df.u.unique()))
assert (df.i.max() - df.i.min() + 1 == len(df.i.unique()))
upper_u = df.u.max() + 1
new_i = df.i + upper_u
new_df.i = new_i
new_df.u += 1
new_df.i += 1
new_df.idx += 1
else:
new_df.u += 1
new_df.i += 1
new_df.idx += 1
return new_df
def run(data_name, bipartite=True):
Path("data/").mkdir(parents=True, exist_ok=True)
PATH = './data_raw/{}.csv'.format(data_name)
OUT_DF = './data/ml_{}.csv'.format(data_name)
OUT_FEAT = './data/ml_{}.npy'.format(data_name)
OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name)
df, feat = preprocess(PATH)
new_df = reindex(df, bipartite)
empty = np.zeros(feat.shape[1])[np.newaxis, :]
feat = np.vstack([empty, feat])
max_idx = max(new_df.u.max(), new_df.i.max())
rand_feat = np.zeros((max_idx + 1, 172))
new_df.to_csv(OUT_DF)
np.save(OUT_FEAT, feat)
np.save(OUT_NODE_FEAT, rand_feat)
parser = argparse.ArgumentParser('Interface for TGN data preprocessing')
parser.add_argument('--data', type=str, help='Dataset name (eg. wikipedia or reddit)',
default='TemFin')
parser.add_argument('--bipartite', action='store_true', help='Whether the graph is bipartite')
args = parser.parse_args()
run(args.data, bipartite=args.bipartite)