forked from karpathy/arxiv-sanity-preserver
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuildsvm.py
86 lines (62 loc) · 2.55 KB
/
buildsvm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import pickle
import numpy as np
from sklearn import svm
from sqlite3 import dbapi2 as sqlite3
from utils import safe_pickle_dump, Config
def get_sqldb():
if not os.path.isfile(Config.database_path):
sqldb = sqlite3.connect(Config.database_path)
with open('schema.sql') as fp:
sqldb.executescript(fp.read())
else:
sqldb = sqlite3.connect(Config.database_path)
sqldb.row_factory = sqlite3.Row # to return dicts rather than tuples
return sqldb
def query_db(sqldb, query, args=(), one=False):
"""Queries the database and returns a list of dictionaries."""
cur = sqldb.execute(query, args)
rv = cur.fetchall()
return (rv[0] if rv else None) if one else rv
def get_users(sqldb):
# fetch all users
users = query_db(sqldb, '''select * from user''')
print('number of users: ', len(users))
return users
def get_libs(sqldb, user_id):
return query_db(sqldb, '''select * from library where user_id = ?''', [user_id])
def get_tfidf():
meta = pickle.load(open(Config.meta_path, 'rb'))
out = pickle.load(open(Config.tfidf_path, 'rb'))
X = out['X']
X = X.todense().astype(np.float32)
pids = meta['pids']
ptoi = meta['ptoi']
return pids, ptoi, X
def get_user_sim(sqldb, users, meta_pids, ptoi, X, num_recommendations):
user_sim = {}
for ii, u in enumerate(users):
print("%d/%d building an SVM for %s" % (ii, len(users), u['username'].encode('utf-8')))
user_id = u['user_id']
user_raw_pids = [x['paper_id'] for x in get_libs(sqldb, user_id)] # raw pids without version
user_pid_idx = [ptoi[p] for p in user_raw_pids if p in ptoi]
if not user_pid_idx:
continue # empty library for this user maybe?
print(user_raw_pids)
y = np.zeros(X.shape[0])
for ix in user_pid_idx: y[ix] = 1
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
clf.fit(X, y)
s = clf.decision_function(X)
sortix = np.argsort(-s)
sortix = sortix[:min(num_recommendations, len(sortix))] # crop paper recommendations to save space
user_sim[user_id] = [meta_pids[ix] for ix in list(sortix)]
print('writing', Config.user_sim_path)
safe_pickle_dump(user_sim, Config.user_sim_path)
def run():
num_recommendations = 1000 # papers to recommend per user
sqldb = get_sqldb()
meta_pids, ptoi, X = get_tfidf()
get_user_sim(sqldb, get_users(sqldb), meta_pids, ptoi, X, num_recommendations)
if __name__ == '__main__':
run()