From abeace69d5056677fa11666c83b90e486d288a4e Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 21 Nov 2024 10:58:08 +0000 Subject: [PATCH 1/5] test(slots): add unit tests for slots for jetstream too Implementation is slightly different, so a separate test is added. --- .../tests/test_generator_slot.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/text-generation-inference/tests/test_generator_slot.py b/text-generation-inference/tests/test_generator_slot.py index f0f23ed7..c63d6920 100644 --- a/text-generation-inference/tests/test_generator_slot.py +++ b/text-generation-inference/tests/test_generator_slot.py @@ -1,8 +1,11 @@ +import numpy as np import pytest import torch from text_generation_server.pb.generate_pb2 import Request from transformers import AutoTokenizer, GenerationConfig +from optimum.tpu.jetstream_pt_support import jetstream_pt_available + TOKENIZERS = ["NousResearch/Llama-2-7b-hf", "openai-community/gpt2"] @@ -15,6 +18,54 @@ def tokenizer(request): return t +@pytest.mark.parametrize( + "input_text, generated_text", + [ + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + " Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind," + " slipped quickly through the glass doors of Victory Mansions, though not quickly enough" + " to prevent a swirl of gritty dust from entering along with him.", + ], + ["This sentence is written in chinese:", "我很感谢你的热情"], + ["Some text might contain a lot of emojis like 😃", "😍💪 👉 👀"], + ], + ids=["spaces", "chinese-utf8", "emojis"], +) +def test_decode_streaming_jetstream(tokenizer, input_text, generated_text): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + + from text_generation_server.jetstream_pt_support.generator import Slot + + slot = Slot(0, tokenizer) + request = Request(id=0, inputs=input_text) + slot.assign(0, request, GenerationConfig()) + + inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors="np") + input_ids = inputs["input_ids"][0] + generated_tokens = tokenizer(generated_text, add_special_tokens=False, return_tensors="np")["input_ids"][0] + + # We need to regenerate the full text as the tokenizer might change it (extra spaces might be added) + all_input_ids = np.concatenate([input_ids, generated_tokens]) + full_text = tokenizer.decode(all_input_ids, skip_special_tokens=True) + regenerated_text = full_text[len(input_text) :] + + # Initialize the slot with the inputs + slot.reset(input_ids, selector=None) + + assert slot.generated_tokens == 0 + + # Simulate an iterative generation (i.e. don't call select and use known tokens instead) + decoded_text = "" + for i in range(len(generated_tokens)): + text = slot.append(generated_tokens[i]) + assert slot.generated_tokens == i + 1 + decoded_text += text + + assert decoded_text == regenerated_text + + @pytest.mark.parametrize( "input_text, generated_text", [ @@ -31,6 +82,7 @@ def tokenizer(request): ) def test_decode_streaming(tokenizer, input_text, generated_text): from text_generation_server.generator import Slot + # Note: device used is cpu to make it faster slot = Slot(0, tokenizer, "cpu") request = Request(id=0, inputs=input_text) From b15f9524c83597a8f3d676d00856c2a93112ff3f Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 21 Nov 2024 14:53:19 +0000 Subject: [PATCH 2/5] test(truncate): adapt test for jetstream too --- .../tests/test_prefill_truncate.py | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/text-generation-inference/tests/test_prefill_truncate.py b/text-generation-inference/tests/test_prefill_truncate.py index 4ad78ab5..43e8d2a4 100644 --- a/text-generation-inference/tests/test_prefill_truncate.py +++ b/text-generation-inference/tests/test_prefill_truncate.py @@ -1,11 +1,14 @@ +import pytest from helpers import create_request, prepare_model from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch +from optimum.tpu.jetstream_pt_support import jetstream_pt_available -def test_prefill_truncate(): - model_id="Maykeye/TinyLLama-v0" - sequence_length=1024 + +def _test_prefill_truncate(): + model_id = "Maykeye/TinyLLama-v0" + sequence_length = 1024 model_path = prepare_model(model_id, sequence_length) max_new_tokens = 20 @@ -13,23 +16,32 @@ def test_prefill_truncate(): generator = AutoGenerator.from_pretrained( model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length ) - input_text = "This is something I will tell by the end of the story" + input_text = "And to finish the story, I will say that" request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length) generations, _ = generator.prefill(batch) assert len(generations) == 1 - assert generations[0].tokens.ids == [31843] - assert generations[0].tokens.texts == ["."] - + assert generations[0].tokens.ids == [357] + assert generations[0].tokens.texts == [" it"] # Now re-test but with truncate generator.clear() request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) # This will only leave last tokens - request.truncate = 6 + request.truncate = 3 batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length) generations, _ = generator.prefill(batch) assert len(generations) == 1 - assert generations[0].tokens.ids == [291] - assert generations[0].tokens.texts == [" and"] + assert generations[0].tokens.ids == [266] + assert generations[0].tokens.texts == [" the"] + + +def test_prefill_truncate_jetstream(): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + _test_prefill_truncate() + + +def test_prefill_truncate(): + _test_prefill_truncate() From a5214be96be5a7e29c7a6e3755b44c27fb7b3b98 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 22 Nov 2024 13:39:07 +0000 Subject: [PATCH 3/5] refactor(test): make tinyllama test work for Jetstream and Torch/XLA Most tests work for both, except for the continuous batching one. This allows to remove the old GPT2 based tests, that are quite slow and do not use any sharding or KV cache, so they might not really be representative of most relevant models on TGI. --- .../tests/test_tinyllama.py | 153 ++++++++++++------ 1 file changed, 105 insertions(+), 48 deletions(-) diff --git a/text-generation-inference/tests/test_tinyllama.py b/text-generation-inference/tests/test_tinyllama.py index 5566bccd..434f532f 100644 --- a/text-generation-inference/tests/test_tinyllama.py +++ b/text-generation-inference/tests/test_tinyllama.py @@ -1,4 +1,3 @@ - import pytest from helpers import create_request, prepare_model from text_generation_server.auto_generator import AutoGenerator @@ -17,18 +16,50 @@ def model_path(): return prepare_model(MODEL_ID, SEQUENCE_LENGTH) -def test_jetstream_info(model_path): +def _test_info(model_path, expected_device_type): """Verify the model info is correctly loaded and check expected results.""" - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1) info = generator.info assert info.requires_padding is True - assert info.device_type == "meta" + assert info.device_type == expected_device_type assert info.window_size == 0 assert info.speculate == 0 +def test_jetstream_info(model_path): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + _test_info(model_path, "meta") + + +def test_info(model_path): + _test_info(model_path, "xla") + + +def _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): + """Verify that prefilling a batch with a single request with different sampling techniques.""" + generator = AutoGenerator.from_pretrained( + model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH + ) + requests = [] + max_new_tokens = 20 + for i in range(batch_size): + requests.append(create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)) + # Let's be pessimistic when estimating max_tokens + batch_size * (len(input_text) + max_new_tokens) + batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) + generations, next_batch = generator.prefill(batch) + assert next_batch.size == batch_size + # Whatever was passed as max_tokens, the server will correct it + # because of static batching + assert next_batch.max_tokens == batch_size * SEQUENCE_LENGTH + assert len(generations) == batch_size + for g in generations: + tokens = g.tokens + assert tokens.ids == [token_id] + assert tokens.texts == [token_text] + + @pytest.mark.parametrize( "input_text, token_id, token_text, do_sample", [ @@ -49,73 +80,84 @@ def test_jetstream_info(model_path): ) @pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): - """Verify that prefilling a batch with a single request with different sampling techniques. - """ if not jetstream_pt_available(): pytest.skip("Jetstream PyTorch is not available") - generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) - requests = [] - max_new_tokens = 20 - for i in range(batch_size): - requests.append(create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)) - # Let's be pessimistic when estimating max_tokens - batch_size * (len(input_text) + max_new_tokens) - batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) - generations, next_batch = generator.prefill(batch) - assert next_batch.size == batch_size - # Whatever was passed as max_tokens, the server will correct it - # because of static batching - assert next_batch.max_tokens == batch_size * SEQUENCE_LENGTH - assert len(generations) == batch_size - for g in generations: - tokens = g.tokens - assert tokens.ids == [token_id] - assert tokens.texts == [token_text] + _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path) -def test_jetstream_prefill_change_sampling(model_path): - """Verify changing the sampling strategy between requests in the same batch works as expected. - """ - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") +@pytest.mark.parametrize( + "input_text, token_id, token_text, do_sample", + [ + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + 571, + " It", + False, + ], + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + 13, + "\n", + True, + ], + ], + ids=["greedy", "sample"], +) +@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) +def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): + _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path) + + +def _test_prefill_change_sampling( + model_path, + greedy_expected_token_id, + sampling_expected_token_id, +): + """Verify changing the sampling strategy between requests in the same batch works as expected.""" input_text = "It was a bright cold day in April, and the clocks were striking thirteen." batch_size = 1 - greedy_expected_token_id = 347 - greedy_expected_text = " The" - sampling_expected_token_id = 13 - sampling_expected_text = "\n" - generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) + generator = AutoGenerator.from_pretrained( + model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH + ) max_new_tokens = 20 - def check_request(do_sample, expected_token_id, expected_text): + def check_request(do_sample, expected_token_id): requests = [create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)] batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) generations, _ = generator.prefill(batch) tokens = generations[0].tokens - print(tokens) assert tokens.ids == [expected_token_id] - assert tokens.texts == [expected_text] generator.clear() # First request is greedy - check_request(False, greedy_expected_token_id, greedy_expected_text) + check_request(False, greedy_expected_token_id) # Second request is sampling - check_request(True, sampling_expected_token_id, sampling_expected_text) + check_request(True, sampling_expected_token_id) # Third request is greedy again - check_request(False, greedy_expected_token_id, greedy_expected_text) + check_request(False, greedy_expected_token_id) -def test_jetstream_decode_multiple(model_path): +def test_jetstream_prefill_change_sampling(model_path): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + _test_prefill_change_sampling(model_path, 347, 13) + + +def test_prefill_change_sampling(model_path): + _test_prefill_change_sampling(model_path, 571, 13) + + +def _test_continuous_batching_two_requests(model_path): """Verify that two requests added to the batch at different generation steps generate the same outputs (continuous batching). """ - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") - generator = AutoGenerator.from_pretrained(model_path, - revision="", - max_batch_size=2, - max_sequence_length=SEQUENCE_LENGTH) + generator = AutoGenerator.from_pretrained( + model_path, + revision="", + max_batch_size=2, + max_sequence_length=SEQUENCE_LENGTH, + ) input_text = "Once upon a time" max_new_tokens = 20 # Prefill a single request, remembering the generated token @@ -177,3 +219,18 @@ def test_jetstream_decode_multiple(model_path): assert output.generated_tokens == max_new_tokens assert tokens[0] == tokens[1] assert output.text == generated_text + + +def test_jetstream_decode_multiple(model_path): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + _test_continuous_batching_two_requests(model_path) + + +"""NOTE: This test does not work on PyTorch/XLA, because of the way +calculations are done in torch/xla and the effect of KV cache (they produce +similar outputs, but not identical). +""" +@pytest.mark.skip(reason="Test is not supported on PyTorch/XLA") +def test_decode_multiple(model_path): + _test_continuous_batching_two_requests(model_path) From 5298dbfdb3d25d25c68692cdbc63415bced2c90c Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 22 Nov 2024 13:51:30 +0000 Subject: [PATCH 4/5] test(gpt2): remove old test There are equivalent tests now on the TinyLlama model, that run faster, use the KV cache and sharding. The only test that does not have an equivalence is the continuous batching one, but the test was not working for most other models, so I prefer to remove it anyway, as having it passing was not representative anyway of the current state. --- text-generation-inference/tests/test_gpt2.py | 132 ------------------- 1 file changed, 132 deletions(-) delete mode 100644 text-generation-inference/tests/test_gpt2.py diff --git a/text-generation-inference/tests/test_gpt2.py b/text-generation-inference/tests/test_gpt2.py deleted file mode 100644 index 776d8418..00000000 --- a/text-generation-inference/tests/test_gpt2.py +++ /dev/null @@ -1,132 +0,0 @@ - -import pytest -from helpers import create_request, prepare_model -from text_generation_server.auto_generator import AutoGenerator -from text_generation_server.pb.generate_pb2 import Batch -from tqdm import tqdm - - -MODEL_ID = "openai-community/gpt2" -SEQUENCE_LENGTH = 1024 - - -@pytest.fixture(scope="module") -def model_path(): - return prepare_model(MODEL_ID, SEQUENCE_LENGTH) - - -def test_info(model_path): - generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1) - info = generator.info - assert info.requires_padding is True - assert info.device_type == "xla" - assert info.window_size == 0 - assert info.speculate == 0 - - -@pytest.mark.parametrize( - "input_text, token_id, token_text, do_sample", - [ - [ - "It was a bright cold day in April, and the clocks were striking thirteen.", - 383, - " The", - False, - ], - [ - "It was a bright cold day in April, and the clocks were striking thirteen.", - 775, - " We", - True, - ], - ], - ids=["greedy", "sample"], -) -@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) -def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): - generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) - requests = [] - max_new_tokens = 20 - for i in range(batch_size): - requests.append(create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)) - # Let's be pessimistic when estimating max_tokens - batch_size * (len(input_text) + max_new_tokens) - batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) - generations, next_batch = generator.prefill(batch) - assert next_batch.size == batch_size - # Whatever was passed as max_tokens, the server will correct it - # because of static batching - assert next_batch.max_tokens == batch_size * SEQUENCE_LENGTH - assert len(generations) == batch_size - for g in generations: - tokens = g.tokens - assert tokens.ids == [token_id] - assert tokens.texts == [token_text] - - -def test_decode_multiple(model_path): - generator = AutoGenerator.from_pretrained(model_path, - revision="", - max_batch_size=2, - max_sequence_length=SEQUENCE_LENGTH) - input_text = "Once upon a time" - max_new_tokens = 20 - # Prefill a single request, remembering the generated token - tokens = {0: [], 1: []} - request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens) - batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) - generations, next_batch = generator.prefill(batch) - assert next_batch.size == 1 - assert len(generations) == 1 - g = generations[0] - tokens[g.request_id].append(g.tokens.ids[0]) - assert len(tokens[0]) == 1 - # Decode a few tokens - gen_tokens = 4 - for _ in tqdm(range(gen_tokens - 1), "Decoding tokens"): - generations, next_batch = generator.decode([next_batch]) - assert len(generations) == 1 - g = generations[0] - tokens[g.request_id].append(g.tokens.ids[0]) - assert len(tokens[0]) == gen_tokens - assert next_batch.size == 1 - # Add a second request - request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens) - batch = Batch(id=1, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) - generations, next_batch_1 = generator.prefill(batch) - assert next_batch_1.size == 1 - # We should have generated only a single token - assert len(generations) == 1 - g = generations[0] - tokens[g.request_id].append(g.tokens.ids[0]) - assert len(tokens[0]) == gen_tokens - assert len(tokens[1]) == 1 - # Decode more tokens until we reach the maximum for the first request - batches = [next_batch, next_batch_1] - for _ in tqdm(range(max_new_tokens - gen_tokens), "Decoding tokens (2nd batch)"): - generations, next_batch = generator.decode(batches) - for g in generations: - tokens[g.request_id].append(g.tokens.ids[0]) - batches = [next_batch] - # Verify we now only have one pending request - assert next_batch.size == 1 - assert len(tokens[0]) == max_new_tokens - assert len(tokens[1]) == max_new_tokens - gen_tokens + 1 - # Verify we have the output for the first request - for g in generations: - if g.request_id == 0: - output = g.generated_text - assert output.text != "" - assert output.generated_tokens == max_new_tokens - generated_text = output.text - # Continue decoding until the end of the second request - for _ in tqdm(range(gen_tokens - 1), "Decoding tokens (finishing)"): - generations, next_batch = generator.decode([next_batch]) - assert len(generations) == 1 - g = generations[0] - tokens[g.request_id].append(g.tokens.ids[0]) - assert next_batch is None - output = generations[0].generated_text - assert output.generated_tokens == max_new_tokens - assert tokens[0] == tokens[1] - assert output.text == generated_text From 07f74ff4a8dfe7066c2e571aabeec43f43653efb Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 22 Nov 2024 14:32:18 +0000 Subject: [PATCH 5/5] feat(tgi): Jetstream/Pytorch is now the default engine Now that the engine is stable and tested, its engine is set as the default one for TGI. --- Makefile | 2 +- docs/source/howto/serving.mdx | 6 ++++-- optimum/tpu/jetstream_pt_support.py | 6 +++--- text-generation-inference/tests/helpers.py | 7 +++++++ text-generation-inference/tests/test_decode.py | 3 +++ .../tests/test_decode_jetstream.py | 6 ------ .../tests/test_decode_jetstream_quant.py | 6 ------ .../tests/test_generator_slot.py | 7 ++----- .../tests/test_prefill_truncate.py | 8 ++------ .../tests/test_tinyllama.py | 16 +++++----------- text-generation-inference/tests/test_warmup.py | 5 ----- 11 files changed, 27 insertions(+), 45 deletions(-) diff --git a/Makefile b/Makefile index 06580e06..5c5fb429 100644 --- a/Makefile +++ b/Makefile @@ -100,7 +100,7 @@ tgi_test_jetstream: test_installs jetstream_requirements tgi_server tgi_test: test_installs tgi_server find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \ -exec python -m pip install --force-reinstall {} \; - python -m pytest -sv text-generation-inference/tests + python -m pytest -sv text-generation-inference/tests -k "not jetstream" tgi_docker_test: tpu-tgi python -m pip install -r text-generation-inference/integration-tests/requirements.txt diff --git a/docs/source/howto/serving.mdx b/docs/source/howto/serving.mdx index df67ec38..3f1639b2 100644 --- a/docs/source/howto/serving.mdx +++ b/docs/source/howto/serving.mdx @@ -50,9 +50,11 @@ curl localhost/generate_stream \ -H 'Content-Type: application/json' ``` -### Using Jetstream Pytorch as backend +### Jetstream Pytorch and Pytorch XLA backends + +[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. This engine is selected by default if the dependency is available. +If for some reason you want to use the Pytorch/XLA backend instead, you can set the `JETSTREAM_PT_DISABLE=1` environment variable. -[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. It is possible to use this engine by setting the `JETSTREAM_PT=1` environment variable. When using Jetstream Pytorch engine, it is possible to enable quantization to reduce the memory footprint and increase the throughput. To enable quantization, set the `QUANTIZATION=1` environment variable. diff --git a/optimum/tpu/jetstream_pt_support.py b/optimum/tpu/jetstream_pt_support.py index eac038ad..4b844c60 100644 --- a/optimum/tpu/jetstream_pt_support.py +++ b/optimum/tpu/jetstream_pt_support.py @@ -8,9 +8,9 @@ def jetstream_pt_available() -> bool: """Check if the necessary imports to use jetstream_pt are available. """ try: - # For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable. - jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1" - if not jetstream_pt_enabled: + # Jetstream Pytorch is enabled by default, it can be disabled with an ENV variable. + jetstream_pt_disabled = os.environ.get("JETSTREAM_PT_DISABLE", False) == "1" + if jetstream_pt_disabled: return False # Torch XLA should not be imported before torch_xla2 to avoid conflicts. if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules: diff --git a/text-generation-inference/tests/helpers.py b/text-generation-inference/tests/helpers.py index e36f997a..93206d1e 100644 --- a/text-generation-inference/tests/helpers.py +++ b/text-generation-inference/tests/helpers.py @@ -1,5 +1,6 @@ import os +import pytest from text_generation_server.pb.generate_pb2 import ( NextTokenChooserParameters, Request, @@ -9,6 +10,12 @@ from optimum.tpu.model import fetch_model +def skip_if_jetstream_pytorch_enabled(func): + reason = "Skipping because Jetstream PyTorch is enabled" + jetstream_disabled = os.getenv("JETSTREAM_PT_DISABLE") != 1 + return pytest.mark.skipif(jetstream_disabled, reason=reason)(func) + + def prepare_model(model_id, sequence_length): # Add variables to environment so they can be used in AutoModelForCausalLM os.environ["HF_SEQUENCE_LENGTH"] = str(sequence_length) diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index 0dbe3ba1..31b64e3d 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -1,8 +1,10 @@ import pytest from decode_tests_utils import DecodeTestParams, decode_single_test +from helpers import skip_if_jetstream_pytorch_enabled +@skip_if_jetstream_pytorch_enabled @pytest.mark.parametrize("params", [ DecodeTestParams( @@ -21,6 +23,7 @@ def test_decode_single(params): decode_single_test(params) +@skip_if_jetstream_pytorch_enabled @pytest.mark.slow @pytest.mark.parametrize("params", [ diff --git a/text-generation-inference/tests/test_decode_jetstream.py b/text-generation-inference/tests/test_decode_jetstream.py index d8df274e..8a297ffd 100644 --- a/text-generation-inference/tests/test_decode_jetstream.py +++ b/text-generation-inference/tests/test_decode_jetstream.py @@ -2,8 +2,6 @@ import pytest from decode_tests_utils import DecodeTestParams, decode_single_test -from optimum.tpu.jetstream_pt_support import jetstream_pt_available - @pytest.mark.slow @pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"]) @@ -35,8 +33,6 @@ ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b", "Mixtral-8x7B"], ) def test_decode_single_jetstream_pytorch_slow(params, do_sample): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") params.do_sample = do_sample decode_single_test(params) @@ -64,7 +60,5 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample): ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny"], ) def test_decode_single_jetstream_pytorch(params, do_sample): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") params.do_sample = do_sample decode_single_test(params) diff --git a/text-generation-inference/tests/test_decode_jetstream_quant.py b/text-generation-inference/tests/test_decode_jetstream_quant.py index fe28f497..86f19365 100644 --- a/text-generation-inference/tests/test_decode_jetstream_quant.py +++ b/text-generation-inference/tests/test_decode_jetstream_quant.py @@ -2,8 +2,6 @@ import pytest from decode_tests_utils import DecodeTestParams, decode_single_test -from optimum.tpu.jetstream_pt_support import jetstream_pt_available - @pytest.mark.parametrize("params", [ @@ -22,8 +20,6 @@ ids=["gemma-2b", "TinyLLama-v0"], ) def test_decode_jetstream_quantization(quantization_jetstream_int8, params): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") decode_single_test(params) @@ -49,6 +45,4 @@ def test_decode_jetstream_quantization(quantization_jetstream_int8, params): ids=["Mixtral-8x7B", "Meta-Llama-3-8B" ,"Meta-Llama-3-70B"], ) def test_decode_jetstream_quantization_slow(quantization_jetstream_int8, params): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") decode_single_test(params) diff --git a/text-generation-inference/tests/test_generator_slot.py b/text-generation-inference/tests/test_generator_slot.py index c63d6920..9860d535 100644 --- a/text-generation-inference/tests/test_generator_slot.py +++ b/text-generation-inference/tests/test_generator_slot.py @@ -1,11 +1,10 @@ import numpy as np import pytest import torch +from helpers import skip_if_jetstream_pytorch_enabled from text_generation_server.pb.generate_pb2 import Request from transformers import AutoTokenizer, GenerationConfig -from optimum.tpu.jetstream_pt_support import jetstream_pt_available - TOKENIZERS = ["NousResearch/Llama-2-7b-hf", "openai-community/gpt2"] @@ -33,9 +32,6 @@ def tokenizer(request): ids=["spaces", "chinese-utf8", "emojis"], ) def test_decode_streaming_jetstream(tokenizer, input_text, generated_text): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") - from text_generation_server.jetstream_pt_support.generator import Slot slot = Slot(0, tokenizer) @@ -66,6 +62,7 @@ def test_decode_streaming_jetstream(tokenizer, input_text, generated_text): assert decoded_text == regenerated_text +@skip_if_jetstream_pytorch_enabled @pytest.mark.parametrize( "input_text, generated_text", [ diff --git a/text-generation-inference/tests/test_prefill_truncate.py b/text-generation-inference/tests/test_prefill_truncate.py index 43e8d2a4..eba4bd69 100644 --- a/text-generation-inference/tests/test_prefill_truncate.py +++ b/text-generation-inference/tests/test_prefill_truncate.py @@ -1,10 +1,7 @@ -import pytest -from helpers import create_request, prepare_model +from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch -from optimum.tpu.jetstream_pt_support import jetstream_pt_available - def _test_prefill_truncate(): model_id = "Maykeye/TinyLLama-v0" @@ -38,10 +35,9 @@ def _test_prefill_truncate(): def test_prefill_truncate_jetstream(): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") _test_prefill_truncate() +@skip_if_jetstream_pytorch_enabled def test_prefill_truncate(): _test_prefill_truncate() diff --git a/text-generation-inference/tests/test_tinyllama.py b/text-generation-inference/tests/test_tinyllama.py index 434f532f..c8ec58ea 100644 --- a/text-generation-inference/tests/test_tinyllama.py +++ b/text-generation-inference/tests/test_tinyllama.py @@ -1,11 +1,9 @@ import pytest -from helpers import create_request, prepare_model +from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch from tqdm import tqdm -from optimum.tpu.jetstream_pt_support import jetstream_pt_available - MODEL_ID = "Maykeye/TinyLLama-v0" SEQUENCE_LENGTH = 256 @@ -27,11 +25,10 @@ def _test_info(model_path, expected_device_type): def test_jetstream_info(model_path): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") _test_info(model_path, "meta") +@skip_if_jetstream_pytorch_enabled def test_info(model_path): _test_info(model_path, "xla") @@ -80,11 +77,10 @@ def _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model ) @pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path) +@skip_if_jetstream_pytorch_enabled @pytest.mark.parametrize( "input_text, token_id, token_text, do_sample", [ @@ -139,11 +135,10 @@ def check_request(do_sample, expected_token_id): def test_jetstream_prefill_change_sampling(model_path): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") _test_prefill_change_sampling(model_path, 347, 13) +@skip_if_jetstream_pytorch_enabled def test_prefill_change_sampling(model_path): _test_prefill_change_sampling(model_path, 571, 13) @@ -222,8 +217,6 @@ def _test_continuous_batching_two_requests(model_path): def test_jetstream_decode_multiple(model_path): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") _test_continuous_batching_two_requests(model_path) @@ -232,5 +225,6 @@ def test_jetstream_decode_multiple(model_path): similar outputs, but not identical). """ @pytest.mark.skip(reason="Test is not supported on PyTorch/XLA") +@skip_if_jetstream_pytorch_enabled def test_decode_multiple(model_path): _test_continuous_batching_two_requests(model_path) diff --git a/text-generation-inference/tests/test_warmup.py b/text-generation-inference/tests/test_warmup.py index bb16d785..102b6d00 100644 --- a/text-generation-inference/tests/test_warmup.py +++ b/text-generation-inference/tests/test_warmup.py @@ -2,17 +2,12 @@ from time import time -import pytest from helpers import create_request, prepare_model from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch -from optimum.tpu.jetstream_pt_support import jetstream_pt_available - def test_warmup_jetstream_pytorch(): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") model_id = "Maykeye/TinyLLama-v0" sequence_length = 256