diff --git a/CHANGELOG.md b/CHANGELOG.md index b9f1935..6665650 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## 2020.4.23 - Pre-release + +* Add multitask split_parse command and tests, called with python -m deep_reference_parser split_parse +* Fix issues with training data creation +* Output predictions of validation data by default +* Various improvements - using tox for testing, refactoring, improving error messages, README and tests + ## 2020.3.3 - Pre-release NOTE: This version includes changes to both the way that model artefacts are packaged and saved, and the way that data are laded and parsed from tsv files. This results in a significantly faster training time (c.14 hours -> c.0.5 hour), but older models will no longer be compatible. For compatibility you must use multitask modles > 2020.3.19, splitting models > 2020.3.6, and parisng models > 2020.3.8. These models currently perform less well than previous versions, but performance is expected to improve with more data and experimentation predominatly around sequence length. diff --git a/Makefile b/Makefile index ea9db32..fd269e8 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ WORD_EMBEDDING := 2020.1.1-wellcome-embeddings-300 WORD_EMBEDDING_TEST := 2020.1.1-wellcome-embeddings-10-test MODEL_PATH := models -MODEL_VERSION := 2019.12.0 +MODEL_VERSION := multitask/2020.4.5_multitask # # S3 Bucket diff --git a/README.md b/README.md index 3ba0aef..94016e3 100644 --- a/README.md +++ b/README.md @@ -59,20 +59,20 @@ Current mode version: *2020.3.8_parsing* #### Multitask model (splitting and parsing) -Current mode version: *2020.3.19_multitask* +Current mode version: *2020.4.5_multitask* |token|f1| |---|---| -|author|0.9102| -|title|0.8809| -|year|0.7469| -|o|0.8892| -|parsing weighted avg|0.8869| -|b-r|0.8254| -|e-r|0.7908| -|i-r|0.9563| -|o|0.7560| -|weighted avg|0.9240| +|author|0.9458| +|title|0.9002| +|year|0.8704| +|o|0.9407| +|parsing weighted avg|0.9285| +|b-r|0.9111| +|e-r|0.8788| +|i-r|0.9726| +|o|0.9332| +|weighted avg|0.9591| #### Computing requirements @@ -82,7 +82,7 @@ Models are trained on AWS instances using CPU only. |---|---|---|---|---| |Span detection|00:26:41|m4.4xlarge|$0.88|$0.39| |Components|00:17:22|m4.4xlarge|$0.88|$0.25| -|MultiTask|00:19:56|m4.4xlarge|$0.88|$0.29| +|MultiTask|00:42:43|c4.4xlarge|$0.91|$0.63| ## tl;dr: Just get me to the references! diff --git a/deep_reference_parser/configs/2020.4.5_multitask.ini b/deep_reference_parser/configs/2020.4.5_multitask.ini index bfb82b5..ab8f3c6 100644 --- a/deep_reference_parser/configs/2020.4.5_multitask.ini +++ b/deep_reference_parser/configs/2020.4.5_multitask.ini @@ -18,7 +18,7 @@ policy_valid = data/processed/annotated/deep_reference_parser/multitask/2020.3.1 s3_slug = https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser/ [build] -output_path = data/models/multitask/2020.4.5_multitask/ +output_path = models/multitask/2020.4.5_multitask/ output = crf word_embeddings = embeddings/2020.1.1-wellcome-embeddings-300.txt pretrained_embedding = 0 diff --git a/deep_reference_parser/split_parse.py b/deep_reference_parser/split_parse.py index 390ee11..274cb2c 100644 --- a/deep_reference_parser/split_parse.py +++ b/deep_reference_parser/split_parse.py @@ -25,7 +25,7 @@ from deep_reference_parser.logger import logger from deep_reference_parser.model_utils import get_config from deep_reference_parser.reference_utils import break_into_chunks - from deep_reference_parser.tokens_to_references import tokens_to_references + from deep_reference_parser.tokens_to_references import tokens_to_reference_lists msg = wasabi.Printer(icons={"check": "\u2023"}) @@ -138,35 +138,31 @@ def split_parse(self, text, return_tokens=False, verbose=False): else: - # TODO: return references with attributes (author, title, year) - # in json format. For now just return predictions as they are to - # allow testing of endpoints. + # Return references with attributes (author, title, year) + # in json format. + # List of lists for each reference - each reference list contains all token attributes predictions + # [[(token, attribute), ... , (token, attribute)], ..., [(token, attribute), ...]] - return preds - - # # Otherwise convert the tokens into references and return - - # refs = tokens_to_references(tokens, preds) - - # if verbose: + references_components = tokens_to_reference_lists(tokens, spans=preds[1], components=preds[0]) + if verbose: - # msg.divider("Results") + msg.divider("Results") - # if refs: + if references_components: - # msg.good(f"Found {len(refs)} references.") - # msg.info("Printing found references:") + msg.good(f"Found {len(references_components)} references.") + msg.info("Printing found references:") - # for ref in refs: - # msg.text(ref, icon="check", spaced=True) + for ref in references_components: + msg.text(ref['Reference'], icon="check", spaced=True) - # else: + else: - # msg.fail("Failed to find any references.") + msg.fail("Failed to find any references.") - # out = refs + out = references_components - #return out + return out @plac.annotations( diff --git a/deep_reference_parser/tokens_to_references.py b/deep_reference_parser/tokens_to_references.py index caa8595..ff0dd76 100644 --- a/deep_reference_parser/tokens_to_references.py +++ b/deep_reference_parser/tokens_to_references.py @@ -10,21 +10,12 @@ from .deep_reference_parser import logger -def tokens_to_references(tokens, labels): - """ - Given a corresponding list of tokens and a list of labels: split the tokens - and return a list of references. - - Args: - tokens(list): A list of tokens. - labels(list): A corresponding list of labels. - - """ +def get_reference_spans(tokens, spans): # Flatten the lists of tokens and predictions into a single list. flat_tokens = list(itertools.chain.from_iterable(tokens)) - flat_predictions = list(itertools.chain.from_iterable(labels)) + flat_predictions = list(itertools.chain.from_iterable(spans)) # Find all b-r and e-r tokens. @@ -37,25 +28,67 @@ def tokens_to_references(tokens, labels): logger.debug("Found %s b-r tokens", len(ref_starts)) logger.debug("Found %s e-r tokens", len(ref_ends)) - references = [] - n_refs = len(ref_starts) # Split on each b-r. - # TODO: It may be worth including some simple post processing steps here - # to pick up false positives, for instance cutting short a reference - # after n tokens. + token_starts = [] + token_ends = [] for i in range(0, n_refs): - token_start = ref_starts[i] + token_starts.append(ref_starts[i]) if i + 1 < n_refs: - - token_end = ref_starts[i + 1] - 1 + token_ends.append(ref_starts[i + 1] - 1) else: - token_end = len(flat_tokens) + token_ends.append(len(flat_tokens)) + + return token_starts, token_ends, flat_tokens + + +def tokens_to_references(tokens, labels): + """ + Given a corresponding list of tokens and a list of labels: split the tokens + and return a list of references. + Args: + tokens(list): A list of tokens. + labels(list): A corresponding list of labels. + + """ + + token_starts, token_ends, flat_tokens = get_reference_spans(tokens, labels) + + references = [] + for token_start, token_end in zip(token_starts, token_ends): ref = flat_tokens[token_start : token_end + 1] flat_ref = " ".join(ref) references.append(flat_ref) return references + +def tokens_to_reference_lists(tokens, spans, components): + """ + Given a corresponding list of tokens, a list of + reference spans (e.g. 'b-r') and components (e.g. 'author;): + split the tokens according to the spans and return a + list of reference components for each reference. + + Args: + tokens(list): A list of tokens. + spans(list): A corresponding list of reference spans. + components(list): A corresponding list of reference components. + + """ + + token_starts, token_ends, flat_tokens = get_reference_spans(tokens, spans) + flat_components = list(itertools.chain.from_iterable(components)) + + references_components = [] + for token_start, token_end in zip(token_starts, token_ends): + + ref_tokens = flat_tokens[token_start : token_end + 1] + ref_components = flat_components[token_start : token_end + 1] + flat_ref = " ".join(ref_tokens) + + references_components.append({'Reference': flat_ref, 'Attributes': list(zip(ref_tokens, ref_components))}) + + return references_components diff --git a/tests/common.py b/tests/common.py index 2bf6107..d5a4735 100644 --- a/tests/common.py +++ b/tests/common.py @@ -9,6 +9,7 @@ def get_path(p): TEST_CFG = get_path("test_data/test_config.ini") +TEST_CFG_MULTITASK = get_path("test_data/test_config_multitask.ini") TEST_JSONL = get_path("test_data/test_jsonl.jsonl") TEST_REFERENCES = get_path("test_data/test_references.txt") TEST_TSV_PREDICT = get_path("test_data/test_tsv_predict.tsv") diff --git a/tests/test_data/test_config_multitask.ini b/tests/test_data/test_config_multitask.ini new file mode 100644 index 0000000..d057ffb --- /dev/null +++ b/tests/test_data/test_config_multitask.ini @@ -0,0 +1,39 @@ +[DEFAULT] +version = test + +[data] +test_proportion = 0.25 +valid_proportion = 0.25 +data_path = data/ +respect_line_endings = 0 +respect_doc_endings = 1 +line_limit = 150 +rodrigues_train = data/rodrigues/clean_test.txt +rodrigues_test = +rodrigues_valid = +policy_train = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_test.tsv +policy_test = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_test.tsv +policy_valid = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_test.tsv +# This needs to have a trailing slash! +s3_slug = https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser/ + +[build] +output_path = models/multitask/2020.4.5_multitask/ +output = crf +word_embeddings = embeddings/2020.1.1-wellcome-embeddings-300-test.txt +pretrained_embedding = 0 +dropout = 0.5 +lstm_hidden = 400 +word_embedding_size = 300 +char_embedding_size = 100 +char_embedding_type = BILSTM +optimizer = adam + +[train] +epochs = 60 +batch_size = 100 +early_stopping_patience = 5 +metric = val_f1 + +[evaluate] +out_file = evaluation_data.tsv diff --git a/tests/test_deep_reference_parser_entrypoints.py b/tests/test_deep_reference_parser_entrypoints.py index 6936cbf..dd61664 100644 --- a/tests/test_deep_reference_parser_entrypoints.py +++ b/tests/test_deep_reference_parser_entrypoints.py @@ -7,7 +7,7 @@ from deep_reference_parser.split import Splitter from deep_reference_parser.split_parse import SplitParser -from .common import TEST_CFG, TEST_REFERENCES +from .common import TEST_CFG, TEST_CFG_MULTITASK, TEST_REFERENCES @pytest.fixture @@ -22,7 +22,7 @@ def parser(): @pytest.fixture def split_parser(): - return SplitParser(TEST_CFG) + return SplitParser(TEST_CFG_MULTITASK) @pytest.fixture @@ -67,7 +67,7 @@ def test_split_parser_list_output(text, split_parser): If the model artefacts and embeddings are not present this test will downloaded them, which can be slow. """ - out = split_parser.split_parse(text, verbose=False) + out = split_parser.split_parse(text, return_tokens=False, verbose=False) print(out) assert isinstance(out, list) @@ -100,13 +100,10 @@ def test_parser_tokens_output(text, parser): def test_split_parser_tokens_output(text, split_parser): """ """ - out = split_parser.split_parse(text, verbose=False) + out = split_parser.split_parse(text, return_tokens=True, verbose=False) - assert isinstance(out, list) - - # NOTE: full functionality of split_parse is not yet implemented. - - # assert isinstance(out[0], tuple) - # assert len(out[0]) == 2 - # assert isinstance(out[0][0], str) - # assert isinstance(out[0][1], str) + assert isinstance(out[0], tuple) + assert len(out[0]) == 3 + assert isinstance(out[0][0], str) + assert isinstance(out[0][1], str) + assert isinstance(out[0][2], str)