-
Notifications
You must be signed in to change notification settings - Fork 0
/
predictor.py
41 lines (34 loc) · 1.76 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from allennlp.common import JsonDict
from allennlp.data import DatasetReader, Instance
from allennlp.data.tokenizers import SpacyTokenizer
from allennlp.models import Model
from allennlp.predictors import Predictor
from overrides import overrides
from typing import List
# You need to name your predictor and register so that `allennlp` command can recognize it
# Note that you need to use "@Predictor.register", not "@Model.register"!
@Predictor.register("sentence_classifier_predictor")
class SentenceClassifierPredictor(Predictor):
def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
super().__init__(model, dataset_reader)
self._tokenizer = SpacyTokenizer(language='en_core_web_sm', pos_tags=True)
def predict(self, sentence: str) -> JsonDict:
return self.predict_json({"sentence" : sentence})
@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
sentence = json_dict["sentence"]
tokens = self._tokenizer.tokenize(sentence)
return self._dataset_reader.text_to_instance([str(t) for t in tokens])
@Predictor.register("universal_pos_predictor")
class UniversalPOSPredictor(Predictor):
def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
super().__init__(model, dataset_reader)
def predict(self, words: List[str]) -> JsonDict:
return self.predict_json({"words" : words})
@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
words = json_dict["words"]
# This is a hack - the second argument to text_to_instance is a list of POS tags
# that has the same length as words. We don't need it for prediction so
# just pass words.
return self._dataset_reader.text_to_instance(words, words)