Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coreference search with LSH #153

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
53540b3
add option for whether with_coref() should be used
f-hafner Nov 25, 2022
cdd8c7b
time ED for different dataset sizes
f-hafner Dec 7, 2022
aa9c741
change coref switch name, more efficiency tests
f-hafner Dec 14, 2022
463245d
add coreference indicator to prediction output
f-hafner Dec 16, 2022
93fa8f5
efficiency test: also pickle data after mention detection
f-hafner Dec 20, 2022
c218ce1
integrate lsh and first run
f-hafner Jan 6, 2023
9f5657d
3 options for coreferences
f-hafner Jan 6, 2023
7d8a17c
adjust update_efficiency_tests.sh
f-hafner Jan 6, 2023
f080855
make printout backwards compatible
f-hafner Jan 9, 2023
1a432d9
add basic logging to lsh class
f-hafner Jan 10, 2023
ae8e9e1
fix bug for single mention, add logging to efficiency test
f-hafner Jan 11, 2023
b8b4ea0
restore run_efficiency_test.sh
f-hafner Jan 11, 2023
0cc7598
scale fake data more
f-hafner Jan 16, 2023
2ca1bbc
switch to hashing with random projections
f-hafner Jan 16, 2023
5fc4357
add some more debugging to lsh
f-hafner Jan 16, 2023
e6894d5
speed up get_candidates()
f-hafner Jan 18, 2023
6cbb668
use sklearn binarizer for encoding
f-hafner Jan 18, 2023
6ac9ff0
test higher precision for lsh
f-hafner Jan 18, 2023
f19c904
vectorize banding
f-hafner Jan 19, 2023
9f41745
small speed ups for get_candidates_new()
f-hafner Jan 20, 2023
7570688
small changes to efficiency tests
f-hafner Jan 20, 2023
aa79b24
start tidying lsh
f-hafner Jan 23, 2023
50f6bfd
drop most of old code
f-hafner Jan 23, 2023
017b03e
lsh class: tidy, add docstrings
f-hafner Jan 24, 2023
780cee2
give right name to main class: random projections
f-hafner Jan 24, 2023
afa63d9
start tests, fix bug in cols_to_int_multidim
f-hafner Jan 24, 2023
136659b
improve docstrings
f-hafner Jan 24, 2023
35a0a32
n_bands and band_length as main inputs to class
f-hafner Jan 24, 2023
ff6778c
document the lsh class
f-hafner Jan 24, 2023
2b6315a
update docstring for with_coref
f-hafner Jan 24, 2023
86f89c2
small fixes to lsh and training_datasets
f-hafner Jan 24, 2023
7586fe6
tidy efficiency_test
f-hafner Jan 24, 2023
d218d54
set lsh params according to validation data
f-hafner Jan 24, 2023
94146a7
update docstrings; optimize lsh parameters
f-hafner Jan 25, 2023
575cb1d
small changes in lsh.py
f-hafner Jan 25, 2023
eb65bee
add __repr__ to lsh
f-hafner Jan 25, 2023
6907eca
improve docstrings, reorder imports
f-hafner Jan 25, 2023
f39fa94
further tidy efficiency tests
f-hafner Jan 25, 2023
20785e4
tidy docstring, add test for short mentions
f-hafner Jan 25, 2023
5e19915
some more comments, and reference online sources
f-hafner Jan 25, 2023
231ca45
make dirs for output of efficiency test if necessary
f-hafner Jan 25, 2023
3f06dac
use logging in with_coref
f-hafner Jan 25, 2023
5aa84db
add base_url argument to efficiency tests
f-hafner Feb 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 167 additions & 15 deletions scripts/efficiency_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,107 @@
import numpy as np
import argparse
import cProfile
import logging
import numpy as np
import os
import pickle
import pandas as pd
import pstats
import requests

from REL.training_datasets import TrainingEvaluationDatasets

np.random.seed(seed=42)

base_url = "/Users/vanhulsm/Desktop/projects/data/"
wiki_version = "wiki_2014"
datasets = TrainingEvaluationDatasets(base_url, wiki_version).load()["aida_testB"]

# random_docs = np.random.choice(list(datasets.keys()), 50)

server = True
def profile_to_df(call):
"""Helper function to profile a function call and save the timing in a pd df.

Source: https://stackoverflow.com/questions/44302726/pandas-how-to-store-cprofile-output-in-a-pandas-dataframe
"""
cProfile.run(call, filename="temp.txt")
st = pstats.Stats("temp.txt")

keys_from_k = ['file', 'line', 'fn']
keys_from_v = ['cc', 'ncalls', 'tottime', 'cumtime', 'callers']
data = {k: [] for k in keys_from_k + keys_from_v}

s = st.stats

for k in s.keys():
for i, kk in enumerate(keys_from_k):
data[kk].append(k[i])

for i, kk in enumerate(keys_from_v):
data[kk].append(s[k][i])

df = pd.DataFrame(data)
os.remove('temp.txt')
return df


parser = argparse.ArgumentParser()
parser.add_argument(
"--url",
dest="base_url",
type=str,
help="path to input and output data"
)
parser.add_argument(
'--search_corefs',
type=str,
choices=['all', 'lsh', 'off'],
default='all',
help="Setting for search_corefs in Entity Disambiguation."
)
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="Profile the disambiguation step."
)
parser.add_argument(
"--scale_mentions",
action="store_true",
default=False,
help="""Stack mentions in each dataset and time the disambiguation step by document.
This is to assess the time complexity of the program."""
)
parser.add_argument(
"--name_dataset",
type=str,
default="aida_testB",
help="Name of the training dataset to be used"
)
parser.add_argument(
"--n_docs",
type=int,
default=50,
help="Number of documents to be processed."
)
logging.basicConfig(level=logging.INFO) # do not print to file

args = parser.parse_args()
print(f"args.search_corefs is {args.search_corefs}")


# base_url = "/home/flavio/projects/rel20/data"
wiki_version = "wiki_2019"
datasets = TrainingEvaluationDatasets(args.base_url, wiki_version, args.search_corefs).load()[args.name_dataset]

# create directories where to save the output from the tests
dir_efficiency_test = os.path.join(args.base_url, "efficiency_test")
sub_directories = {
"profile": "profile",
"predictions": "predictions",
"n_mentions_time": "n_mentions_time"
}
sub_directories = {k: os.path.join(dir_efficiency_test, v) for k, v in sub_directories.items()}

for d in sub_directories.values():
if not os.path.exists(d):
os.makedirs(d)


server = False
docs = {}
for i, doc in enumerate(datasets):
sentences = []
Expand All @@ -20,8 +110,8 @@
sentences.append(x["sentence"])
text = ". ".join([x for x in sentences])

if len(docs) == 50:
print("length docs is 50.")
if len(docs) == args.n_docs:
print(f"length docs is {args.n_docs}.")
print("====================")
break

Expand Down Expand Up @@ -56,11 +146,11 @@
from REL.entity_disambiguation import EntityDisambiguation
from REL.mention_detection import MentionDetection

base_url = "C:/Users/mickv/desktop/data_back/"
# base_url = "C:/Users/mickv/desktop/data_back/" # why is this defined again here?

flair.device = torch.device("cuda:0")
flair.device = torch.device("cpu")

mention_detection = MentionDetection(base_url, wiki_version)
mention_detection = MentionDetection(args.base_url, wiki_version)

# Alternatively use Flair NER tagger.
tagger_ner = SequenceTagger.load("ner-fast")
Expand All @@ -72,11 +162,73 @@
# 3. Load model.
config = {
"mode": "eval",
"model_path": "{}/{}/generated/model".format(base_url, wiki_version),
"model_path": "{}/{}/generated/model".format(args.base_url, wiki_version),
}
model = EntityDisambiguation(base_url, wiki_version, config)
model = EntityDisambiguation(args.base_url, wiki_version, config, search_corefs=args.search_corefs)

# 4. Entity disambiguation.
start = time()
predictions, timing = model.predict(mentions_dataset)
print("ED took: {}".format(time() - start))

output = {
"mentions": mentions_dataset,
"predictions": predictions,
"timing": timing
}

iteration_identifier = f"{args.name_dataset}_{args.n_docs}_{args.search_corefs}"
filename = os.path.join(sub_directories["predictions"], iteration_identifier)

with open(f"{filename}.pickle", "wb") as f:
pickle.dump(output, f, protocol=pickle.HIGHEST_PROTOCOL)

# ## 4.b Profile the disambiguation part
if args.profile:
print("Profiling disambiguation")
filename = os.path.join(sub_directories["profile"], iteration_identifier)

df_stats = profile_to_df(call="model.predict(mentions_dataset)")
df_stats.to_csv(f"{filename}.csv", index=False)

# ## 4.c time disambiguation by document, vary number of mentions
if args.scale_mentions:
print("Scaling the mentions per document")
logging.basicConfig(level=logging.DEBUG)
mentions_dataset_scaled = {}

for k, data in mentions_dataset.items():
mentions_dataset_scaled[k] = data # add the baseline data as in mentions_dataset
for f in [5, 50, 100]:
d = data * f
key = f"{k}_{f}"
mentions_dataset_scaled[key] = d

print("Timing disambiguation per document")
timing_by_dataset = {}
for name, mentions in mentions_dataset_scaled.items():
print(f"predicting for dataset {name}", flush=True)
tempdict = {name: mentions} # format so that model.predict() works
start = time()
predictions, timing = model.predict(tempdict)
t = time() - start

timing_by_dataset[name] = {
"n_mentions": len(mentions),
"time": t
}

if args.profile:
print("Profiling disambiguation for synthetic data set")
df_profile = profile_to_df(call="model.predict(tempdict)")
timing_by_dataset[name]['profile'] = df_profile

# save timing by dataset
filename = os.path.join(sub_directories["n_mentions_time"], f"{args.name_dataset}_{args.search_corefs}" )

with open(f"{filename}.pickle", "wb") as f:
pickle.dump(timing_by_dataset, f, protocol=pickle.HIGHEST_PROTOCOL)




42 changes: 42 additions & 0 deletions scripts/run_efficiency_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

BASE_URL="$1"

DATASETS=("aida_testB")
DOCSIZES=(50 500)
COREF_OPTIONS=("all" "off" "lsh")


echo $DATASETS


echo "--Running efficiency tests by data set, n_docs and coref option--"

# do profiling and checking predictions in one
for size in ${DOCSIZES[@]}; do
for ds in ${DATASETS[@]}; do
for option in ${COREF_OPTIONS[@]}; do
echo $ds, echo $size, echo $option
python scripts/efficiency_test.py \
--url "$BASE_URL" \
--profile \
--n_docs $size \
--name_dataset "$ds" \
--search_corefs $option
done
done
done

# echo "--Scaling number of mentions--"

# for ds in ${datasets[@]}; do
# echo $ds
# python scripts/efficiency_test.py --name_dataset "$ds" --scale_mentions --profile --search_corefs "all"
# python scripts/efficiency_test.py --name_dataset "$ds" --scale_mentions --profile --search_corefs "lsh"
# python scripts/efficiency_test.py --name_dataset "$ds" --scale_mentions --profile --search_corefs "off"
# done


echo "Done."



24 changes: 20 additions & 4 deletions src/REL/entity_disambiguation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ class EntityDisambiguation:
Parent Entity Disambiguation class that directs the various subclasses used
for the ED step.
"""
def __init__(self, base_url, wiki_version, user_config, reset_embeddings=False):
def __init__(self, base_url, wiki_version, user_config, reset_embeddings=False, search_corefs="all"):
"""
Argument search_corefs: One of 'all' (default), 'lsh', 'off'.
If 'off', no coreference search is done.
Otherwise the arguments are passed to the argument `search_corefs_in` in `with_coref`.
"""
self.base_url = base_url
self.wiki_version = wiki_version
self.embeddings = {}
Expand All @@ -53,7 +58,9 @@ def __init__(self, base_url, wiki_version, user_config, reset_embeddings=False):
), "Glove embeddings in wrong folder..? Test embedding not found.."

self.__load_embeddings()
self.coref = TrainingEvaluationDatasets(base_url, wiki_version)
assert search_corefs in ['all', 'lsh', 'off']
self.search_corefs = search_corefs
self.coref = TrainingEvaluationDatasets(base_url, wiki_version, search_corefs)
self.prerank_model = PreRank(self.config).to(self.device)

self.__max_conf = None
Expand Down Expand Up @@ -470,7 +477,9 @@ def predict(self, data):
:return: predictions and time taken for the ED step.
"""

self.coref.with_coref(data)
if self.search_corefs != "off":
self.coref.with_coref(data, search_corefs_in=self.search_corefs)

data = self.get_data_items(data, "raw", predict=True)
predictions, timing = self.__predict(data, include_timing=True, eval_raw=True)

Expand Down Expand Up @@ -664,7 +673,12 @@ def __predict(self, data, include_timing=False, eval_raw=False):
]
doc_names = [m["doc_name"] for m in batch]

for dname, entity in zip(doc_names, pred_entities):
if self.search_corefs != 'off':
coref_indicators = [m['raw']['is_coref'] for m in batch]
else:
coref_indicators = [None for m in batch]

for dname, entity, is_coref in zip(doc_names, pred_entities, coref_indicators):
if entity[0] != "NIL":
predictions[dname].append(
{
Expand All @@ -673,6 +687,7 @@ def __predict(self, data, include_timing=False, eval_raw=False):
"candidates": entity[2],
"conf_ed": entity[4],
"scores": list([str(x) for x in entity[3]]),
"is_coref": is_coref
}
)

Expand All @@ -683,6 +698,7 @@ def __predict(self, data, include_timing=False, eval_raw=False):
"prediction": entity[0],
"candidates": entity[2],
"scores": [],
"is_coref": is_coref
}
)

Expand Down
Loading