diff --git a/robotreviewer/robots/pico_span_robot.py b/robotreviewer/robots/pico_span_robot.py index e7aebbe..1787e25 100644 --- a/robotreviewer/robots/pico_span_robot.py +++ b/robotreviewer/robots/pico_span_robot.py @@ -98,23 +98,23 @@ def annotate(self, article): for sent in chain(article['title'].sents, article['abstract'].sents): words = [w.text for w in sent] preds = self.model.predict(words) - + last_label = "N" - span = [] + start_idx = 0 - for w, p in zip(words, preds): + for i, p in enumerate(preds): - if p != last_label and span: - out[label_dict[last_label]].append(' '.join(span).strip()) - span = [] - - if p != "N": - span.append(w) + if p != last_label and last_label != "N": + out[label_dict[last_label]].append(sent[start_idx: i].text.strip()) + start_idx = i + + if p != last_label and last_label == "N": + start_idx = i last_label = p if last_label != "N": - out[label_dict[last_label]].append(' '.join(span).strip()) + out[label_dict[last_label]].append(sent[start_idx:].text.strip()) return out