diff --git a/scripts/efficiency_test.py b/scripts/efficiency_test.py index fc66f54..c47d5ed 100644 --- a/scripts/efficiency_test.py +++ b/scripts/efficiency_test.py @@ -1,17 +1,107 @@ -import numpy as np +import argparse +import cProfile +import logging +import numpy as np +import os +import pickle +import pandas as pd +import pstats import requests from REL.training_datasets import TrainingEvaluationDatasets np.random.seed(seed=42) -base_url = "/Users/vanhulsm/Desktop/projects/data/" -wiki_version = "wiki_2014" -datasets = TrainingEvaluationDatasets(base_url, wiki_version).load()["aida_testB"] - -# random_docs = np.random.choice(list(datasets.keys()), 50) - -server = True +def profile_to_df(call): + """Helper function to profile a function call and save the timing in a pd df. + + Source: https://stackoverflow.com/questions/44302726/pandas-how-to-store-cprofile-output-in-a-pandas-dataframe + """ + cProfile.run(call, filename="temp.txt") + st = pstats.Stats("temp.txt") + + keys_from_k = ['file', 'line', 'fn'] + keys_from_v = ['cc', 'ncalls', 'tottime', 'cumtime', 'callers'] + data = {k: [] for k in keys_from_k + keys_from_v} + + s = st.stats + + for k in s.keys(): + for i, kk in enumerate(keys_from_k): + data[kk].append(k[i]) + + for i, kk in enumerate(keys_from_v): + data[kk].append(s[k][i]) + + df = pd.DataFrame(data) + os.remove('temp.txt') + return df + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--url", + dest="base_url", + type=str, + help="path to input and output data" +) +parser.add_argument( + '--search_corefs', + type=str, + choices=['all', 'lsh', 'off'], + default='all', + help="Setting for search_corefs in Entity Disambiguation." +) +parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Profile the disambiguation step." + ) +parser.add_argument( + "--scale_mentions", + action="store_true", + default=False, + help="""Stack mentions in each dataset and time the disambiguation step by document. + This is to assess the time complexity of the program.""" + ) +parser.add_argument( + "--name_dataset", + type=str, + default="aida_testB", + help="Name of the training dataset to be used" +) +parser.add_argument( + "--n_docs", + type=int, + default=50, + help="Number of documents to be processed." +) +logging.basicConfig(level=logging.INFO) # do not print to file + +args = parser.parse_args() +print(f"args.search_corefs is {args.search_corefs}") + + +# base_url = "/home/flavio/projects/rel20/data" +wiki_version = "wiki_2019" +datasets = TrainingEvaluationDatasets(args.base_url, wiki_version, args.search_corefs).load()[args.name_dataset] + +# create directories where to save the output from the tests +dir_efficiency_test = os.path.join(args.base_url, "efficiency_test") +sub_directories = { + "profile": "profile", + "predictions": "predictions", + "n_mentions_time": "n_mentions_time" +} +sub_directories = {k: os.path.join(dir_efficiency_test, v) for k, v in sub_directories.items()} + +for d in sub_directories.values(): + if not os.path.exists(d): + os.makedirs(d) + + +server = False docs = {} for i, doc in enumerate(datasets): sentences = [] @@ -20,8 +110,8 @@ sentences.append(x["sentence"]) text = ". ".join([x for x in sentences]) - if len(docs) == 50: - print("length docs is 50.") + if len(docs) == args.n_docs: + print(f"length docs is {args.n_docs}.") print("====================") break @@ -56,11 +146,11 @@ from REL.entity_disambiguation import EntityDisambiguation from REL.mention_detection import MentionDetection - base_url = "C:/Users/mickv/desktop/data_back/" + # base_url = "C:/Users/mickv/desktop/data_back/" # why is this defined again here? - flair.device = torch.device("cuda:0") + flair.device = torch.device("cpu") - mention_detection = MentionDetection(base_url, wiki_version) + mention_detection = MentionDetection(args.base_url, wiki_version) # Alternatively use Flair NER tagger. tagger_ner = SequenceTagger.load("ner-fast") @@ -72,11 +162,73 @@ # 3. Load model. config = { "mode": "eval", - "model_path": "{}/{}/generated/model".format(base_url, wiki_version), + "model_path": "{}/{}/generated/model".format(args.base_url, wiki_version), } - model = EntityDisambiguation(base_url, wiki_version, config) + model = EntityDisambiguation(args.base_url, wiki_version, config, search_corefs=args.search_corefs) # 4. Entity disambiguation. start = time() predictions, timing = model.predict(mentions_dataset) print("ED took: {}".format(time() - start)) + + output = { + "mentions": mentions_dataset, + "predictions": predictions, + "timing": timing + } + + iteration_identifier = f"{args.name_dataset}_{args.n_docs}_{args.search_corefs}" + filename = os.path.join(sub_directories["predictions"], iteration_identifier) + + with open(f"{filename}.pickle", "wb") as f: + pickle.dump(output, f, protocol=pickle.HIGHEST_PROTOCOL) + + # ## 4.b Profile the disambiguation part + if args.profile: + print("Profiling disambiguation") + filename = os.path.join(sub_directories["profile"], iteration_identifier) + + df_stats = profile_to_df(call="model.predict(mentions_dataset)") + df_stats.to_csv(f"{filename}.csv", index=False) + + # ## 4.c time disambiguation by document, vary number of mentions + if args.scale_mentions: + print("Scaling the mentions per document") + logging.basicConfig(level=logging.DEBUG) + mentions_dataset_scaled = {} + + for k, data in mentions_dataset.items(): + mentions_dataset_scaled[k] = data # add the baseline data as in mentions_dataset + for f in [5, 50, 100]: + d = data * f + key = f"{k}_{f}" + mentions_dataset_scaled[key] = d + + print("Timing disambiguation per document") + timing_by_dataset = {} + for name, mentions in mentions_dataset_scaled.items(): + print(f"predicting for dataset {name}", flush=True) + tempdict = {name: mentions} # format so that model.predict() works + start = time() + predictions, timing = model.predict(tempdict) + t = time() - start + + timing_by_dataset[name] = { + "n_mentions": len(mentions), + "time": t + } + + if args.profile: + print("Profiling disambiguation for synthetic data set") + df_profile = profile_to_df(call="model.predict(tempdict)") + timing_by_dataset[name]['profile'] = df_profile + + # save timing by dataset + filename = os.path.join(sub_directories["n_mentions_time"], f"{args.name_dataset}_{args.search_corefs}" ) + + with open(f"{filename}.pickle", "wb") as f: + pickle.dump(timing_by_dataset, f, protocol=pickle.HIGHEST_PROTOCOL) + + + + diff --git a/scripts/run_efficiency_tests.sh b/scripts/run_efficiency_tests.sh new file mode 100644 index 0000000..85059e0 --- /dev/null +++ b/scripts/run_efficiency_tests.sh @@ -0,0 +1,42 @@ + +BASE_URL="$1" + +DATASETS=("aida_testB") +DOCSIZES=(50 500) +COREF_OPTIONS=("all" "off" "lsh") + + +echo $DATASETS + + +echo "--Running efficiency tests by data set, n_docs and coref option--" + +# do profiling and checking predictions in one +for size in ${DOCSIZES[@]}; do + for ds in ${DATASETS[@]}; do + for option in ${COREF_OPTIONS[@]}; do + echo $ds, echo $size, echo $option + python scripts/efficiency_test.py \ + --url "$BASE_URL" \ + --profile \ + --n_docs $size \ + --name_dataset "$ds" \ + --search_corefs $option + done + done +done + +# echo "--Scaling number of mentions--" + +# for ds in ${datasets[@]}; do +# echo $ds +# python scripts/efficiency_test.py --name_dataset "$ds" --scale_mentions --profile --search_corefs "all" +# python scripts/efficiency_test.py --name_dataset "$ds" --scale_mentions --profile --search_corefs "lsh" +# python scripts/efficiency_test.py --name_dataset "$ds" --scale_mentions --profile --search_corefs "off" +# done + + +echo "Done." + + + diff --git a/src/REL/entity_disambiguation.py b/src/REL/entity_disambiguation.py index 838670c..f07d70a 100644 --- a/src/REL/entity_disambiguation.py +++ b/src/REL/entity_disambiguation.py @@ -32,7 +32,12 @@ class EntityDisambiguation: Parent Entity Disambiguation class that directs the various subclasses used for the ED step. """ - def __init__(self, base_url, wiki_version, user_config, reset_embeddings=False): + def __init__(self, base_url, wiki_version, user_config, reset_embeddings=False, search_corefs="all"): + """ + Argument search_corefs: One of 'all' (default), 'lsh', 'off'. + If 'off', no coreference search is done. + Otherwise the arguments are passed to the argument `search_corefs_in` in `with_coref`. + """ self.base_url = base_url self.wiki_version = wiki_version self.embeddings = {} @@ -53,7 +58,9 @@ def __init__(self, base_url, wiki_version, user_config, reset_embeddings=False): ), "Glove embeddings in wrong folder..? Test embedding not found.." self.__load_embeddings() - self.coref = TrainingEvaluationDatasets(base_url, wiki_version) + assert search_corefs in ['all', 'lsh', 'off'] + self.search_corefs = search_corefs + self.coref = TrainingEvaluationDatasets(base_url, wiki_version, search_corefs) self.prerank_model = PreRank(self.config).to(self.device) self.__max_conf = None @@ -470,7 +477,9 @@ def predict(self, data): :return: predictions and time taken for the ED step. """ - self.coref.with_coref(data) + if self.search_corefs != "off": + self.coref.with_coref(data, search_corefs_in=self.search_corefs) + data = self.get_data_items(data, "raw", predict=True) predictions, timing = self.__predict(data, include_timing=True, eval_raw=True) @@ -664,7 +673,12 @@ def __predict(self, data, include_timing=False, eval_raw=False): ] doc_names = [m["doc_name"] for m in batch] - for dname, entity in zip(doc_names, pred_entities): + if self.search_corefs != 'off': + coref_indicators = [m['raw']['is_coref'] for m in batch] + else: + coref_indicators = [None for m in batch] + + for dname, entity, is_coref in zip(doc_names, pred_entities, coref_indicators): if entity[0] != "NIL": predictions[dname].append( { @@ -673,6 +687,7 @@ def __predict(self, data, include_timing=False, eval_raw=False): "candidates": entity[2], "conf_ed": entity[4], "scores": list([str(x) for x in entity[3]]), + "is_coref": is_coref } ) @@ -683,6 +698,7 @@ def __predict(self, data, include_timing=False, eval_raw=False): "prediction": entity[0], "candidates": entity[2], "scores": [], + "is_coref": is_coref } ) diff --git a/src/REL/lsh.py b/src/REL/lsh.py new file mode 100644 index 0000000..bd174e1 --- /dev/null +++ b/src/REL/lsh.py @@ -0,0 +1,337 @@ +"""Implement a simple version of locality-sensitive hashing. + +To deal with high-dimensional data (=many mentions), the class stores the feature vectors +as sparse matrices and uses random projections as hash functions. + +See chapter 3 in "Mining of Massive Datasets" (http://www.mmds.org/). +The time complexity is explained at the end of this video: https://www.youtube.com/watch?v=Arni-zkqMBA +(number of hyperplanes = band length). +The use of multiple bands is called amplification, which is discussed in the book +but not in the video. +""" + +import itertools +import logging +import math +import numpy as np +import time + +from scipy import sparse +from sklearn.preprocessing import MultiLabelBinarizer + +# First, define a bunch of functions. TODO: should they be defined elsewhere? put in utils? + +def k_shingle(s, k): + "Convert string s into shingles of length k." + shingle = [] + for i in range(len(s) - k + 1): + shingle.append(s[i:(i+k)]) + return shingle + + +def cols_to_int_multidim(a): + """Combine columns in all rows to an integer. + + For instance, [[1,20,3], [1,4,10]] becomes [1203,1410]. + + Notes + ----- + The advantage is that it uses vectorized numpy to collapse an + entire row into one integer. The disadvantage is that one additional row increases + the size of the integer at least by an order of magnitude, which only works for cases where + the bands are not too large. But in practice, optimal bands are typically not long enough + to cause problems. + + :param a: 2-dimensional array + :type a: np.ndarray + :returns: An array of shape (n, 1), where the horizontally neighboring column values + are appended together. + :rtype: np.ndarray + """ + existing_powers = np.floor(np.log10(a)) + n_bands, nrows, ncols = a.shape + + # sum existing powers from right to left + cumsum_powers = np.flip(np.cumsum(np.flip(existing_powers, axis=2), axis=2), axis=2) + + add_powers = [x for x in reversed(range(ncols))] + add_powers = np.tile(add_powers, (nrows, 1)) + + mult_factor = cumsum_powers - existing_powers + add_powers + summationvector = np.ones((ncols, 1)) + out = np.matmul(a * 10**mult_factor, summationvector) + return out + +def signature_to_3d_bands(a, n_bands, band_length): + """Convert a signature from 2d to 3d. + + Convert a signature array of dimension (n_items, signature_length) into an array + of (n_bands, n_items, band_length). + + Notes + ----- + This produces the same output as np.vstack(np.split(a, indices_or_sections=n_bands, axis=1)). + When further processing the output, this is a useful alternative to looping on the output of + np.split(a, indices_or_sections=n_bands, axis=1) because a single vectorized call can be used, + while np.vstack(np.split(...)) is likely to be less efficient. + + :param a: Array with 2 dimensions + :type a: np.ndarray + :param n_bands: Number of bands the columns to cut into + :type n_bands: int + :param band_length: Length of each band + :type band_length: int + :returns: Array of shape (n_bands, n_items, band_length) + :rtype: np.ndarray + """ + n_items, signature_length = a.shape + + # stacked bands of each item, stacked together + stacked_bands = a.reshape(n_items*n_bands, band_length) + # reorder so that the first band of all items comes first, then the second band of all items, etc. + reordering_vector = np.arange(n_items*n_bands).reshape(n_items, n_bands).T.reshape(1, -1) + + result = stacked_bands[reordering_vector, :].reshape(n_bands, n_items, band_length) + return result + +def group_unique_indices(a): + """Compute indices of matching rows. + + In a 3-dimensional array, for each array (axis 0), + compute the indices of rows (axis=1) that are identical. + Based on 1d-version here: https://stackoverflow.com/questions/23268605/grouping-indices-of-unique-elements-in-numpy + + :param a: 3-dimensional array + :type a: np.ndarray + :returns: List of lists. Outer lists correspond to bands. + Inner lists correspond to the row indices that + have the same values in their columns. An item + in the inner list is an np.array. + :rtype: list + """ + n_bands, n_items, length_band = a.shape + a = cols_to_int_multidim(a).squeeze() + + sort_idx = np.argsort(a, axis=1) # necessary for later, need to calc anyway + a_sorted = np.sort(a, axis=1) # faster alternative to np.take_along_axis(b, sort_idx, axis=1) + + # indicators for where a sequence of different unique elements starts + indicators = a_sorted[:, 1:] != a_sorted[:, :-1] + first_element = np.tile([[True]], n_bands).T + unq_first = np.concatenate((first_element, indicators), axis=1) + + # calculate number of unique items + unq_count = [np.diff(np.nonzero(row)[0]) for row in unq_first] # iterate through rows. + # split sorted array into groups of identical items. only keep groups with more than one item. + unq_idx = [[a for a in np.split(sort_idx[i], np.cumsum(count)) if len(a) > 1] for i, count in enumerate(unq_count)] + + return unq_idx + +# ## Here follow the classes + +class LSHBase: + """ + Base class for locality-sensitive hashing. + + Attributes + ---------- + shingle_size + Size of shingles to be created from mentions + mentions + Mentions in which to search for similar items + + Methods + ------- + encode_binary() + One-hot encode mentions, based on shingles + """ + + def __init__(self, mentions, shingle_size): + """ + + Parameters + ---------- + :param mentions: Mentions in which to search for similar items + :type mentions: list or dict + :param shingle_size: Length of substrings to be created from mentions ("shingles") + :type shingle_size: int + """ + self.shingle_size = shingle_size + if isinstance(mentions, dict): + self.shingles = [k_shingle(m, shingle_size) for m in mentions.values()] + elif isinstance(mentions, list): + self.shingles = [k_shingle(m, shingle_size) for m in mentions] + self._rep_items_not_show = ["shingles"] # do not show in __repr__ b/c too long + + def __repr__(self): + items_dict_show = {k: v for k, v in self.__dict__.items() + if k not in self._rep_items_not_show + and k[0] != "_" # omit private attributes + } + items_dict_show = [f"{k}={v}" for k, v in items_dict_show.items()] + return f"<{type(self).__name__}() with {', '.join(items_dict_show)}>" + + def _build_vocab(self): + "Make vocabulary of unique shingles in all mentions." + logging.debug("making vocabulary from shingles") + vocab = list(set([shingle for sublist in self.shingles for shingle in sublist])) + self.vocab = vocab + + def encode_binary(self): + """Create sparse binary vectors for each mention. + + :return: Indicator matrix. + Rows indicate mentions, columns indicate whether + the mention contains the shingle. + :rtype: scipy.sparse.csr_matrix + """ + logging.debug("making one-hot vectors") + binarizer = MultiLabelBinarizer(sparse_output=True) + self.vectors = binarizer.fit_transform(self.shingles) + + +class LSHRandomProjections(LSHBase): + """Class for locality-sensitive hashing with random projections. + + Attributes + ----------- + mentions + List or dict of mentions + shingle_size + Length of the shingles to be constructed from each string in `mentions` + n_bands, band_length + The signature of a mention will be n_bands*band_length. + Longer bands increase precision, more bands increase recall. + If band_length is `None`, it is set as log(len(mentions)), which + will guarantee O(log(N)) time complexity. + seed + Random seed for np.random.default_rng + + Methods + -------- + make_signature() + Create a dense signature vector with random projections. + get_candidates() + Find groups of mentions overlapping signatures. + cluster() + End-to-end hashing from shingles to clusters. + This is the main functionality of the class. + summarise() + Summarise time and output of cluster() + efficiency_gain_comparisons() + Compare number of computations for coreference search with hashing + and without hashing. + """ + + def __init__(self, mentions, shingle_size, n_bands, band_length=None, seed=3): + """ + + Parameters + ---------- + :param mentions: Mentions in which to search for similar items + :type mentions: list or dict + :param shingle_size: Length of substrings to be created from mentions ("shingles") + :type shingle_size: int + :param n_bands: Number of signature bands (equal-sized cuts of the full signature) + :type n_bands: int + :param band_length: Length of bands + :type band_length: int or None + :seed: Random seed for random number generator from numpy + :type seed: int + """ + super().__init__(mentions, shingle_size) + self.seed = seed + self.n_bands = n_bands + if band_length is None: + log_n_mentions = math.ceil(math.log(len(mentions))) # for O(log(N)) complexity + self.band_length = max(1, log_n_mentions) # use 1 if exp(log(n_mentions)) < 1 + else: + self.band_length = band_length + self.signature_size = n_bands * self.band_length + self.rng = np.random.default_rng(seed=self.seed) + self._rep_items_not_show.extend(["signature_size", "rng"]) + + def make_signature(self): + "Create a matrix of signatures with random projections." + logging.debug(f"Making signature. vectors shape is {self.vectors.shape}") + n_rows = self.signature_size + n_cols = self.vectors.shape[1] + hyperplanes = sparse.csr_matrix( + self.rng.choice([-1, 1], (n_rows, n_cols)) + ) + products = self.vectors.dot(hyperplanes.transpose()).toarray() + sign = 1 + (products > 0) # need +1 for cols_to_int_multidim + self.signature = sign + + def _all_candidates_to_all(self): + """Assign all mentions as candidates to all other mentions. + For edge cases where no single mention is longer than the shingle size. + """ + n_mentions = self.vectors.shape[0] + self.candidates = [set(range(n_mentions)) for _ in range(n_mentions)] + + def get_candidates(self): + """Extract similar mentions from signature. + + For each mention, extract similar mentions based on whether part + of their signatures overlap. + + :return: Index of mentions that are similar to each other. + A list of the candidate set of similar mentions. + :rtype: list + """ + logging.debug("getting candidates...") + if self.vectors.shape[0] == 1: + candidates = [set()] + candidates[0].add(0) + else: + candidates = [set() for _ in range(self.vectors.shape[0])] + + bands = signature_to_3d_bands(self.signature, n_bands=self.n_bands, band_length=self.band_length) + buckets_by_band = group_unique_indices(bands) + groups = [tuple(i) for i in itertools.chain.from_iterable(buckets_by_band)] # flatten group; use tuple for applying set() + groups = set(groups) # we only need the unique groups + + for group in groups: + for i in group: + candidates[i].update(group) + + [candidates[i].discard(i) for i in range(len(candidates))] + self.candidates = candidates + + def cluster(self): + """End-to-end locality-sensitive hashing. + + Cluster mentions together based on their similarity. + + :return: Index of mentions that are similar to each other. + A list of the candidate set of similar mentions. + :rtype: list + """ + start = time.time() + self._build_vocab() + self.encode_binary() + + if self.vectors.shape[1] == 0: # no signature possible b/c no mention is longer than the shingle size. + self._all_candidates_to_all() + else: + self.make_signature() + self.get_candidates() + self.time = time.time() - start + + def summarise(self): + "Summarise the time taken and output from clustering one LSH instance." + sizes = [len(g) for g in self.candidates] + print(f"took {self.time} seconds for {len(self.candidates)} mentions") + print(f"average, min, max cluster size: {round(sum(sizes)/len(sizes),2)}, {min(sizes)}, {max(sizes)}") + + def efficiency_gain_comparisons(self): + """ + Compare number of comparisons made for coreference search with option + "lsh" and option "all". Useful for understanding time complexity, + and to assess whether number of comparisons is meaningfully reduced. + """ + sizes = [len(g) for g in self.candidates] + runtime_all = len(self.candidates) * len(self.candidates) + runtime_lsh = len(self.candidates) * (sum(sizes)/len(sizes)) + print(f"option 'lsh' makes fraction {round(runtime_lsh/runtime_all, 2)} of comparisons relative to option 'all'.") diff --git a/src/REL/training_datasets.py b/src/REL/training_datasets.py index 0e41d62..eef6ec8 100644 --- a/src/REL/training_datasets.py +++ b/src/REL/training_datasets.py @@ -1,6 +1,9 @@ +import logging import os import pickle +import math +from REL.lsh import LSHRandomProjections class TrainingEvaluationDatasets: """ @@ -9,11 +12,18 @@ class TrainingEvaluationDatasets: Reading dataset from CoNLL dataset, extracted by https://github.com/dalab/deep-ed/ """ - def __init__(self, base_url, wiki_version): + def __init__(self, base_url, wiki_version, search_corefs="all"): + """ + Argument search_corefs: One of 'all' (default), 'lsh', 'off'. + If 'off', no coreference search is done. + Otherwise the arguments are passed to the argument `search_corefs_in` in `with_coref`. + """ self.person_names = self.__load_person_names( os.path.join(base_url, "generic/p_e_m_data/persons.txt") ) self.base_url = os.path.join(base_url, wiki_version) + assert search_corefs in ['all', 'lsh', 'off'] + self.search_corefs = search_corefs def load(self): """ @@ -44,7 +54,8 @@ def load(self): if "Jiří_Třanovský Jiří_Třanovský" in datasets[ds]: del datasets[ds]["Jiří_Třanovský Jiří_Třanovský"] - self.with_coref(datasets[ds]) + if self.search_corefs != "off": + self.with_coref(datasets[ds], search_corefs_in=self.search_corefs) return datasets @@ -103,23 +114,64 @@ def __find_coref(self, ment, mentlist): return coref - def with_coref(self, dataset): + def with_coref(self, dataset, search_corefs_in="all"): + """Replace candidates of coreferring mentions with main mention. + + Check if there are coreferences in the given dataset, and replace + the candidate entity of a coreferring mention with the candidates + from the main mention. + + Example + ------- + If a document contains both "Jimi Hendrix" and "Hendrix" as a mention, + then the candidate entities of "Hendrix" will be replaced by the candidate + entities of "Jimi Hendrix". + + Parameters: + ----------- + :param search_corefs_in: in which set to search for coreferences. + Either of "lsh" or "all". + If 'all', search for coreferences among all mentions in document + If 'lsh', search for coreferences among a pre-selected set of candidates. + The set is calculated with LSH. + :type search_corefs_in: string. + :return: dataset with updated candidate entities and p(e|m) scores. """ - Parent function that checks if there are coreferences in the given dataset. - - :return: dataset - """ - + logging.info(f"with_coref() is called with search_corefs_in={search_corefs_in}.") + assert search_corefs_in in ['lsh', 'all'] for data_name, content in dataset.items(): - for cur_m in content: - coref = self.__find_coref(cur_m, content) - if coref is not None and len(coref) > 0: - cur_cands = {} - for m in coref: - for c, p in m["candidates"]: - cur_cands[c] = cur_cands.get(c, 0) + p - for c in cur_cands.keys(): - cur_cands[c] /= len(coref) - cur_m["candidates"] = sorted( - list(cur_cands.items()), key=lambda x: x[1] - )[::-1] + if len(content) == 0: + pass + else: + if search_corefs_in == 'lsh': + input_mentions = [m["mention"] for m in content] + band_length = math.ceil(math.log(len(input_mentions))) + lsh_corefs = LSHRandomProjections( + mentions=input_mentions, + shingle_size=2, + n_bands=15, + band_length=band_length + ) + lsh_corefs.cluster() + assert len(content) == len(lsh_corefs.candidates) + # lsh_corefs.candidates are the input for below. indices refer to index in input_mentions + for idx_mention, cur_m in enumerate(content): + if search_corefs_in == "lsh": + idx_candidates = list(lsh_corefs.candidates[idx_mention]) # lsh returns the indices of the candidate coreferences + candidates = [content[i] for i in idx_candidates] + elif search_corefs_in == "all": + candidates = content + coref = self.__find_coref(cur_m, candidates) + if coref is not None and len(coref) > 0: + cur_cands = {} + for m in coref: + for c, p in m["candidates"]: + cur_cands[c] = cur_cands.get(c, 0) + p + for c in cur_cands.keys(): + cur_cands[c] /= len(coref) + cur_m["candidates"] = sorted( + list(cur_cands.items()), key=lambda x: x[1] + )[::-1] + cur_m["is_coref"] = 1 + else: + cur_m["is_coref"] = 0 diff --git a/src/REL/utils.py b/src/REL/utils.py index 1e52cf6..2abef0d 100644 --- a/src/REL/utils.py +++ b/src/REL/utils.py @@ -90,6 +90,7 @@ def process_results( idx = ment["sent_idx"] start_pos = ment["pos"] mention_length = int(ment["end_pos"] - ment["pos"]) + is_coref = pred['is_coref'] if pred["prediction"] != "NIL": temp = ( @@ -100,6 +101,7 @@ def process_results( pred["conf_ed"], ment["conf_md"] if "conf_md" in ment else 0.0, ment["tag"] if "tag" in ment else "NULL", + is_coref, ) res_doc.append(temp) res[doc] = res_doc diff --git a/tests/test_lsh.py b/tests/test_lsh.py new file mode 100644 index 0000000..1aa7ed6 --- /dev/null +++ b/tests/test_lsh.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from pathlib import Path + +import REL.lsh as lsh +import numpy as np +import itertools + + +def test_k_shingle(): + output = lsh.k_shingle("random string", 5) + expected = ["rando", "andom", "ndom ", "dom s", "om st", "m str", " stri", "strin", "tring"] + assert output == expected, "shingles not built correctly" + + +def test_cols_to_int_multidim(): + a = np.array([[[1, 20, 3], [1, 4, 10]], + [[1, 3, 5], [100, 3, 50]]] + ) + output = lsh.cols_to_int_multidim(a) + expected = np.array( + [ + [[1203], [1410]], + [[135], [100350]] + ] + ) + assert np.all(output == expected), "rows do not convert correctly to integer" + +def test_signature_to_3d_bands(): + a = np.array( + [ + [1, 4, 7, 8, 10, 8], + [5, 3, 2, 6, 11, 0], + [1, 4, 2, 6, 13, 15] + ] + ) + + n_bands = 2 + n_items = a.shape[0] + band_length = int(a.shape[1]/n_bands) + result = lsh.signature_to_3d_bands(a, n_bands=n_bands, band_length=band_length) + + expected = np.vstack(np.split(a, n_bands, axis=1)).reshape(n_bands, n_items, -1) + assert np.all(result == expected), "signature not correctly converted to 3d bands" + +def test_group_unique_indices(): + a = np.array([[[1, 4], [1, 4], [5,3], [5, 3], [1 , 2]], + [[7,8], [2, 7], [2, 7], [7, 8], [10, 3]] + ]) + output = lsh.group_unique_indices(a) + + # build expected + groups_band0 = [[0, 1], [2, 3]] + groups_band1 = [[1, 2], [0, 3]] + # Notes: + # [1,2], [10,3] are not listed because their group is of size 1. + # [2,7] is before [7, 8] because 27 < 78 + groups_band0 = [np.array(i) for i in groups_band0] + groups_band1 = [np.array(i) for i in groups_band1] + expected = [groups_band0, groups_band1] + + o = itertools.chain.from_iterable(output) + e = itertools.chain.from_iterable(expected) + + # test + assert all([np.all(i==j) for i, j in zip(o, e)]), "unique indices not grouped correctly" + +def test_cluster_short_mentions(): + mentions = ['EEC', 'ABC'] + max_length = max([len(m) for m in mentions]) + mylsh = lsh.LSHRandomProjections( + mentions=mentions, + shingle_size=max_length + 1, + n_bands=15) + mylsh.cluster() + expected = [set((0, 1)), set((0, 1))] + assert expected == mylsh.candidates, \ + "lsh fails when shingle size longer than longest input mentions"