From 33fa858dd6a4f51a06f69a6cccf767aa4a6f080b Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Mon, 25 Nov 2024 13:51:19 +0000 Subject: [PATCH] feat(tests): use pytests markers to filter jetstream and torch xla tests So far filtering was done using the name of the test. Now the selection is done using a custom marker, that allows for clearer filtering. --- ...est-pytorch-xla-tpu-tgi-nightly-jetstream.yml | 14 +++++++------- .github/workflows/test-pytorch-xla-tpu.yml | 1 + Makefile | 4 ++-- text-generation-inference/tests/conftest.py | 13 +++++++++++++ text-generation-inference/tests/helpers.py | 7 ------- text-generation-inference/tests/pytest.ini | 4 ++++ text-generation-inference/tests/test_decode.py | 7 ++++--- .../tests/test_decode_jetstream.py | 3 +++ .../tests/test_decode_jetstream_quant.py | 3 +++ .../tests/test_prefill_truncate.py | 16 +++++----------- .../tests/test_tinyllama.py | 14 +++++++++----- text-generation-inference/tests/test_warmup.py | 4 ++-- 12 files changed, 53 insertions(+), 37 deletions(-) create mode 100644 text-generation-inference/tests/pytest.ini diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml index d7837e3d..3e4859e1 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml @@ -34,28 +34,28 @@ jobs: - name: Run TGI Jetstream Pytorch - Llama run: | python -m \ - pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and Llama" + pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and Llama" - name: Run TGI Jetstream Pytorch - Gemma run: | python -m \ - pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and gemma" + pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and gemma" - name: Run TGI Jetstream Pytorch - Mixtral greedy run: | python -m \ - pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and Mixtral and greedy" + pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and Mixtral and greedy" - name: Run TGI Jetstream Pytorch - Quantization Mixtral run: | python -m \ - pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Mixtral" + pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Mixtral" - name: Run TGI Jetstream Pytorch - Quantization Llama-3 8B run: | python -m \ - pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Llama-3-8B" + pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3-8B" - name: Run TGI Jetstream Pytorch - Quantization Llama 3 70B run: | python -m \ - pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Llama-3-70B" + pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3-70B" - name: Run TGI Jetstream Pytorch - Other tests run: | python -m \ - pytest -sv text-generation-inference/tests --runslow -k "jetstream and not decode and not quant" + pytest -sv text-generation-inference/tests --runslow -m jetstream -k "not decode" diff --git a/.github/workflows/test-pytorch-xla-tpu.yml b/.github/workflows/test-pytorch-xla-tpu.yml index 46f6f414..a91ebf93 100644 --- a/.github/workflows/test-pytorch-xla-tpu.yml +++ b/.github/workflows/test-pytorch-xla-tpu.yml @@ -25,6 +25,7 @@ jobs: env: PJRT_DEVICE: TPU HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface + JETSTREAM_PT_DISABLE: 1 # Disable Jetstream Pytorch tests steps: - name: Checkout uses: actions/checkout@v4 diff --git a/Makefile b/Makefile index 5c5fb429..7448a46e 100644 --- a/Makefile +++ b/Makefile @@ -95,12 +95,12 @@ jetstream_requirements: test_installs tgi_test_jetstream: test_installs jetstream_requirements tgi_server find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \ -exec python -m pip install --force-reinstall {} \; - JETSTREAM_PT=1 python -m pytest -sv text-generation-inference/tests -k jetstream + python -m pytest -sv text-generation-inference/tests -m jetstream tgi_test: test_installs tgi_server find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \ -exec python -m pip install --force-reinstall {} \; - python -m pytest -sv text-generation-inference/tests -k "not jetstream" + python -m pytest -sv text-generation-inference/tests -m torch_xla tgi_docker_test: tpu-tgi python -m pip install -r text-generation-inference/integration-tests/requirements.txt diff --git a/text-generation-inference/tests/conftest.py b/text-generation-inference/tests/conftest.py index d5bef29f..b3a825d7 100644 --- a/text-generation-inference/tests/conftest.py +++ b/text-generation-inference/tests/conftest.py @@ -2,6 +2,8 @@ import pytest +from optimum.tpu import jetstream_pt_available + # See https://stackoverflow.com/a/61193490/217945 for run_slow def pytest_addoption(parser): @@ -33,3 +35,14 @@ def quantization_jetstream_int8(): # Clean up os.environ.clear() os.environ.update(old_environ) + + +def pytest_runtest_setup(item): + marker_names = [marker.name for marker in item.own_markers] + jetstream_pt_enabled = jetstream_pt_available() + # Skip tests that require torch xla but not jetstream + if "torch_xla" in marker_names and "jetstream" not in marker_names: + if jetstream_pt_enabled: + pytest.skip("Jetstream PyTorch must be disabled") + elif "jetstream" in marker_names and not jetstream_pt_enabled: + pytest.skip("Jetstream PyTorch must be enabled") diff --git a/text-generation-inference/tests/helpers.py b/text-generation-inference/tests/helpers.py index 5c1e320e..e36f997a 100644 --- a/text-generation-inference/tests/helpers.py +++ b/text-generation-inference/tests/helpers.py @@ -1,6 +1,5 @@ import os -import pytest from text_generation_server.pb.generate_pb2 import ( NextTokenChooserParameters, Request, @@ -10,12 +9,6 @@ from optimum.tpu.model import fetch_model -def skip_if_jetstream_pytorch_enabled(func): - reason = "Skipping because Jetstream PyTorch is enabled" - jetstream_enabled = os.getenv("JETSTREAM_PT_DISABLE") != "1" - return pytest.mark.skipif(jetstream_enabled, reason=reason)(func) - - def prepare_model(model_id, sequence_length): # Add variables to environment so they can be used in AutoModelForCausalLM os.environ["HF_SEQUENCE_LENGTH"] = str(sequence_length) diff --git a/text-generation-inference/tests/pytest.ini b/text-generation-inference/tests/pytest.ini new file mode 100644 index 00000000..e93b1b11 --- /dev/null +++ b/text-generation-inference/tests/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + jetstream: mark a test as a test that uses jetstream backend + torch_xla: mark a test as a test that uses torch_xla backend diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index 31b64e3d..99343da6 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -1,10 +1,11 @@ import pytest from decode_tests_utils import DecodeTestParams, decode_single_test -from helpers import skip_if_jetstream_pytorch_enabled -@skip_if_jetstream_pytorch_enabled +# All tests in this file are for torch xla +pytestmark = pytest.mark.torch_xla + @pytest.mark.parametrize("params", [ DecodeTestParams( @@ -23,7 +24,7 @@ def test_decode_single(params): decode_single_test(params) -@skip_if_jetstream_pytorch_enabled + @pytest.mark.slow @pytest.mark.parametrize("params", [ diff --git a/text-generation-inference/tests/test_decode_jetstream.py b/text-generation-inference/tests/test_decode_jetstream.py index 8a297ffd..6cee30c8 100644 --- a/text-generation-inference/tests/test_decode_jetstream.py +++ b/text-generation-inference/tests/test_decode_jetstream.py @@ -3,6 +3,9 @@ from decode_tests_utils import DecodeTestParams, decode_single_test +# All tests in this file are for jetstream +pytestmark = pytest.mark.jetstream + @pytest.mark.slow @pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"]) @pytest.mark.parametrize("params", diff --git a/text-generation-inference/tests/test_decode_jetstream_quant.py b/text-generation-inference/tests/test_decode_jetstream_quant.py index 86f19365..832cf4c3 100644 --- a/text-generation-inference/tests/test_decode_jetstream_quant.py +++ b/text-generation-inference/tests/test_decode_jetstream_quant.py @@ -3,6 +3,9 @@ from decode_tests_utils import DecodeTestParams, decode_single_test +# All tests in this file are for jetstream +pytestmark = pytest.mark.jetstream + @pytest.mark.parametrize("params", [ DecodeTestParams( diff --git a/text-generation-inference/tests/test_prefill_truncate.py b/text-generation-inference/tests/test_prefill_truncate.py index eba4bd69..b90a16ca 100644 --- a/text-generation-inference/tests/test_prefill_truncate.py +++ b/text-generation-inference/tests/test_prefill_truncate.py @@ -1,9 +1,12 @@ -from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled +import pytest +from helpers import create_request, prepare_model from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch -def _test_prefill_truncate(): +@pytest.mark.jetstream +@pytest.mark.torch_xla +def test_prefill_truncate(): model_id = "Maykeye/TinyLLama-v0" sequence_length = 1024 @@ -32,12 +35,3 @@ def _test_prefill_truncate(): assert len(generations) == 1 assert generations[0].tokens.ids == [266] assert generations[0].tokens.texts == [" the"] - - -def test_prefill_truncate_jetstream(): - _test_prefill_truncate() - - -@skip_if_jetstream_pytorch_enabled -def test_prefill_truncate(): - _test_prefill_truncate() diff --git a/text-generation-inference/tests/test_tinyllama.py b/text-generation-inference/tests/test_tinyllama.py index c8ec58ea..d9343073 100644 --- a/text-generation-inference/tests/test_tinyllama.py +++ b/text-generation-inference/tests/test_tinyllama.py @@ -1,5 +1,5 @@ import pytest -from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled +from helpers import create_request, prepare_model from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch from tqdm import tqdm @@ -24,11 +24,12 @@ def _test_info(model_path, expected_device_type): assert info.speculate == 0 +@pytest.mark.jetstream def test_jetstream_info(model_path): _test_info(model_path, "meta") -@skip_if_jetstream_pytorch_enabled +@pytest.mark.torch_xla def test_info(model_path): _test_info(model_path, "xla") @@ -57,6 +58,7 @@ def _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model assert tokens.texts == [token_text] +@pytest.mark.jetstream @pytest.mark.parametrize( "input_text, token_id, token_text, do_sample", [ @@ -80,7 +82,7 @@ def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_si _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path) -@skip_if_jetstream_pytorch_enabled +@pytest.mark.torch_xla @pytest.mark.parametrize( "input_text, token_id, token_text, do_sample", [ @@ -134,11 +136,12 @@ def check_request(do_sample, expected_token_id): check_request(False, greedy_expected_token_id) +@pytest.mark.jetstream def test_jetstream_prefill_change_sampling(model_path): _test_prefill_change_sampling(model_path, 347, 13) -@skip_if_jetstream_pytorch_enabled +@pytest.mark.torch_xla def test_prefill_change_sampling(model_path): _test_prefill_change_sampling(model_path, 571, 13) @@ -216,6 +219,7 @@ def _test_continuous_batching_two_requests(model_path): assert output.text == generated_text +@pytest.mark.jetstream def test_jetstream_decode_multiple(model_path): _test_continuous_batching_two_requests(model_path) @@ -225,6 +229,6 @@ def test_jetstream_decode_multiple(model_path): similar outputs, but not identical). """ @pytest.mark.skip(reason="Test is not supported on PyTorch/XLA") -@skip_if_jetstream_pytorch_enabled +@pytest.mark.torch_xla def test_decode_multiple(model_path): _test_continuous_batching_two_requests(model_path) diff --git a/text-generation-inference/tests/test_warmup.py b/text-generation-inference/tests/test_warmup.py index 102b6d00..3aa026f2 100644 --- a/text-generation-inference/tests/test_warmup.py +++ b/text-generation-inference/tests/test_warmup.py @@ -1,5 +1,4 @@ - - +import pytest from time import time from helpers import create_request, prepare_model @@ -7,6 +6,7 @@ from text_generation_server.pb.generate_pb2 import Batch +@pytest.mark.jetstream def test_warmup_jetstream_pytorch(): model_id = "Maykeye/TinyLLama-v0" sequence_length = 256