Skip to content

Commit

Permalink
fix: torch.cat returning zero tensor for large datasets
Browse files Browse the repository at this point in the history
This saves memory early by offloading model logits onto CPU before concatenation and fixes a weird bug likely caused by memory issues
  • Loading branch information
paluchasz committed Dec 11, 2024
1 parent a8bf7a7 commit 28db0e5
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions kazu/steps/ner/hf_token_classification.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import logging
from collections.abc import Iterator
from typing import Optional, cast, Any, Iterable
from typing import Any, Iterable, Optional, cast

import torch
from torch import Tensor, softmax
from torch.utils.data import DataLoader, IterableDataset
from transformers import (
AutoModelForTokenClassification,
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
BatchEncoding,
DataCollatorWithPadding,
PreTrainedTokenizerBase,
BatchEncoding,
)
from transformers.file_utils import PaddingStrategy

from kazu.data import Section, Document
from kazu.data import Document, Section
from kazu.steps import Step, document_batch_step
from kazu.steps.ner.entity_post_processing import NonContiguousEntitySplitter
from kazu.steps.ner.tokenized_word_processor import TokenizedWordProcessor, TokenizedWord
from kazu.steps.ner.tokenized_word_processor import (
TokenizedWord,
TokenizedWordProcessor,
)
from kazu.utils.utils import documents_to_document_section_batch_encodings_map


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -288,26 +290,28 @@ def get_list_of_batch_encoding_frames_for_section(

def get_multilabel_activations(self, loader: DataLoader) -> Tensor:
"""Get a tensor consisting of confidences for labels in a multi label
classification context.
classification context. Output tensor is of shape (n_samples,
max_sequence_length, n_labels).
:param loader:
:return:
"""
with torch.no_grad():
results = torch.cat(
tuple(self.model(**batch.to(self.device)).logits for batch in loader)
tuple(self.model(**batch.to(self.device)).logits.to("cpu") for batch in loader)
).to(self.device)
return results.heaviside(torch.tensor([0.0]).to(self.device)).int().to("cpu")

def get_single_label_activations(self, loader: DataLoader) -> Tensor:
"""Get a tensor consisting of one hot binary classifications in a single label
classification context.
classification context. Output tensor is of shape (n_samples,
max_sequence_length, n_labels).
:param loader:
:return:
"""
with torch.no_grad():
results = torch.cat(
tuple(self.model(**batch.to(self.device)).logits for batch in loader)
tuple(self.model(**batch.to(self.device)).logits.to("cpu") for batch in loader)
)
return softmax(results, dim=-1).to("cpu")

0 comments on commit 28db0e5

Please sign in to comment.