Skip to content

Commit

Permalink
slight refactorr of the code just so i'm not tooo embarassed, and als…
Browse files Browse the repository at this point in the history
…o setting it up better for more similarities in the future
  • Loading branch information
Andrej Karpathy committed Mar 31, 2020
1 parent 0faad4f commit e08cfd1
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 94 deletions.
202 changes: 111 additions & 91 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,97 +6,117 @@
import requests
import numpy as np

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction import stop_words
from sklearn import svm

# -----------------------------------------------------------------------------

jstr = requests.get('https://connect.biorxiv.org/relate/collection_json.php?grp=181')
jall = jstr.json()
print(f"writing jall.json with {len(jall['rels'])} papers")
json.dump(jall, open('jall.json', 'w'))

# compute tfidf features with scikit learn
print("fitting tfidf")
max_features = 2000
v = TfidfVectorizer(input='content',
encoding='utf-8', decode_error='replace', strip_accents='unicode',
lowercase=True, analyzer='word', stop_words='english',
token_pattern=r'(?u)\b[a-zA-Z_][a-zA-Z0-9_-]+\b',
ngram_range=(1, 1), max_features = max_features,
norm='l2', use_idf=True, smooth_idf=True, sublinear_tf=True,
max_df=1.0, min_df=1)
corpus = [a['rel_abs'] for a in jall['rels']]
v.fit(corpus)

# use tfidf features to find nearest neighbors cheaply
X = v.transform(corpus)
D = np.dot(X, X.T).todense()
IX = np.argsort(-D, axis=1)
sim = {}
ntake = 40
n = IX.shape[0]
for i in range(n):
ixc = [int(IX[i,j]) for j in range(ntake)]
ds = [int(D[i,IX[i,j]]*1000) for j in range(ntake)]
sim[i] = list(zip(ixc, ds))
print("writing sim.json")
json.dump(sim, open('sim.json', 'w'))

# use exemplar SVM to build similarity instead
print("fitting SVMs per paper")
svm_sim = {}
ntake = 40
for i in range(n):
# everything is 0 except the current index - i.e. "exemplar svm"
y = np.zeros(X.shape[0])
y[i] = 1
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-4, C=1.0)
clf.fit(X,y)
s = clf.decision_function(X)
IX = np.argsort(-s)
ixc = [int(IX[j]) for j in range(ntake)]
ds = [int(D[i,IX[j]]*1000) for j in range(ntake)]
svm_sim[i] = list(zip(ixc, ds))
json.dump(svm_sim, open('svm_sim.json', 'w'))

# construct a reverse index for suppoorting search
vocab = v.vocabulary_
idf = v.idf_
english_stop_words = stop_words.ENGLISH_STOP_WORDS
punc = "'!\"#$%&\'()*+,./:;<=>?@[\\]^_`{|}~'" # removed hyphen from string.punctuation
trans_table = {ord(c): None for c in punc}

def makedict(s, forceidf=None, scale=1.0):
words = set(s.lower().translate(trans_table).strip().split())
words = set(w for w in words if len(w) > 1 and (not w in english_stop_words))
idfd = {}
for w in words: # todo: if we're using bigrams in vocab then this won't search over them
if forceidf is None:
if w in vocab:
idfval = idf[vocab[w]] * scale # we have idf for this
def write_json(obj, filename, msg=''):
suffix = f'; {msg}' if msg else ''
print(f"writing {filename}{suffix}")
with open(filename, 'w') as f:
json.dump(obj, f)


def calculate_tfidf_features(rels, max_features=5000, max_df=1.0, min_df=3):
""" compute tfidf features with scikit learn """
from sklearn.feature_extraction.text import TfidfVectorizer
v = TfidfVectorizer(input='content',
encoding='utf-8', decode_error='replace', strip_accents='unicode',
lowercase=True, analyzer='word', stop_words='english',
token_pattern=r'(?u)\b[a-zA-Z_][a-zA-Z0-9_-]+\b',
ngram_range=(1, 1), max_features=max_features,
norm='l2', use_idf=True, smooth_idf=True, sublinear_tf=True,
max_df=max_df, min_df=min_df)
corpus = [(a['rel_title'] + '. ' + a['rel_abs']) for a in rels]
X = v.fit_transform(corpus)
X = np.asarray(X.astype(np.float32).todense())
print("tfidf calculated array of shape ", X.shape)
return X, v


def calculate_sim_dot_product(X, ntake=40):
""" take X (N,D) features and for each index return closest ntake indices via dot product """
S = np.dot(X, X.T)
IX = np.argsort(S, axis=1)[:, :-ntake-1:-1] # take last ntake sorted backwards
return IX.tolist()


def calculate_sim_svm(X, ntake=40):
""" take X (N,D) features and for each index return closest ntake indices using exemplar SVM """
from sklearn import svm
n, d = X.shape
IX = np.zeros((n, ntake), dtype=np.int64)
print(f"training {n} svms for each paper...")
for i in range(n):
# set all examples as negative except this one
y = np.zeros(X.shape[0], dtype=np.float32)
y[i] = 1
# train an SVM
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-4, C=0.1)
clf.fit(X, y)
s = clf.decision_function(X)
ix = np.argsort(s)[:-ntake-1:-1] # take last ntake sorted backwards
IX[i] = ix
return IX.tolist()


def build_search_index(rels, v):
from sklearn.feature_extraction import stop_words

# construct a reverse index for suppoorting search
vocab = v.vocabulary_
idf = v.idf_
english_stop_words = stop_words.ENGLISH_STOP_WORDS
punc = "'!\"#$%&\'()*+,./:;<=>?@[\\]^_`{|}~'" # removed hyphen from string.punctuation
trans_table = {ord(c): None for c in punc}

def makedict(s, forceidf=None):
words = set(s.lower().translate(trans_table).strip().split())
words = set(w for w in words if len(w) > 1 and (not w in english_stop_words))
idfd = {}
for w in words: # todo: if we're using bigrams in vocab then this won't search over them
if forceidf is None:
if w in vocab:
idfval = idf[vocab[w]] # we have a computed idf for this
else:
idfval = 1.0 # some word we don't know; assume idf 1.0 (low)
else:
idfval = 1.0 * scale # assume idf 1.0 (low)
else:
idfval = forceidf
idfd[w] = idfval
return idfd

def merge_dicts(dlist):
m = {}
for d in dlist:
for k,v in d.items():
m[k] = m.get(k,0) + v
return m

search_dict = []
for p in jall['rels']:
dict_title = makedict(p['rel_title'], forceidf=10, scale=3)
dict_authors = makedict(p['rel_authors'], forceidf=5)
dict_summary = makedict(p['rel_abs'])
qdict = merge_dicts([dict_title, dict_authors, dict_summary])
search_dict.append(qdict)

print("writing search.json")
json.dump(search_dict, open('search.json', 'w'))
idfval = forceidf
idfd[w] = idfval
return idfd

def merge_dicts(dlist):
m = {}
for d in dlist:
for k, v in d.items():
m[k] = m.get(k,0) + v
return m

search_dict = []
for p in rels:
dict_title = makedict(p['rel_title'], forceidf=10)
dict_authors = makedict(p['rel_authors'], forceidf=5)
dict_summary = makedict(p['rel_abs'])
qdict = merge_dicts([dict_title, dict_authors, dict_summary])
search_dict.append(qdict)

return search_dict


if __name__ == '__main__':

# fetch the raw data from biorxiv
jstr = requests.get('https://connect.biorxiv.org/relate/collection_json.php?grp=181')
jall = jstr.json()
write_json(jall, 'jall.json', f"{len(jall['rels'])} papers")

# calculate similarities using various techniques
X, v = calculate_tfidf_features(jall['rels'])
# similarity using simple dot product on tfidf
sim_tfidf = calculate_sim_dot_product(X)
write_json(sim_tfidf, 'sim_tfidf_dot.json')
# similarity using an exemplar svm on tfidf
sim_svm = calculate_sim_svm(X)
write_json(sim_svm, 'sim_tfidf_svm.json')

# calculate the search index to support search
search_dict = build_search_index(jall['rels'], v)
write_json(search_dict, 'search.json')
6 changes: 3 additions & 3 deletions serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
"""

import json
import argparse

from flask import Flask, request, redirect, url_for
from flask import render_template

# -----------------------------------------------------------------------------

app = Flask(__name__)
Expand All @@ -16,7 +16,7 @@
jall = json.load(f)

# load computed paper similarities
with open('svm_sim.json', 'r') as f:
with open('sim_tfidf_svm.json', 'r') as f:
sim_dict = json.load(f)

# load search dictionary for each paper
Expand Down Expand Up @@ -63,7 +63,7 @@ def sim(doi_prefix=None, doi_suffix=None):
if pix is None:
papers = []
else:
sim_ix, match = zip(*sim_dict[str(pix)][:40]) # indices of closest papers
sim_ix = sim_dict[pix]
papers = [jall['rels'][cix] for cix in sim_ix]
gvars = {'sort_order': 'sim', 'num_papers': len(jall['rels'])}
context = {'papers': papers, 'gvars': gvars}
Expand Down

0 comments on commit e08cfd1

Please sign in to comment.