-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
36 lines (30 loc) · 1.42 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Find character coreference clusters in movie data jsonlines"""
from movie_coref.movie_coref import MovieCoreference
from absl import flags
from absl import app
FLAGS = flags.FLAGS
flags.DEFINE_string("input_file", default=None, help="Input preprocess movie data jsonlines file.", required=True)
flags.DEFINE_string("weights_file", default=None, help="Trained model weights.", required=True)
flags.DEFINE_integer("subdocument_len", default=5120, help="Subdocument length", lower_bound=512)
flags.DEFINE_integer("overlap_len", default=2048, help="Overlap length (fusion)", lower_bound=256)
flags.DEFINE_integer("repk", default=3, help="Number of representative mentions (hierarchical)", lower_bound=1)
flags.DEFINE_bool("hierarchical", default=False, help="Set for hierarchical inference, otherwise fusion-based "
"inference is performed.")
def predict(argv):
if len(argv) > 1:
print(f"Extra command-line arguments: {argv}")
return
movie_coref = MovieCoreference(
full_length_scripts_file=FLAGS.input_file,
weights_file=FLAGS.weights_file,
document_len=FLAGS.subdocument_len,
overlap_len=FLAGS.overlap_len,
hierarchical=FLAGS.hierarchical,
n_representative_mentions=FLAGS.repk,
save_log=False,
save_predictions=False,
save_loss_curve=False
)
movie_coref.predict()
if __name__=="__main__":
app.run(predict)