Skip to content

Commit

Permalink
feat(tests): use pytests markers to filter jetstream and torch xla tests
Browse files Browse the repository at this point in the history
So far filtering was done using the name of the test. Now the selection
is done using a custom marker, that allows for clearer filtering.
  • Loading branch information
tengomucho committed Nov 25, 2024
1 parent a4cacf6 commit 33fa858
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 37 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,28 @@ jobs:
- name: Run TGI Jetstream Pytorch - Llama
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and Llama"
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and Llama"
- name: Run TGI Jetstream Pytorch - Gemma
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and gemma"
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and gemma"
- name: Run TGI Jetstream Pytorch - Mixtral greedy
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and Mixtral and greedy"
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and Mixtral and greedy"
- name: Run TGI Jetstream Pytorch - Quantization Mixtral
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Mixtral"
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Mixtral"
- name: Run TGI Jetstream Pytorch - Quantization Llama-3 8B
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Llama-3-8B"
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3-8B"
- name: Run TGI Jetstream Pytorch - Quantization Llama 3 70B
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Llama-3-70B"
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3-70B"
- name: Run TGI Jetstream Pytorch - Other tests
run: |
python -m \
pytest -sv text-generation-inference/tests --runslow -k "jetstream and not decode and not quant"
pytest -sv text-generation-inference/tests --runslow -m jetstream -k "not decode"
1 change: 1 addition & 0 deletions .github/workflows/test-pytorch-xla-tpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
env:
PJRT_DEVICE: TPU
HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface
JETSTREAM_PT_DISABLE: 1 # Disable Jetstream Pytorch tests
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ jetstream_requirements: test_installs
tgi_test_jetstream: test_installs jetstream_requirements tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
JETSTREAM_PT=1 python -m pytest -sv text-generation-inference/tests -k jetstream
python -m pytest -sv text-generation-inference/tests -m jetstream

tgi_test: test_installs tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
python -m pytest -sv text-generation-inference/tests -k "not jetstream"
python -m pytest -sv text-generation-inference/tests -m torch_xla

tgi_docker_test: tpu-tgi
python -m pip install -r text-generation-inference/integration-tests/requirements.txt
Expand Down
13 changes: 13 additions & 0 deletions text-generation-inference/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

from optimum.tpu import jetstream_pt_available


# See https://stackoverflow.com/a/61193490/217945 for run_slow
def pytest_addoption(parser):
Expand Down Expand Up @@ -33,3 +35,14 @@ def quantization_jetstream_int8():
# Clean up
os.environ.clear()
os.environ.update(old_environ)


def pytest_runtest_setup(item):
marker_names = [marker.name for marker in item.own_markers]
jetstream_pt_enabled = jetstream_pt_available()
# Skip tests that require torch xla but not jetstream
if "torch_xla" in marker_names and "jetstream" not in marker_names:
if jetstream_pt_enabled:
pytest.skip("Jetstream PyTorch must be disabled")
elif "jetstream" in marker_names and not jetstream_pt_enabled:
pytest.skip("Jetstream PyTorch must be enabled")
7 changes: 0 additions & 7 deletions text-generation-inference/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

import pytest
from text_generation_server.pb.generate_pb2 import (
NextTokenChooserParameters,
Request,
Expand All @@ -10,12 +9,6 @@
from optimum.tpu.model import fetch_model


def skip_if_jetstream_pytorch_enabled(func):
reason = "Skipping because Jetstream PyTorch is enabled"
jetstream_enabled = os.getenv("JETSTREAM_PT_DISABLE") != "1"
return pytest.mark.skipif(jetstream_enabled, reason=reason)(func)


def prepare_model(model_id, sequence_length):
# Add variables to environment so they can be used in AutoModelForCausalLM
os.environ["HF_SEQUENCE_LENGTH"] = str(sequence_length)
Expand Down
4 changes: 4 additions & 0 deletions text-generation-inference/tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
markers =
jetstream: mark a test as a test that uses jetstream backend
torch_xla: mark a test as a test that uses torch_xla backend
7 changes: 4 additions & 3 deletions text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@

import pytest
from decode_tests_utils import DecodeTestParams, decode_single_test
from helpers import skip_if_jetstream_pytorch_enabled


@skip_if_jetstream_pytorch_enabled
# All tests in this file are for torch xla
pytestmark = pytest.mark.torch_xla

@pytest.mark.parametrize("params",
[
DecodeTestParams(
Expand All @@ -23,7 +24,7 @@
def test_decode_single(params):
decode_single_test(params)

@skip_if_jetstream_pytorch_enabled

@pytest.mark.slow
@pytest.mark.parametrize("params",
[
Expand Down
3 changes: 3 additions & 0 deletions text-generation-inference/tests/test_decode_jetstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from decode_tests_utils import DecodeTestParams, decode_single_test


# All tests in this file are for jetstream
pytestmark = pytest.mark.jetstream

@pytest.mark.slow
@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"])
@pytest.mark.parametrize("params",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from decode_tests_utils import DecodeTestParams, decode_single_test


# All tests in this file are for jetstream
pytestmark = pytest.mark.jetstream

@pytest.mark.parametrize("params",
[
DecodeTestParams(
Expand Down
16 changes: 5 additions & 11 deletions text-generation-inference/tests/test_prefill_truncate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled
import pytest
from helpers import create_request, prepare_model
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch


def _test_prefill_truncate():
@pytest.mark.jetstream
@pytest.mark.torch_xla
def test_prefill_truncate():
model_id = "Maykeye/TinyLLama-v0"
sequence_length = 1024

Expand Down Expand Up @@ -32,12 +35,3 @@ def _test_prefill_truncate():
assert len(generations) == 1
assert generations[0].tokens.ids == [266]
assert generations[0].tokens.texts == [" the"]


def test_prefill_truncate_jetstream():
_test_prefill_truncate()


@skip_if_jetstream_pytorch_enabled
def test_prefill_truncate():
_test_prefill_truncate()
14 changes: 9 additions & 5 deletions text-generation-inference/tests/test_tinyllama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled
from helpers import create_request, prepare_model
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch
from tqdm import tqdm
Expand All @@ -24,11 +24,12 @@ def _test_info(model_path, expected_device_type):
assert info.speculate == 0


@pytest.mark.jetstream
def test_jetstream_info(model_path):
_test_info(model_path, "meta")


@skip_if_jetstream_pytorch_enabled
@pytest.mark.torch_xla
def test_info(model_path):
_test_info(model_path, "xla")

Expand Down Expand Up @@ -57,6 +58,7 @@ def _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model
assert tokens.texts == [token_text]


@pytest.mark.jetstream
@pytest.mark.parametrize(
"input_text, token_id, token_text, do_sample",
[
Expand All @@ -80,7 +82,7 @@ def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_si
_test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path)


@skip_if_jetstream_pytorch_enabled
@pytest.mark.torch_xla
@pytest.mark.parametrize(
"input_text, token_id, token_text, do_sample",
[
Expand Down Expand Up @@ -134,11 +136,12 @@ def check_request(do_sample, expected_token_id):
check_request(False, greedy_expected_token_id)


@pytest.mark.jetstream
def test_jetstream_prefill_change_sampling(model_path):
_test_prefill_change_sampling(model_path, 347, 13)


@skip_if_jetstream_pytorch_enabled
@pytest.mark.torch_xla
def test_prefill_change_sampling(model_path):
_test_prefill_change_sampling(model_path, 571, 13)

Expand Down Expand Up @@ -216,6 +219,7 @@ def _test_continuous_batching_two_requests(model_path):
assert output.text == generated_text


@pytest.mark.jetstream
def test_jetstream_decode_multiple(model_path):
_test_continuous_batching_two_requests(model_path)

Expand All @@ -225,6 +229,6 @@ def test_jetstream_decode_multiple(model_path):
similar outputs, but not identical).
"""
@pytest.mark.skip(reason="Test is not supported on PyTorch/XLA")
@skip_if_jetstream_pytorch_enabled
@pytest.mark.torch_xla
def test_decode_multiple(model_path):
_test_continuous_batching_two_requests(model_path)
4 changes: 2 additions & 2 deletions text-generation-inference/tests/test_warmup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@


import pytest
from time import time

from helpers import create_request, prepare_model
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch


@pytest.mark.jetstream
def test_warmup_jetstream_pytorch():
model_id = "Maykeye/TinyLLama-v0"
sequence_length = 256
Expand Down

0 comments on commit 33fa858

Please sign in to comment.