Skip to content

Commit

Permalink
updated tests for TokenizedWordProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
RichJackson committed Sep 11, 2024
1 parent 54f0fc4 commit 888f011
Showing 1 changed file with 21 additions and 197 deletions.
218 changes: 21 additions & 197 deletions kazu/tests/test_tokenized_word_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@
from kazu.steps.ner.tokenized_word_processor import TokenizedWord, TokenizedWordProcessor


@pytest.mark.parametrize(
"detect_subspans",
(True, False),
)
def test_tokenized_word_processor_with_subspan_detection(detect_subspans):
def test_tokenized_word_processor_single_label():
text = "hello to you"
# should produce one ent
word1 = TokenizedWord(
word_id=0,
token_ids=[0],
tokens=["hello"],
token_confidences=torch.Tensor([[0.99, 0.01]]),
token_confidences=torch.Tensor([[0.70, 0.20, 0.10]]),
token_offsets=[(0, 5)],
word_char_start=0,
word_char_end=5,
Expand All @@ -24,7 +20,7 @@ def test_tokenized_word_processor_with_subspan_detection(detect_subspans):
word_id=1,
token_ids=[1],
tokens=["to"],
token_confidences=torch.Tensor([[0.01, 0.99]]),
token_confidences=torch.Tensor([[0.01, 0.98, 0.01]]),
token_offsets=[(6, 8)],
word_char_start=6,
word_char_end=8,
Expand All @@ -33,181 +29,28 @@ def test_tokenized_word_processor_with_subspan_detection(detect_subspans):
word_id=2,
token_ids=[2],
tokens=["you"],
token_confidences=torch.Tensor([[0.01, 0.99]]),
token_offsets=[(9, 11)],
word_char_start=9,
word_char_end=11,
)

processor = TokenizedWordProcessor(
confidence_threshold=0.2, id2label={0: "B-hello", 1: "O"}, detect_subspans=detect_subspans
)
ents = processor(words=[word1, word2, word3], text=text, namespace="test")
assert len(ents) == 1
assert ents[0].match == "hello"


@pytest.mark.parametrize(
"detect_subspans",
(True, False),
)
def test_tokenized_word_processor_with_subspan_detection_2(detect_subspans):
text = "hello to you"
# also check this works if the word hello is composed of two B tokens
word1 = TokenizedWord(
word_id=0,
token_ids=[0, 1],
tokens=["hel", "lo"],
token_confidences=torch.Tensor([[0.99, 0.01], [0.99, 0.01]]),
token_offsets=[(0, 3), (3, 5)],
word_char_start=0,
word_char_end=5,
)
word2 = TokenizedWord(
word_id=1,
token_ids=[2],
tokens=["to"],
token_confidences=torch.Tensor([[0.01, 0.99]]),
token_offsets=[(6, 8)],
word_char_start=6,
word_char_end=8,
)
word3 = TokenizedWord(
word_id=2,
token_ids=[3],
tokens=["you"],
token_confidences=torch.Tensor([[0.01, 0.99]]),
token_offsets=[(9, 11)],
word_char_start=9,
word_char_end=11,
)
processor = TokenizedWordProcessor(
confidence_threshold=0.2, id2label={0: "B-hello", 1: "O"}, detect_subspans=detect_subspans
)
ents = processor(words=[word1, word2, word3], text=text, namespace="test")
assert len(ents) == 1
assert ents[0].match == "hello"


@pytest.mark.parametrize(
"detect_subspans",
(True, False),
)
def test_tokenized_word_processor_with_subspan_detection_3(detect_subspans: bool):
text = "hello-to you"
word1 = TokenizedWord(
word_id=0,
token_ids=[0],
tokens=["hello"],
token_confidences=torch.Tensor([[0.99, 0.01]]),
token_offsets=[(0, 5)],
word_char_start=0,
word_char_end=5,
)
word2 = TokenizedWord(
word_id=1,
token_ids=[1],
tokens=["-"],
token_confidences=torch.Tensor([[0.01, 0.99]]),
token_offsets=[(5, 6)],
word_char_start=5,
word_char_end=6,
)
word3 = TokenizedWord(
word_id=2,
token_ids=[2],
tokens=["to"],
token_confidences=torch.Tensor([[0.99, 0.01]]),
token_offsets=[(6, 8)],
word_char_start=6,
word_char_end=8,
)
word4 = TokenizedWord(
word_id=3,
token_ids=[3],
tokens=["you"],
token_confidences=torch.Tensor([[0.01, 0.99]]),
token_confidences=torch.Tensor([[0.01, 0.01, 0.98]]),
token_offsets=[(9, 11)],
word_char_start=9,
word_char_end=11,
)

processor = TokenizedWordProcessor(
confidence_threshold=0.2,
id2label={0: "B-greeting", 1: "O"},
detect_subspans=detect_subspans,
)
ents = processor(words=[word1, word2, word3, word4], text=text, namespace="test")
if detect_subspans:
# should produce three ents, since '-' is non breaking
assert len(ents) == 3
assert ents[0].match == "hello-"
assert ents[0].entity_class == "greeting"
assert ents[1].match == "hello-to"
assert ents[1].entity_class == "greeting"
assert ents[2].match == "to"
assert ents[2].entity_class == "greeting"
else:
# should produce two ents
assert len(ents) == 2
assert ents[0].match == "hello-"
assert ents[0].entity_class == "greeting"
assert ents[1].match == "to"
assert ents[1].entity_class == "greeting"


@pytest.mark.parametrize(
"detect_subspans",
(True, False),
)
def test_tokenized_word_processor_with_subspan_detection_4(detect_subspans):
# should produce two ent as " " is span breaking
text = "hello to you"
word1 = TokenizedWord(
word_id=0,
token_ids=[0],
tokens=["hello"],
token_confidences=torch.Tensor([[0.99, 0.01]]),
token_offsets=[(0, 5)],
word_char_start=0,
word_char_end=5,
)
word2 = TokenizedWord(
word_id=1,
token_ids=[1],
tokens=["to"],
token_confidences=torch.Tensor([[0.01, 0.99]]),
token_offsets=[(6, 8)],
word_char_start=6,
word_char_end=8,
)
word3 = TokenizedWord(
word_id=2,
token_ids=[2],
tokens=["you"],
token_confidences=torch.Tensor([[0.99, 0.01]]),
token_offsets=[(9, 11)],
word_char_start=9,
word_char_end=12,
)

processor = TokenizedWordProcessor(
confidence_threshold=0.2, id2label={0: "B-hello", 1: "O"}, detect_subspans=detect_subspans
)
processor = TokenizedWordProcessor(labels=["B-class1", "O", "B-class2"], use_multilabel=False)
ents = processor(words=[word1, word2, word3], text=text, namespace="test")
assert len(ents) == 2
assert ents[0].match == "hello"
assert ents[1].match == "you"
detected_ent_classes = [ent.entity_class for ent in ents]
assert "class1" in detected_ent_classes
assert "class2" in detected_ent_classes


def test_tokenized_word_processor_with_threshold():
def test_tokenized_word_processor_multi_label():
text = "hello to you"
# should produce one ent
word1 = TokenizedWord(
word_id=0,
token_ids=[0],
tokens=["hello"],
token_confidences=torch.Tensor([[0.70, 0.20, 0.10]]),
token_confidences=torch.Tensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]),
token_offsets=[(0, 5)],
word_char_start=0,
word_char_end=5,
Expand All @@ -216,7 +59,7 @@ def test_tokenized_word_processor_with_threshold():
word_id=1,
token_ids=[1],
tokens=["to"],
token_confidences=torch.Tensor([[0.01, 0.01, 0.98]]),
token_confidences=torch.Tensor([[[1, 0, 0], [0, 0, 0], [0, 0, 1]]]),
token_offsets=[(6, 8)],
word_char_start=6,
word_char_end=8,
Expand All @@ -225,50 +68,31 @@ def test_tokenized_word_processor_with_threshold():
word_id=2,
token_ids=[2],
tokens=["you"],
token_confidences=torch.Tensor([[0.01, 0.01, 0.98]]),
token_confidences=torch.Tensor([[[0, 0, 0], [0, 0, 0], [0, 0, 0]]]),
token_offsets=[(9, 11)],
word_char_start=9,
word_char_end=11,
)

processor = TokenizedWordProcessor(
confidence_threshold=0.1,
id2label={0: "B-class1", 1: "B-class2", 2: "O"},
detect_subspans=True,
)
processor = TokenizedWordProcessor(labels=["class1", "O", "class2"], use_multilabel=True)
ents = processor(words=[word1, word2, word3], text=text, namespace="test")
assert len(ents) == 2
detected_ent_classes = [ent.entity_class for ent in ents]
detected_ent_classes = set()
detected_ent_matches = set()
for ent in ents:
detected_ent_classes.add(ent.entity_class)
detected_ent_matches.add(ent.match)
assert "class1" in detected_ent_classes
assert "class2" in detected_ent_classes


def test_tokenized_word_processor_no_threshold():
with pytest.raises(ValueError):
text = "hello"
word1 = TokenizedWord(
word_id=0,
token_ids=[0],
tokens=["hello"],
token_confidences=torch.Tensor([[0.70, 0.20, 0.10]]),
token_offsets=[(0, 5)],
word_char_start=0,
word_char_end=5,
)
processor = TokenizedWordProcessor(
confidence_threshold=None,
id2label={0: "B-class1", 1: "B-class2", 2: "O"},
detect_subspans=True,
)
processor(words=[word1], text=text, namespace="test")
assert "to" in detected_ent_matches
assert "hello to" in detected_ent_matches


@pytest.mark.parametrize("query", ["COX2 protein", "COX2 gene", "COX2 gene protein protein gene"])
def test_tokenized_word_processor_strip_re(query):
processor = TokenizedWordProcessor(
confidence_threshold=None, id2label={}, strip_re={"gene": "( (gene|protein)s?)+$"}
labels=["B-hello", "O"], use_multilabel=False, strip_re={"gene": "( (gene|protein)s?)+$"}
)

expected_str = "COX2"
expected_end = 4

Expand Down

0 comments on commit 888f011

Please sign in to comment.