Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conversational entity linking into REL #144

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions scripts/efficiency_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import numpy as np
import requests
import os

from REL.training_datasets import TrainingEvaluationDatasets

np.random.seed(seed=42)

base_url = "/Users/vanhulsm/Desktop/projects/data/"
wiki_version = "wiki_2014"
base_url = os.environ.get("REL_BASE_URL")
wiki_version = "wiki_2019"
host = 'localhost'
port = '5555'
datasets = TrainingEvaluationDatasets(base_url, wiki_version).load()["aida_testB"]

# random_docs = np.random.choice(list(datasets.keys()), 50)
Expand Down Expand Up @@ -40,7 +43,7 @@
print(myjson)

print("Output API:")
print(requests.post("http://192.168.178.11:1235", json=myjson).json())
print(requests.post(f"http://{host}:{port}", json=myjson).json())
print("----------------------------")


Expand Down
57 changes: 57 additions & 0 deletions scripts/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import requests

# Script for testing the implementation of the conversational entity linking API
#
# To run:
#
# python .\src\REL\server.py $REL_BASE_URL wiki_2019
# or
# python .\src\REL\server.py $env:REL_BASE_URL wiki_2019
#
# Set $REL_BASE_URL to where your data are stored (`base_url`)
#
# These paths must exist:
# - `$REL_BASE_URL/bert_conv`
# - `$REL_BASE_URL/s2e_ast_onto `
#
# (see https://github.com/informagi/conversational-entity-linking-2022/tree/main/tool#step-1-download-models)
#


host = 'localhost'
port = '5555'

text1 = {
"text": "REL is a modular Entity Linking package that can both be integrated in existing pipelines or be used as an API.",
"spans": []
}

conv1 = {"conversation": "True",
"text" : [
{
"speaker":
"USER",
"utterance":
"I am allergic to tomatoes but we have a lot of famous Italian restaurants here in London.",
},
{
"speaker": "SYSTEM",
"utterance": "Some people are allergic to histamine in tomatoes.",
},
{
"speaker":
"USER",
"utterance":
"Talking of food, can you recommend me a restaurant in my city for our anniversary?",
},
]
}

for myjson in text1, conv1:
print('Input API:')
print(myjson)
print()
print('Output API:')
print(requests.post(f"http://{host}:{port}", json=myjson).json())
print('----------------------------')
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ install_requires =
konoha
flair>=0.11
segtok
spacy
torch
nltk
anyascii
Expand Down
Empty file added src/REL/crel/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions src/REL/crel/bert_md.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline


class BERT_MD:
def __init__(self, file_pretrained):
"""

Args:
file_pretrained = "./tmp/ft-conel/"

Note:
The output of self.ner_model(s_input) is like
- s_input: e.g, 'Burger King franchise'
- return: e.g., [{'entity': 'B-ment', 'score': 0.99364895, 'index': 1, 'word': 'Burger', 'start': 0, 'end': 6}, ...]
"""

model = AutoModelForTokenClassification.from_pretrained(file_pretrained)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(file_pretrained)
self.ner_model = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
device=device.index if device.index != None else -1,
ignore_labels=[],
)

def md(self, s, flag_warning=False):
"""Perform mention detection

Args:
s: input string
flag_warning: if True, print warning message

Returns:
REL style annotation results: [(start_position, length, mention), ...]
E.g., [[0, 15, 'The Netherlands'], ...]
"""

ann = self.ner_model(s) # Get ann results from BERT-NER model

ret = []
pos_start, pos_end = -1, -1 # Initialize variables

for i in range(len(ann)):
w, ner = ann[i]["word"], ann[i]["entity"]
assert ner in [
"B-ment",
"I-ment",
"O",
], f"Unexpected ner tag: {ner}. If you use BERT-NER as it is, then you should flag_use_normal_bert_ner_tag=True."
if ner == "B-ment" and w[:2] != "##":
if (pos_start != -1) and (pos_end != -1): # If B-ment is already found
ret.append(
[pos_start, pos_end - pos_start, s[pos_start:pos_end]]
) # save the previously identified mention
pos_start, pos_end = -1, -1 # Initialize
pos_start, pos_end = ann[i]["start"], ann[i]["end"]

elif ner == "B-ment" and w[:2] == "##":
if (ann[i]["index"] == ann[i - 1]["index"] + 1) and (
ann[i - 1]["entity"] != "B-ment"
): # If previous token has an entity (ner) label and it is NOT "B-ment" (i.e., ##xxx should not be the begin of the entity)
if flag_warning:
print(
f"WARNING: ##xxx (in this case {w}) should not be the begin of the entity"
)

elif (
i > 0
and (ner == "I-ment")
and (ann[i]["index"] == ann[i - 1]["index"] + 1)
): # If w is I-ment and previous word's index (i.e., ann[i-1]['index']) is also a mention
pos_end = ann[i]["end"] # update pos_end

# This only happens when flag_ignore_o is False
elif (
ner == "O"
and w[:2] == "##"
and (
ann[i - 1]["entity"] == "B-ment" or ann[i - 1]["entity"] == "I-ment"
)
): # If w is "O" and ##xxx, and previous token's index (i.e., ann[i-1]['index']) is B-ment or I-ment
pos_end = ann[i]["end"] # update pos_end

# Append remaining ment
if (pos_start != -1) and (pos_end != -1):
ret.append(
[pos_start, pos_end - pos_start, s[pos_start:pos_end]]
) # Save last mention

return ret
142 changes: 142 additions & 0 deletions src/REL/crel/conv_el.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import importlib
import sys
from pathlib import Path

from .bert_md import BERT_MD
from .rel_ed import REL_ED
from .s2e_pe import pe_data
from .s2e_pe.pe import EEMD, PEMD


class ConvEL:
def __init__(
self, base_url=".", wiki_version="wiki_2019", user_config=None, threshold=0
):
self.threshold = threshold

self.wiki_version = wiki_version
self.base_url = base_url
self.file_pretrained = str(Path(base_url) / "bert_conv-td")

self.bert_md = BERT_MD(self.file_pretrained)
self.rel_ed = REL_ED(self.base_url, self.wiki_version)
self.eemd = EEMD(s2e_pe_model=str(Path(base_url) / "s2e_ast_onto"))
self.pemd = PEMD()

self.preprocess = pe_data.PreProcess()
self.postprocess = pe_data.PostProcess()

# These are always initialize when get_annotations() is called
self.conv_hist_for_pe = (
[]
) # initialize the history of conversation, which is used in PE Linking
self.ment2ent = {} # This will be used for PE Linking

def _error_check(self, conv):
assert type(conv) == list
for turn in conv:
assert type(turn) == dict
assert set(turn.keys()) == {"speaker", "utterance"}
assert turn["speaker"] in [
"USER",
"SYSTEM",
], f'Speaker should be either "USER" or "SYSTEM", but got {turn["speaker"]}'

def _el(self, utt):
"""Perform entity linking"""
# MD
md_results = self.bert_md.md(utt)

# ED
spans = [[r[0], r[1]] for r in md_results] # r[0]: start, r[1]: length
el_results = self.rel_ed.ed(utt, spans) # ED

self.conv_hist_for_pe[-1]["mentions"] = [r[2] for r in el_results]
self.ment2ent.update(
{r[2]: r[3] for r in el_results}
) # If there is a mismatch of annotations for the same mentions, the last one (the most closest turn's one to the PEM) will be used.

return [r[:4] for r in el_results] # [start_pos, length, mention, entity]

def _pe(self, utt):
"""Perform PE Linking"""

ret = []

# Step 1: PE Mention Detection
pem_results = self.pemd.pem_detector(utt)
pem2result = {r[2]: r for r in pem_results}

# Step 2: Finding corresponding explicit entity mentions (EEMs)
# NOTE: Current implementation can handle only one target PEM at a time
outputs = []
for _, _, pem in pem_results: # pems: [[start_pos, length, pem], ...]
self.conv_hist_for_pe[-1]["pems"] = [
pem
] # Create a conv for each target PEM that you want to link

# Preprocessing
token_with_info = self.preprocess.get_tokens_with_info(
self.conv_hist_for_pe
)
input_data = self.preprocess.get_input_of_pe_linking(token_with_info)

assert (
len(input_data) == 1
), f"Current implementation can handle only one target PEM at a time"
input_data = input_data[0]

# Finding corresponding explicit entity mentions (EEMs)
scores = self.eemd.get_scores(input_data)

# Post processing
outputs += self.postprocess.get_results(
input_data, self.conv_hist_for_pe, self.threshold, scores
)

self.conv_hist_for_pe[-1]["pems"] = [] # Remove the target PEM

# Step 3: Get corresponding entity
for r in outputs:
pem = r["personal_entity_mention"]
pem_result = pem2result[pem] # [start_pos, length, pem]
eem = r["mention"] # Explicit entity mention
ent = self.ment2ent[eem] # Corresponding entity
ret.append(
[pem_result[0], pem_result[1], pem_result[2], ent]
) # [start_pos, length, PEM, entity]

return ret

def annotate(self, conv):
"""Get conversational entity linking annotations

Args:
conv: A list of dicts, each dict contains "speaker" and "utterance" keys.

Returns:
A list of dicts, each dict contains conv's ones + "annotations" key.
"""
self._error_check(conv)
ret = []
self.conv_hist_for_pe = [] # Initialize
self.ment2ent = {} # Initialize

for turn in conv:
utt = turn["utterance"]
assert turn["speaker"] in [
"USER",
"SYSTEM",
], f'Speaker should be either "USER" or "SYSTEM", but got {turn["speaker"]}'
ret.append({"speaker": turn["speaker"], "utterance": utt})

self.conv_hist_for_pe.append({})
self.conv_hist_for_pe[-1]["speaker"] = turn["speaker"]
self.conv_hist_for_pe[-1]["utterance"] = utt

if turn["speaker"] == "USER":
el_results = self._el(utt)
pe_results = self._pe(utt)
ret[-1]["annotations"] = el_results + pe_results

return ret
60 changes: 60 additions & 0 deletions src/REL/crel/rel_ed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import sys

from REL.entity_disambiguation import EntityDisambiguation
from REL.mention_detection import MentionDetection
from REL.utils import process_results


class REL_ED:
def __init__(self, base_url, wiki_version):

config = {
"mode": "eval",
"model_path": f"{base_url}/{wiki_version}/generated/model",
}

self.mention_detection = MentionDetection(
base_url, wiki_version
) # This is only used for format spans
self.model = EntityDisambiguation(base_url, wiki_version, config)

def generate_response(self, text, spans):
"""Generate ED results

Returns:
- list of tuples for each entity found.

Note:
- Original code: https://github.com/informagi/REL/blob/9ca253b1d371966c39219ed672f39784fd833d8d/REL/server.py#L101
"""

API_DOC = "API_DOC"

if len(text) == 0 or len(spans) == 0:
return []

# Get the mentions from the spans
processed = {API_DOC: [text, spans]}
mentions_dataset, total_ment = self.mention_detection.format_spans(processed)

# Disambiguation
predictions, timing = self.model.predict(mentions_dataset)

# Process result.
result = process_results(
mentions_dataset,
predictions,
processed,
include_offset=False if (len(spans) > 0) else True,
)

# Singular document.
if len(result) > 0:
return [*result.values()][0]

return []

def ed(self, text, spans):
"""Change tuple to list to match the output format of REL API."""
response = self.generate_response(text, spans)
return [list(ent) for ent in response]
Empty file added src/REL/crel/s2e_pe/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions src/REL/crel/s2e_pe/consts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SPEAKER_START = 49518 # 'Ġ#####'
SPEAKER_END = 22560 # 'Ġ###'
NULL_ID_FOR_COREF = 0
Loading