diff --git a/kazu/steps/ner/hf_token_classification.py b/kazu/steps/ner/hf_token_classification.py index d785827d..fc793ac5 100644 --- a/kazu/steps/ner/hf_token_classification.py +++ b/kazu/steps/ner/hf_token_classification.py @@ -1,27 +1,29 @@ import logging from collections.abc import Iterator -from typing import Optional, cast, Any, Iterable +from typing import Any, Iterable, Optional, cast import torch from torch import Tensor, softmax from torch.utils.data import DataLoader, IterableDataset from transformers import ( - AutoModelForTokenClassification, AutoConfig, + AutoModelForTokenClassification, AutoTokenizer, + BatchEncoding, DataCollatorWithPadding, PreTrainedTokenizerBase, - BatchEncoding, ) from transformers.file_utils import PaddingStrategy -from kazu.data import Section, Document +from kazu.data import Document, Section from kazu.steps import Step, document_batch_step from kazu.steps.ner.entity_post_processing import NonContiguousEntitySplitter -from kazu.steps.ner.tokenized_word_processor import TokenizedWordProcessor, TokenizedWord +from kazu.steps.ner.tokenized_word_processor import ( + TokenizedWord, + TokenizedWordProcessor, +) from kazu.utils.utils import documents_to_document_section_batch_encodings_map - logger = logging.getLogger(__name__) @@ -288,26 +290,28 @@ def get_list_of_batch_encoding_frames_for_section( def get_multilabel_activations(self, loader: DataLoader) -> Tensor: """Get a tensor consisting of confidences for labels in a multi label - classification context. + classification context. Output tensor is of shape (n_samples, + max_sequence_length, n_labels). :param loader: :return: """ with torch.no_grad(): results = torch.cat( - tuple(self.model(**batch.to(self.device)).logits for batch in loader) + tuple(self.model(**batch.to(self.device)).logits.to("cpu") for batch in loader) ).to(self.device) return results.heaviside(torch.tensor([0.0]).to(self.device)).int().to("cpu") def get_single_label_activations(self, loader: DataLoader) -> Tensor: """Get a tensor consisting of one hot binary classifications in a single label - classification context. + classification context. Output tensor is of shape (n_samples, + max_sequence_length, n_labels). :param loader: :return: """ with torch.no_grad(): results = torch.cat( - tuple(self.model(**batch.to(self.device)).logits for batch in loader) + tuple(self.model(**batch.to(self.device)).logits.to("cpu") for batch in loader) ) return softmax(results, dim=-1).to("cpu")