Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.cat() issue when processing large number of documents with TransformersModelForTokenClassificationNerStep #80

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

paluchasz
Copy link
Collaborator

@paluchasz paluchasz commented Dec 11, 2024

The Issue

Noticed a weird problem occurring in the evaluation script when trying to naively process a large number (365) of Kazu documents with the TransformersModelForTokenClassificationNerStep step using the MPS device.

The 365 documents used totalled over 14k sections and were being processed with a newly trained 400MB model. Performing this on a Mac M3 with MPS, I saw Python's memory usage peak at 18GB:
image
In the end the step failed to predict any entities but without any exceptions thrown. The result was a weird phenomenon inside https://github.com/AstraZeneca/KAZU/blob/main/kazu/steps/ner/hf_token_classification.py:

  def get_multilabel_activations(self, loader: DataLoader) -> Tensor:
        """Get a tensor consisting of confidences for labels in a multi label
        classification context.

        :param loader:
        :return:
        """
        with torch.no_grad():
            results = torch.cat(
                tuple(self.model(**batch.to(self.device)).logits for batch in loader)
            ).to(self.device)
        return results.heaviside(torch.tensor([0.0]).to(self.device)).int().to("cpu")

Where torch.cat was producing a tensor full of zeros, indicating the model has not found any entities. This is likely due to torch.cat exceeding the allocated memory of the device.

The Fix

The fix is in two places. Firstly in the evaluate script we now process the documents in batches through the pipeline. However, to stop a user naively processing many documents with the Kazu pipeline and hitting this issue, there is also a fix inside the TransformersModelForTokenClassificationNerStep. This offloads the model logits onto CPU before concatenation.

Testing Performance

Here we perform the test with the naive call to the step with all the documents at once as before and test the version of TransformersModelForTokenClassificationNerStep before and after the change.

Before the change we observe a peak memory usage of 18GB and it takes 690s to process all the documents. With the new implementation we see a peak memory usage of 4GB and it takes 680s to process all the documents - also fixing the weird issue. Thus there doesn't seem to be any performance degradation in executing torch.cat on cpu vs mps. Cuda device was not tested however.

General Test for single label classification

A test script with the default model pipeline and Kazu model pack was run as a sanity check. The integration tests will now also be run.

Note

There is also a small refactor moving some functions from train_multilabel_ner to modelling_utils. Individual changes can be seen at commit level.

stops memory issue and null results
This saves memory early by offloading model logits onto CPU before concatenation and fixes a weird bug likely caused by memory issues
@paluchasz paluchasz requested review from EFord36 and removed request for EFord36 December 12, 2024 15:38
Copy link
Collaborator

@mariosaenger mariosaenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I only added minor comments

kazu/training/modelling_utils.py Outdated Show resolved Hide resolved
kazu/training/modelling_utils.py Show resolved Hide resolved
Copy link
Collaborator

@mariosaenger mariosaenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants