Skip to content

Commit

Permalink
Cleanup ranking code a bit.
Browse files Browse the repository at this point in the history
  • Loading branch information
coleifer committed Jan 7, 2019
1 parent ccbf5b6 commit 1696908
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 98 deletions.
145 changes: 59 additions & 86 deletions playhouse/_sqlite_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -736,31 +736,37 @@ cdef inline unicode decode(key):
return ukey


cdef double *get_weights(int ncol, tuple raw_weights):
cdef:
int argc = len(raw_weights)
int icol
double *weights = <double *>malloc(sizeof(double) * ncol)

for icol in range(ncol):
if argc == 0:
weights[icol] = 1.0
elif icol < argc:
weights[icol] = <double>raw_weights[icol]
else:
weights[icol] = 0.0
return weights


def peewee_rank(py_match_info, *raw_weights):
cdef:
unsigned int *match_info
unsigned int *phrase_info
bytes _match_info_buf = bytes(py_match_info)
char *match_info_buf = _match_info_buf
int argc = len(raw_weights)
int ncol, nphrase, icol, iphrase, hits, global_hits
int nphrase, ncol, icol, iphrase, hits, global_hits
int P_O = 0, C_O = 1, X_O = 2
double score = 0.0, weight
double *weights

match_info = <unsigned int *>match_info_buf
nphrase = match_info[P_O]
ncol = match_info[C_O]

# Normalize weights and load them into an array of weight for each column.
weights = <double *>malloc(sizeof(double) * ncol)
for icol in range(ncol):
if argc == 0:
weights[icol] = 1.0
elif icol < argc:
weights[icol] = <double>raw_weights[icol]
else:
weights[icol] = 0.0
weights = get_weights(ncol, raw_weights)

# matchinfo X value corresponds to, for each phrase in the search query, a
# list of 3 values for each column in the search table.
Expand Down Expand Up @@ -792,46 +798,35 @@ def peewee_lucene(py_match_info, *raw_weights):
# Usage: peewee_lucene(matchinfo(table, 'pcnalx'), 1)
cdef:
unsigned int *match_info
unsigned int *phrase_info
bytes _match_info_buf = bytes(py_match_info)
char *match_info_buf = _match_info_buf
int argc = len(raw_weights)
int term_count, col_count
int nphrase, ncol
double total_docs, term_frequency
double doc_length, docs_with_term, avg_length
double idf, weight, rhs, denom
double *weights
int P_O = 0, C_O = 1, N_O = 2, L_O, X_O
int i, j, x

int iphrase, icol, x
double score = 0.0

match_info = <unsigned int *>match_info_buf
term_count = match_info[P_O]
col_count = match_info[C_O]
nphrase = match_info[P_O]
ncol = match_info[C_O]
total_docs = match_info[N_O]

L_O = 3 + col_count
X_O = L_O + col_count

weights = <double *>malloc(sizeof(double) * col_count)
for i in range(col_count):
if argc == 0:
weights[i] = 1.
elif i < argc:
weights[i] = <double>raw_weights[i]
else:
weights[i] = 0
L_O = 3 + ncol
X_O = L_O + ncol
weights = get_weights(ncol, raw_weights)

for i in range(term_count):
for j in range(col_count):
weight = weights[j]
for iphrase in range(nphrase):
for icol in range(ncol):
weight = weights[icol]
if weight == 0:
continue
doc_length = match_info[L_O + j]
x = X_O + (3 * (j + i * col_count))
term_frequency = match_info[x]
docs_with_term = match_info[x + 2]
doc_length = match_info[L_O + icol]
x = X_O + (3 * (icol + iphrase * ncol))
term_frequency = match_info[x] # f(qi)
docs_with_term = match_info[x + 2] # n(qi)
idf = log(total_docs / (docs_with_term + 1.))
tf = sqrt(term_frequency)
fieldNorms = 1.0 / sqrt(doc_length)
Expand All @@ -847,19 +842,16 @@ def peewee_bm25(py_match_info, *raw_weights):
# the 3rd and 4th specify k and b.
cdef:
unsigned int *match_info
unsigned int *phrase_info
bytes _match_info_buf = bytes(py_match_info)
char *match_info_buf = _match_info_buf
int argc = len(raw_weights)
int term_count, col_count
int nphrase, ncol
double B = 0.75, K = 1.2, D
double total_docs, term_frequency
double doc_length, docs_with_term, avg_length
double idf, weight, rhs, denom
double *weights
int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O
int i, j, x

int iphrase, icol, x
double score = 0.0

match_info = <unsigned int *>match_info_buf
Expand All @@ -873,35 +865,27 @@ def peewee_bm25(py_match_info, *raw_weights):
# * phrase count within column for current row.
# * phrase count within column for all rows.
# * total rows for which column contains phrase.
term_count = match_info[P_O]
col_count = match_info[C_O]
nphrase = match_info[P_O]
ncol = match_info[C_O]
total_docs = match_info[N_O]

L_O = A_O + col_count
X_O = L_O + col_count
L_O = A_O + ncol
X_O = L_O + ncol
weights = get_weights(ncol, raw_weights)

weights = <double *>malloc(sizeof(double) * col_count)
for i in range(col_count):
if argc == 0:
weights[i] = 1.
elif i < argc:
weights[i] = <double>raw_weights[i]
else:
weights[i] = 0

for i in range(term_count):
for j in range(col_count):
weight = weights[j]
for iphrase in range(nphrase):
for icol in range(ncol):
weight = weights[icol]
if weight == 0:
continue
avg_length = match_info[A_O + j]
doc_length = match_info[L_O + j]
avg_length = match_info[A_O + icol]
doc_length = match_info[L_O + icol]
if avg_length == 0:
D = 0
else:
D = 1 - B + (B * (doc_length / avg_length))

x = X_O + (3 * (j + i * col_count))
x = X_O + (3 * (icol + iphrase * ncol))
term_frequency = match_info[x]
docs_with_term = match_info[x + 2]
idf = max(
Expand All @@ -927,50 +911,39 @@ def peewee_bm25f(py_match_info, *raw_weights):
# the 3rd and 4th specify k and b.
cdef:
unsigned int *match_info
unsigned int *phrase_info
bytes _match_info_buf = bytes(py_match_info)
char *match_info_buf = _match_info_buf
int argc = len(raw_weights)
int term_count, col_count
int nphrase, ncol
double B = 0.75, K1 = 1.2, D, epsilon
double total_docs, term_frequency, docs_with_term
double doc_length = 0.0, avg_length = 0.0
double idf, weight, rhs, denom
double *weights
int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O
int i, j, current_x

int iphrase, icol, current_x
double score = 0.0

match_info = <unsigned int *>match_info_buf
term_count = match_info[P_O]
col_count = match_info[C_O]
nphrase = match_info[P_O]
ncol = match_info[C_O]
total_docs = match_info[N_O]

L_O = A_O + col_count
X_O = L_O + col_count
L_O = A_O + ncol
X_O = L_O + ncol

for j in range(col_count):
avg_length += match_info[A_O + j]
doc_length += match_info[L_O + j]
for icol in range(ncol):
avg_length += match_info[A_O + icol]
doc_length += match_info[L_O + icol]

epsilon = 1.0 / (total_docs * avg_length)
weights = get_weights(ncol, raw_weights)

weights = <double *>malloc(sizeof(double) * col_count)
for i in range(col_count):
if argc == 0:
weights[i] = 1.
elif i < argc:
weights[i] = <double>raw_weights[i]
else:
weights[i] = 0

for i in range(term_count):
for j in range(col_count):
weight = weights[j]
for iphrase in range(nphrase):
for icol in range(ncol):
weight = weights[icol]
if weight == 0:
continue
current_x = X_O + (3 * (j + i * col_count))
current_x = X_O + (3 * (icol + iphrase * ncol))
term_frequency = match_info[current_x]
docs_with_term = match_info[current_x + 2]
idf = log(
Expand Down
23 changes: 11 additions & 12 deletions playhouse/sqlite_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,15 @@ def _parse_match_info(buf):
bufsize = len(buf) # Length in bytes.
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]

def get_weights(ncol, raw_weights):
if not raw_weights:
return [1] * ncol
else:
weights = [0] * ncol
for i, weight in enumerate(raw_weights):
weights[i] = weight
return weights

# Ranking implementation, which parse matchinfo.
def rank(raw_match_info, *raw_weights):
# Handle match_info called w/default args 'pcx' - based on the example rank
Expand All @@ -1110,12 +1119,7 @@ def rank(raw_match_info, *raw_weights):
score = 0.0

p, c = match_info[:2]
if not raw_weights:
weights = [1] * c
else:
weights = [0] * c
for i, weight in enumerate(raw_weights):
weights[i] = weight
weights = get_weights(c, raw_weights)

# matchinfo X value corresponds to, for each phrase in the search query, a
# list of 3 values for each column in the search table.
Expand Down Expand Up @@ -1165,12 +1169,7 @@ def bm25(raw_match_info, *args):
L_O = A_O + col_count
X_O = L_O + col_count

if not args:
weights = [1] * col_count
else:
weights = [0] * col_count
for i, weight in enumerate(args):
weights[i] = args[i]
weights = get_weights(col_count, args)

for i in range(term_count):
for j in range(col_count):
Expand Down

0 comments on commit 1696908

Please sign in to comment.