Skip to content

Commit

Permalink
FIX TST Scalings logging test latest transformers (huggingface#2042)
Browse files Browse the repository at this point in the history
Fix test for latest transformers, skip for earlier versions.
  • Loading branch information
EricLBuehler authored Sep 5, 2024
1 parent c9f7240 commit 31fbbd2
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions tests/test_xlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import os

import huggingface_hub
import packaging
import pytest
import torch
import transformers
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand All @@ -25,6 +27,9 @@
from peft.utils import infer_device


uses_transformers_4_45 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.45.0")


class TestXlora:
torch_device = infer_device()

Expand Down Expand Up @@ -128,6 +133,8 @@ def test_functional(self, tokenizer, model):
)
assert torch.isfinite(outputs[: inputs.shape[1] :]).all()

# TODO: remove the skip when 4.45 is released!
@pytest.mark.skipif(not uses_transformers_4_45, reason="Requires transformers >= 4.45")
def test_scalings_logging_methods(self, tokenizer, model):
model.enable_scalings_logging()

Expand Down Expand Up @@ -155,16 +162,13 @@ def test_scalings_logging_methods(self, tokenizer, model):

bucketed = model.get_bucketed_scalings_log()
keys = bucketed.keys()
# One bucket for prompt (seqlen=...) and one for the completion (seqlen=1)
assert len(bucketed) == 2
# One bucket for prompt (which has 1 elem)
assert len(bucketed[max(keys)][0]) == 1
assert len(bucketed[max(keys)][1]) == 1
assert bucketed[max(keys)][0][0] == 0
# One bucket for completions with bucket name 1
assert len(bucketed[1][0]) > 1
assert len(bucketed[1][1]) > 1
assert bucketed[1][0][0] > 0
# Once bucket for each token as we aren't using cache
assert len(bucketed) == 32 == len(keys)
seq_len = inputs.shape[1]
for key in keys:
assert len(bucketed[key][0]) == 1
assert len(bucketed[key][1]) == 1
assert bucketed[key][0][0] == key - seq_len

model.clear_scalings_log()
assert len(model.get_scalings_log()) == 0
Expand Down

0 comments on commit 31fbbd2

Please sign in to comment.