|
| 1 | +import spacy |
1 | 2 | import logging
|
2 |
| -import textwrap |
3 |
| -from servicelayer.rpc import ExtractedEntity |
4 |
| -from servicelayer.rpc import EntityExtractService |
5 | 3 |
|
6 | 4 | from aleph import settings
|
7 |
| -from aleph.tracing import trace_function |
8 |
| -from aleph.logic.extractors.result import PersonResult, LocationResult |
9 |
| -from aleph.logic.extractors.result import OrganizationResult, LanguageResult |
| 5 | +from aleph.logic.extractors.result import PersonResult |
| 6 | +from aleph.logic.extractors.result import LocationResult |
| 7 | +from aleph.logic.extractors.result import OrganizationResult |
10 | 8 |
|
11 | 9 | log = logging.getLogger(__name__)
|
12 |
| - |
13 |
| - |
14 |
| -class NERService(EntityExtractService): |
15 |
| - MIN_LENGTH = 60 |
16 |
| - MAX_LENGTH = 100000 |
17 |
| - TYPES = { |
18 |
| - ExtractedEntity.ORGANIZATION: OrganizationResult, |
19 |
| - ExtractedEntity.PERSON: PersonResult, |
20 |
| - ExtractedEntity.LOCATION: LocationResult, |
21 |
| - ExtractedEntity.LANGUAGE: LanguageResult |
22 |
| - } |
23 |
| - |
24 |
| - @trace_function(span_name='NER') |
25 |
| - def extract_all(self, text, languages): |
26 |
| - if text is None or len(text) < self.MIN_LENGTH: |
27 |
| - return |
28 |
| - if len(text) > self.MAX_LENGTH: |
29 |
| - texts = textwrap.wrap(text, self.MAX_LENGTH) |
30 |
| - else: |
31 |
| - texts = [text] |
32 |
| - for text in texts: |
33 |
| - for res in self.Extract(text, languages): |
34 |
| - clazz = self.TYPES.get(res.type) |
35 |
| - yield (res.text, clazz, res.start, res.end) |
| 10 | +MIN_LENGTH = 60 |
| 11 | +MAX_LENGTH = 100000 |
| 12 | +# https://spacy.io/api/annotation#named-entities |
| 13 | +SPACY_TYPES = { |
| 14 | + 'PER': PersonResult, |
| 15 | + 'PERSON': PersonResult, |
| 16 | + 'ORG': OrganizationResult, |
| 17 | + 'LOC': LocationResult, |
| 18 | + 'GPE': LocationResult |
| 19 | +} |
36 | 20 |
|
37 | 21 |
|
38 | 22 | def extract_entities(ctx, text, languages):
|
39 |
| - if not hasattr(settings, '_ner_service'): |
40 |
| - settings._ner_service = NERService() |
41 |
| - entities = settings._ner_service.extract_all(text, languages=languages) |
42 |
| - for (text, clazz, start, end) in entities: |
43 |
| - yield clazz.create(ctx, text, start, end) |
| 23 | + if text is None or len(text) < MIN_LENGTH: |
| 24 | + return |
| 25 | + if not hasattr(settings, '_nlp'): |
| 26 | + settings._nlp = spacy.load('xx') |
| 27 | + doc = settings._nlp(text) |
| 28 | + for ent in doc.ents: |
| 29 | + clazz = SPACY_TYPES.get(ent.label_) |
| 30 | + label = ent.text.strip() |
| 31 | + if clazz is not None and len(label): |
| 32 | + yield clazz.create(ctx, label, ent.start, ent.end) |
0 commit comments