Skip to content

Commit

Permalink
Merge branch 'fix_model' into 'master'
Browse files Browse the repository at this point in the history
Fix model build to use on ia2-server

See merge request llave-en-mano/liberajus/dataturks-spacy-train-cli-util!41
  • Loading branch information
sgobotta committed Mar 19, 2021
2 parents 6846721 + 8c85f5a commit ad9414a
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 74 deletions.
70 changes: 69 additions & 1 deletion pipeline_components/entity_matcher.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,78 @@
import logging
from pipeline_components.generic_matcher import GenericMatcher, repeat_patterns
from spacy.matcher import Matcher
from spacy.tokens import Span
from spacy.lang.es.lex_attrs import _num_words
from spacy.util import filter_spans


def filter_longer_spans(spans, *, seen_tokens=set(), preserve_spans=[]):
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
creating named entities (where one token can only be part of one entity) or
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
longest span is preferred over shorter spans.
spans (iterable): The spans to filter.
RETURNS (list): The filtered spans.
"""

def get_sort_key(span):
return (span.end - span.start, -span.start)

sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = preserve_spans
_seen_tokens = seen_tokens
for span in sorted_spans:
# Check for end - 1 here because boundaries are inclusive
if span.start not in _seen_tokens and span.end - 1 not in _seen_tokens:
result.append(span)
_seen_tokens.update(range(span.start, span.end))
result = sorted(result, key=lambda span: span.start)
return result


def repeat_patterns(patterns, times):
"""
Utility function that receives a pattern to return a list that contains
that pattern multiplied by times. The final length of the list is equal to
len(patterns) * times.
"""
generated_patterns = []
for i in range(0, times):
[generated_patterns.append(pattern) for pattern in patterns]
return generated_patterns


class GenericMatcher(object):
"""
GenericMatcher: Given an NLP instance, and list of patterns, generates a
pipeline that matches tokens against each of those patterns to return an
updated Doc object.
"""

name = "generic_matcher"

def __init__(self, nlp, matcher_patterns=[]):
self.nlp = nlp
self.matcher = Matcher(self.nlp.vocab, validate=True)
# Adds patterns to the Matcher pipeline
for entity_label, pattern in matcher_patterns:
self.matcher.add(entity_label, [pattern], on_match=None)

def __call__(self, doc):
matches = self.matcher(doc)
matched_spans = [Span(doc, start, end, self.nlp.vocab.strings[match_id]) for match_id, start, end in matches]
# Creates a set of seen tokens so that the filter_longer_spans function
# prioritizes those spans we are sending.
seen_tokens = set()
merged_matched_spans = filter_spans(matched_spans)
for span in merged_matched_spans:
seen_tokens.update(range(span.start, span.end))
doc_ents = merged_matched_spans + list(doc.ents)
# Merges adjacent entities and removes overlapped entities
doc.ents = filter_longer_spans(doc_ents, seen_tokens=seen_tokens, preserve_spans=merged_matched_spans)
return doc


# Extends built-in lex_attrs from the spanish lang package
num_words = _num_words + [
"ciento",
Expand Down
72 changes: 0 additions & 72 deletions pipeline_components/generic_matcher.py

This file was deleted.

2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def build_model_package(

# Copio archivos con modulos custom al directorio pipe_components (no uso model_components para que sea fijo y siempre el mismo en nuestro componente, y asi poder desacoplarlo de cuando tengamos multiples clientes
package_components_dir = "pipeline_components"
files_src = ["entity_matcher.py", "entity_custom.py", "generic_matcher.py"]
files_src = ["entity_matcher.py", "entity_custom.py"]
dest_component_dir = os.path.join(package_base_path, package_dir, package_components_dir)
os.mkdir(dest_component_dir)
for f in files_src:
Expand Down

0 comments on commit ad9414a

Please sign in to comment.