Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi label models #67

Merged
merged 11 commits into from
Sep 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ path: ${oc.env:KAZU_MODEL_PACK}/tinybern
batch_size: 4
stride: 16
max_sequence_length: 128
keys_to_use: #distilbert for token classification doesn't use token_type_ids
- input_ids
- attention_mask
- token_type_ids
entity_splitter:
_target_: kazu.steps.ner.entity_post_processing.NonContiguousEntitySplitter
entity_conditions:
Expand All @@ -13,21 +17,22 @@ entity_splitter:
disease:
- _target_: kazu.steps.ner.entity_post_processing.SplitOnConjunctionPattern
path: ${SciSpacyPipeline.path}
detect_subspans: False
threshold: ~
labels:
- 'B-cell_line'
- 'B-cell_type'
- 'B-disease'
- 'B-drug'
- 'B-gene'
- 'B-species'
- 'I-cell_line'
- 'I-cell_type'
- 'I-disease'
- 'I-drug'
- 'I-gene'
- 'I-species'
- 'O'
strip_re:
gene: "( (gene|protein)s?)+$"
tokenized_word_processor:
_target_: kazu.steps.ner.tokenized_word_processor.TokenizedWordProcessor
labels:
- 'B-cell_line'
- 'B-cell_type'
- 'B-disease'
- 'B-drug'
- 'B-gene'
- 'B-species'
- 'I-cell_line'
- 'I-cell_type'
- 'I-disease'
- 'I-drug'
- 'I-gene'
- 'I-species'
- 'O'
strip_re:
gene: "( (gene|protein)s?)+$"
use_multilabel: false

This file was deleted.

1 change: 0 additions & 1 deletion kazu/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from kazu.utils.string_normalizer import StringNormalizer
from numpy import ndarray, float32, float16

IS_SUBSPAN = "is_subspan"
# BIO schema
ENTITY_START_SYMBOL = "B"
ENTITY_INSIDE_SYMBOL = "I"
Expand Down
103 changes: 53 additions & 50 deletions kazu/steps/ner/hf_token_classification.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,50 @@
import logging
from collections.abc import Iterable, Iterator, Callable
from functools import partial
from typing import Optional, cast, Any
from collections.abc import Iterator
from typing import Optional, cast, Any, Iterable

import torch
from torch import Tensor, sigmoid, softmax
from torch import Tensor, softmax
from torch.utils.data import DataLoader, IterableDataset
from transformers import (
AutoModelForTokenClassification,
AutoConfig,
AutoTokenizer,
DataCollatorWithPadding,
BatchEncoding,
PreTrainedTokenizerBase,
BatchEncoding,
)
from transformers.file_utils import PaddingStrategy

from kazu.data import Document, Section
from kazu.data import Section, Document
from kazu.steps import Step, document_batch_step
from kazu.steps.ner.entity_post_processing import NonContiguousEntitySplitter
from kazu.steps.ner.tokenized_word_processor import TokenizedWordProcessor, TokenizedWord
from kazu.utils.utils import documents_to_document_section_batch_encodings_map


logger = logging.getLogger(__name__)


class HFDataset(IterableDataset[dict[str, Any]]):
def __getitem__(self, index: int) -> dict[str, Any]:
return {
"input_ids": self.encodings.data["input_ids"][index],
"attention_mask": self.encodings.data["attention_mask"][index],
"token_type_ids": self.encodings.data["token_type_ids"][index],
}
return {key: self.encodings.data[key][index] for key in self.keys_to_use}

def __init__(self, encodings: BatchEncoding):
def __init__(self, encodings: BatchEncoding, keys_to_use: Iterable[str]):
"""Simple implementation of :class:`torch.utils.data.IterableDataset`\\ ,
producing HF tokenizer input_id.

:param encodings:
:param keys_to_use: the keys to use from the encodings (not all models require
token_type_ids)
"""
self.keys_to_use = set(keys_to_use)
self.encodings = encodings
self.dataset_size = len(encodings.data["input_ids"])

def __iter__(self) -> Iterator[dict[str, Any]]:

for i in range(self.dataset_size):
yield {
"input_ids": self.encodings.data["input_ids"][i],
"attention_mask": self.encodings.data["attention_mask"][i],
"token_type_ids": self.encodings.data["token_type_ids"][i],
}
yield {key: self.encodings.data[key][i] for key in self.keys_to_use}


class TransformersModelForTokenClassificationNerStep(Step):
Expand All @@ -67,24 +62,25 @@ def __init__(
batch_size: int,
stride: int,
max_sequence_length: int,
labels: list[str],
detect_subspans: bool = False,
threshold: Optional[float] = None,
tokenized_word_processor: TokenizedWordProcessor,
keys_to_use: Iterable[str],
entity_splitter: Optional[NonContiguousEntitySplitter] = None,
strip_re: Optional[dict[str, str]] = None,
device: str = "cpu",
):
"""

:param path: path to HF model, config and tokenizer. Passed to HF .from_pretrained()
:param batch_size: batch size for dataloader
:param stride: passed to HF tokenizers (for splitting long docs)
:param max_sequence_length: passed to HF tokenizers (for splitting long docs)
:param labels:
:param detect_subspans: attempt to detect nested entities (threshold must be configured)
:param threshold: the confidence threshold used to detect nested entities
:param tokenized_word_processor:
:param keys_to_use: keys to use from the encodings. Note that this varies depending on the flaour of bert model (e.g. distilbert requires token_type_ids)
:param entity_splitter: to detect non-contiguous entities if provided
:param strip_re: passed to :class:`~kazu.steps.ner.tokenized_word_processor.TokenizedWordProcessor`
:param device: device to run the model on. Defaults to "cpu"
"""

self.keys_to_use = set(keys_to_use)
self.device = device
self.entity_splitter = entity_splitter
if max_sequence_length % 2 != 0:
raise RuntimeError(
Expand All @@ -102,23 +98,19 @@ def __init__(
self.model = AutoModelForTokenClassification.from_pretrained(
path, config=self.config
).eval()
self.activation_fn = cast(
Callable[[Tensor], Tensor], sigmoid if detect_subspans else partial(softmax, dim=-1)
)
self.tokenized_word_processor = TokenizedWordProcessor(
detect_subspans=detect_subspans,
confidence_threshold=threshold,
id2label=self.id2labels_from_label_list(labels),
strip_re=strip_re,
)
self.tokenized_word_processor = tokenized_word_processor
self.model.to(device)

@document_batch_step
def __call__(self, docs: list[Document]) -> None:
loader, id_section_map = self.get_dataloader(docs)
# need this so mypy knows to expect the dataset to have encodings
dataset = cast(HFDataset, loader.dataset)
# run the transformer and get results
activations = self.get_activations(loader)
if self.tokenized_word_processor.use_multilabel:
activations = self.get_multilabel_activations(loader)
else:
activations = self.get_single_label_activations(loader)
for section_index, section in id_section_map.items():
words = self.section_frames_to_tokenised_words(
section_index=section_index,
Expand All @@ -133,17 +125,6 @@ def __call__(self, docs: list[Document]) -> None:
for ent in entities:
section.entities.extend(self.entity_splitter(ent, section.text))

def get_activations(self, loader: DataLoader) -> Tensor:
"""Get a namedtuple_values_indices consisting of confidence and labels for a
given dataloader (i.e. run bert)

:param loader:
:return:
"""
with torch.no_grad():
results = torch.cat(tuple(self.model(**batch).logits for batch in loader))
return self.activation_fn(results)

def get_dataloader(self, docs: list[Document]) -> tuple[DataLoader, dict[int, Section]]:
"""Get a dataloader from a List of :class:`kazu.data.Document`. Collation is
handled via :class:`transformers.DataCollatorWithPadding`\\ .
Expand All @@ -156,7 +137,7 @@ def get_dataloader(self, docs: list[Document]) -> tuple[DataLoader, dict[int, Se
batch_encoding, id_section_map = documents_to_document_section_batch_encodings_map(
docs, self.tokeniser, stride=self.stride, max_length=self.max_sequence_length
)
dataset = HFDataset(batch_encoding)
dataset = HFDataset(batch_encoding, keys_to_use=self.keys_to_use)
collate_func = DataCollatorWithPadding(
tokenizer=self.tokeniser, padding=PaddingStrategy.MAX_LENGTH
)
Expand Down Expand Up @@ -305,6 +286,28 @@ def get_list_of_batch_encoding_frames_for_section(
]
return section_frame_indices

@staticmethod
def id2labels_from_label_list(labels: Iterable[str]) -> dict[int, str]:
return {idx: label for idx, label in enumerate(labels)}
def get_multilabel_activations(self, loader: DataLoader) -> Tensor:
"""Get a tensor consisting of confidences for labels in a multi label
classification context.

:param loader:
:return:
"""
with torch.no_grad():
results = torch.cat(
tuple(self.model(**batch.to(self.device)).logits for batch in loader)
).to(self.device)
return results.heaviside(torch.tensor([0.0]).to(self.device)).int().to("cpu")

def get_single_label_activations(self, loader: DataLoader) -> Tensor:
"""Get a tensor consisting of one hot binary classifications in a single label
classification context.

:param loader:
:return:
"""
with torch.no_grad():
results = torch.cat(
tuple(self.model(**batch.to(self.device)).logits for batch in loader)
)
return softmax(results, dim=-1).to("cpu")
Loading
Loading