forked from castorini/anserini
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tune_bm25.py
87 lines (73 loc) · 3.56 KB
/
tune_bm25.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# -*- coding: utf-8 -*-
'''
Anserini: A Lucene toolkit for replicable information retrieval research
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
'''
# Simple script for tuning BM25 parameters (k1 and b) for MS MARCO
import argparse
import os
import re
import subprocess
parser = argparse.ArgumentParser(description='Tunes BM25 parameters for MS MARCO Passages')
parser.add_argument('--base_directory', required=True, help='base directory for storing runs')
parser.add_argument('--index', required=True, help='index to use')
parser.add_argument('--queries', required=True, help='queries for evaluation')
parser.add_argument('--qrels', required=True, help='qrels for evaluation')
args = parser.parse_args()
base_directory = args.base_directory
index = args.index
qrels = args.qrels
queries = args.queries
if not os.path.exists(args.base_directory):
os.makedirs(args.base_directory)
print('# Settings')
print('base directory: {}'.format(base_directory))
print('index: {}'.format(index))
print('queries: {}'.format(queries))
print('qrels: {}'.format(qrels))
print('\n')
for k1 in [0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]:
for b in [0.5, 0.6, 0.7, 0.8, 0.9]:
print('Trying... k1 = {}, b = {}'.format(k1, b))
filename = 'run.bm25.k1_{}.b_{}.txt'.format(k1, b)
if os.path.isfile('{}/{}'.format(base_directory, filename)):
print('Run already exists, skipping!')
else:
subprocess.call('python src/main/python/msmarco/retrieve.py \
--index {} --qid_queries {} --output {}/{} \
--k1 {} --b {} --hits 1000'.format(index, queries, base_directory, filename, k1, b), shell=True)
print('\n\nStarting evaluation...')
# We're going to be tuning to maximize recall, although we'll compute MRR and MAP also just for reference.
max_score = 0
max_file = ''
for filename in sorted(os.listdir(base_directory)):
# trec file, perhaps left over from a previous tuning run: skip.
if filename.endswith('trec'):
continue
# convert to a trec run and evaluate with trec_eval
subprocess.call('python src/main/python/msmarco/convert_msmarco_to_trec_run.py \
--input {}/{} --output {}/{}.trec'.format(base_directory, filename, base_directory, filename), shell=True)
results = subprocess.check_output(['eval/trec_eval.9.0.4/trec_eval', args.qrels,
'{}/{}.trec'.format(base_directory, filename), '-mrecall.1000', '-mmap'])
match = re.search('map +\tall\t([0-9.]+)', results.decode('utf-8'))
ap = float(match.group(1))
match = re.search('recall_1000 +\tall\t([0-9.]+)', results.decode('utf-8'))
recall = float(match.group(1))
# evaluate with official scoring script
results = subprocess.check_output(['python', 'src/main/python/msmarco/msmarco_eval.py', \
'{}'.format(qrels), '{}/{}'.format(base_directory, filename)])
match = re.search('MRR @10: ([\d.]+)', results.decode('utf-8'))
rr = float(match.group(1))
print('{}: MRR@10 = {}, MAP = {}, R@1000 = {}'.format(filename, rr, ap, recall))
if recall > max_score:
max_score = recall
max_file = filename
print('\n\nBest parameters: {}: R@1000 = {}'.format(max_file, max_score))