diff --git a/Makefile b/Makefile index 06580e06..5c5fb429 100644 --- a/Makefile +++ b/Makefile @@ -100,7 +100,7 @@ tgi_test_jetstream: test_installs jetstream_requirements tgi_server tgi_test: test_installs tgi_server find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \ -exec python -m pip install --force-reinstall {} \; - python -m pytest -sv text-generation-inference/tests + python -m pytest -sv text-generation-inference/tests -k "not jetstream" tgi_docker_test: tpu-tgi python -m pip install -r text-generation-inference/integration-tests/requirements.txt diff --git a/docs/source/howto/serving.mdx b/docs/source/howto/serving.mdx index df67ec38..3f1639b2 100644 --- a/docs/source/howto/serving.mdx +++ b/docs/source/howto/serving.mdx @@ -50,9 +50,11 @@ curl localhost/generate_stream \ -H 'Content-Type: application/json' ``` -### Using Jetstream Pytorch as backend +### Jetstream Pytorch and Pytorch XLA backends + +[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. This engine is selected by default if the dependency is available. +If for some reason you want to use the Pytorch/XLA backend instead, you can set the `JETSTREAM_PT_DISABLE=1` environment variable. -[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. It is possible to use this engine by setting the `JETSTREAM_PT=1` environment variable. When using Jetstream Pytorch engine, it is possible to enable quantization to reduce the memory footprint and increase the throughput. To enable quantization, set the `QUANTIZATION=1` environment variable. diff --git a/optimum/tpu/jetstream_pt_support.py b/optimum/tpu/jetstream_pt_support.py index eac038ad..4b844c60 100644 --- a/optimum/tpu/jetstream_pt_support.py +++ b/optimum/tpu/jetstream_pt_support.py @@ -8,9 +8,9 @@ def jetstream_pt_available() -> bool: """Check if the necessary imports to use jetstream_pt are available. """ try: - # For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable. - jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1" - if not jetstream_pt_enabled: + # Jetstream Pytorch is enabled by default, it can be disabled with an ENV variable. + jetstream_pt_disabled = os.environ.get("JETSTREAM_PT_DISABLE", False) == "1" + if jetstream_pt_disabled: return False # Torch XLA should not be imported before torch_xla2 to avoid conflicts. if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules: diff --git a/text-generation-inference/tests/helpers.py b/text-generation-inference/tests/helpers.py index e36f997a..93206d1e 100644 --- a/text-generation-inference/tests/helpers.py +++ b/text-generation-inference/tests/helpers.py @@ -1,5 +1,6 @@ import os +import pytest from text_generation_server.pb.generate_pb2 import ( NextTokenChooserParameters, Request, @@ -9,6 +10,12 @@ from optimum.tpu.model import fetch_model +def skip_if_jetstream_pytorch_enabled(func): + reason = "Skipping because Jetstream PyTorch is enabled" + jetstream_disabled = os.getenv("JETSTREAM_PT_DISABLE") != 1 + return pytest.mark.skipif(jetstream_disabled, reason=reason)(func) + + def prepare_model(model_id, sequence_length): # Add variables to environment so they can be used in AutoModelForCausalLM os.environ["HF_SEQUENCE_LENGTH"] = str(sequence_length) diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index 0dbe3ba1..31b64e3d 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -1,8 +1,10 @@ import pytest from decode_tests_utils import DecodeTestParams, decode_single_test +from helpers import skip_if_jetstream_pytorch_enabled +@skip_if_jetstream_pytorch_enabled @pytest.mark.parametrize("params", [ DecodeTestParams( @@ -21,6 +23,7 @@ def test_decode_single(params): decode_single_test(params) +@skip_if_jetstream_pytorch_enabled @pytest.mark.slow @pytest.mark.parametrize("params", [ diff --git a/text-generation-inference/tests/test_decode_jetstream.py b/text-generation-inference/tests/test_decode_jetstream.py index d8df274e..8a297ffd 100644 --- a/text-generation-inference/tests/test_decode_jetstream.py +++ b/text-generation-inference/tests/test_decode_jetstream.py @@ -2,8 +2,6 @@ import pytest from decode_tests_utils import DecodeTestParams, decode_single_test -from optimum.tpu.jetstream_pt_support import jetstream_pt_available - @pytest.mark.slow @pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"]) @@ -35,8 +33,6 @@ ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b", "Mixtral-8x7B"], ) def test_decode_single_jetstream_pytorch_slow(params, do_sample): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") params.do_sample = do_sample decode_single_test(params) @@ -64,7 +60,5 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample): ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny"], ) def test_decode_single_jetstream_pytorch(params, do_sample): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") params.do_sample = do_sample decode_single_test(params) diff --git a/text-generation-inference/tests/test_decode_jetstream_quant.py b/text-generation-inference/tests/test_decode_jetstream_quant.py index fe28f497..86f19365 100644 --- a/text-generation-inference/tests/test_decode_jetstream_quant.py +++ b/text-generation-inference/tests/test_decode_jetstream_quant.py @@ -2,8 +2,6 @@ import pytest from decode_tests_utils import DecodeTestParams, decode_single_test -from optimum.tpu.jetstream_pt_support import jetstream_pt_available - @pytest.mark.parametrize("params", [ @@ -22,8 +20,6 @@ ids=["gemma-2b", "TinyLLama-v0"], ) def test_decode_jetstream_quantization(quantization_jetstream_int8, params): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") decode_single_test(params) @@ -49,6 +45,4 @@ def test_decode_jetstream_quantization(quantization_jetstream_int8, params): ids=["Mixtral-8x7B", "Meta-Llama-3-8B" ,"Meta-Llama-3-70B"], ) def test_decode_jetstream_quantization_slow(quantization_jetstream_int8, params): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") decode_single_test(params) diff --git a/text-generation-inference/tests/test_generator_slot.py b/text-generation-inference/tests/test_generator_slot.py index f0f23ed7..9860d535 100644 --- a/text-generation-inference/tests/test_generator_slot.py +++ b/text-generation-inference/tests/test_generator_slot.py @@ -1,5 +1,7 @@ +import numpy as np import pytest import torch +from helpers import skip_if_jetstream_pytorch_enabled from text_generation_server.pb.generate_pb2 import Request from transformers import AutoTokenizer, GenerationConfig @@ -15,6 +17,52 @@ def tokenizer(request): return t +@pytest.mark.parametrize( + "input_text, generated_text", + [ + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + " Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind," + " slipped quickly through the glass doors of Victory Mansions, though not quickly enough" + " to prevent a swirl of gritty dust from entering along with him.", + ], + ["This sentence is written in chinese:", "我很感谢你的热情"], + ["Some text might contain a lot of emojis like 😃", "😍💪 👉 👀"], + ], + ids=["spaces", "chinese-utf8", "emojis"], +) +def test_decode_streaming_jetstream(tokenizer, input_text, generated_text): + from text_generation_server.jetstream_pt_support.generator import Slot + + slot = Slot(0, tokenizer) + request = Request(id=0, inputs=input_text) + slot.assign(0, request, GenerationConfig()) + + inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors="np") + input_ids = inputs["input_ids"][0] + generated_tokens = tokenizer(generated_text, add_special_tokens=False, return_tensors="np")["input_ids"][0] + + # We need to regenerate the full text as the tokenizer might change it (extra spaces might be added) + all_input_ids = np.concatenate([input_ids, generated_tokens]) + full_text = tokenizer.decode(all_input_ids, skip_special_tokens=True) + regenerated_text = full_text[len(input_text) :] + + # Initialize the slot with the inputs + slot.reset(input_ids, selector=None) + + assert slot.generated_tokens == 0 + + # Simulate an iterative generation (i.e. don't call select and use known tokens instead) + decoded_text = "" + for i in range(len(generated_tokens)): + text = slot.append(generated_tokens[i]) + assert slot.generated_tokens == i + 1 + decoded_text += text + + assert decoded_text == regenerated_text + + +@skip_if_jetstream_pytorch_enabled @pytest.mark.parametrize( "input_text, generated_text", [ @@ -31,6 +79,7 @@ def tokenizer(request): ) def test_decode_streaming(tokenizer, input_text, generated_text): from text_generation_server.generator import Slot + # Note: device used is cpu to make it faster slot = Slot(0, tokenizer, "cpu") request = Request(id=0, inputs=input_text) diff --git a/text-generation-inference/tests/test_gpt2.py b/text-generation-inference/tests/test_gpt2.py deleted file mode 100644 index 776d8418..00000000 --- a/text-generation-inference/tests/test_gpt2.py +++ /dev/null @@ -1,132 +0,0 @@ - -import pytest -from helpers import create_request, prepare_model -from text_generation_server.auto_generator import AutoGenerator -from text_generation_server.pb.generate_pb2 import Batch -from tqdm import tqdm - - -MODEL_ID = "openai-community/gpt2" -SEQUENCE_LENGTH = 1024 - - -@pytest.fixture(scope="module") -def model_path(): - return prepare_model(MODEL_ID, SEQUENCE_LENGTH) - - -def test_info(model_path): - generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1) - info = generator.info - assert info.requires_padding is True - assert info.device_type == "xla" - assert info.window_size == 0 - assert info.speculate == 0 - - -@pytest.mark.parametrize( - "input_text, token_id, token_text, do_sample", - [ - [ - "It was a bright cold day in April, and the clocks were striking thirteen.", - 383, - " The", - False, - ], - [ - "It was a bright cold day in April, and the clocks were striking thirteen.", - 775, - " We", - True, - ], - ], - ids=["greedy", "sample"], -) -@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) -def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): - generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) - requests = [] - max_new_tokens = 20 - for i in range(batch_size): - requests.append(create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)) - # Let's be pessimistic when estimating max_tokens - batch_size * (len(input_text) + max_new_tokens) - batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) - generations, next_batch = generator.prefill(batch) - assert next_batch.size == batch_size - # Whatever was passed as max_tokens, the server will correct it - # because of static batching - assert next_batch.max_tokens == batch_size * SEQUENCE_LENGTH - assert len(generations) == batch_size - for g in generations: - tokens = g.tokens - assert tokens.ids == [token_id] - assert tokens.texts == [token_text] - - -def test_decode_multiple(model_path): - generator = AutoGenerator.from_pretrained(model_path, - revision="", - max_batch_size=2, - max_sequence_length=SEQUENCE_LENGTH) - input_text = "Once upon a time" - max_new_tokens = 20 - # Prefill a single request, remembering the generated token - tokens = {0: [], 1: []} - request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens) - batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) - generations, next_batch = generator.prefill(batch) - assert next_batch.size == 1 - assert len(generations) == 1 - g = generations[0] - tokens[g.request_id].append(g.tokens.ids[0]) - assert len(tokens[0]) == 1 - # Decode a few tokens - gen_tokens = 4 - for _ in tqdm(range(gen_tokens - 1), "Decoding tokens"): - generations, next_batch = generator.decode([next_batch]) - assert len(generations) == 1 - g = generations[0] - tokens[g.request_id].append(g.tokens.ids[0]) - assert len(tokens[0]) == gen_tokens - assert next_batch.size == 1 - # Add a second request - request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens) - batch = Batch(id=1, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) - generations, next_batch_1 = generator.prefill(batch) - assert next_batch_1.size == 1 - # We should have generated only a single token - assert len(generations) == 1 - g = generations[0] - tokens[g.request_id].append(g.tokens.ids[0]) - assert len(tokens[0]) == gen_tokens - assert len(tokens[1]) == 1 - # Decode more tokens until we reach the maximum for the first request - batches = [next_batch, next_batch_1] - for _ in tqdm(range(max_new_tokens - gen_tokens), "Decoding tokens (2nd batch)"): - generations, next_batch = generator.decode(batches) - for g in generations: - tokens[g.request_id].append(g.tokens.ids[0]) - batches = [next_batch] - # Verify we now only have one pending request - assert next_batch.size == 1 - assert len(tokens[0]) == max_new_tokens - assert len(tokens[1]) == max_new_tokens - gen_tokens + 1 - # Verify we have the output for the first request - for g in generations: - if g.request_id == 0: - output = g.generated_text - assert output.text != "" - assert output.generated_tokens == max_new_tokens - generated_text = output.text - # Continue decoding until the end of the second request - for _ in tqdm(range(gen_tokens - 1), "Decoding tokens (finishing)"): - generations, next_batch = generator.decode([next_batch]) - assert len(generations) == 1 - g = generations[0] - tokens[g.request_id].append(g.tokens.ids[0]) - assert next_batch is None - output = generations[0].generated_text - assert output.generated_tokens == max_new_tokens - assert tokens[0] == tokens[1] - assert output.text == generated_text diff --git a/text-generation-inference/tests/test_prefill_truncate.py b/text-generation-inference/tests/test_prefill_truncate.py index 4ad78ab5..eba4bd69 100644 --- a/text-generation-inference/tests/test_prefill_truncate.py +++ b/text-generation-inference/tests/test_prefill_truncate.py @@ -1,11 +1,11 @@ -from helpers import create_request, prepare_model +from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch -def test_prefill_truncate(): - model_id="Maykeye/TinyLLama-v0" - sequence_length=1024 +def _test_prefill_truncate(): + model_id = "Maykeye/TinyLLama-v0" + sequence_length = 1024 model_path = prepare_model(model_id, sequence_length) max_new_tokens = 20 @@ -13,23 +13,31 @@ def test_prefill_truncate(): generator = AutoGenerator.from_pretrained( model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length ) - input_text = "This is something I will tell by the end of the story" + input_text = "And to finish the story, I will say that" request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length) generations, _ = generator.prefill(batch) assert len(generations) == 1 - assert generations[0].tokens.ids == [31843] - assert generations[0].tokens.texts == ["."] - + assert generations[0].tokens.ids == [357] + assert generations[0].tokens.texts == [" it"] # Now re-test but with truncate generator.clear() request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) # This will only leave last tokens - request.truncate = 6 + request.truncate = 3 batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length) generations, _ = generator.prefill(batch) assert len(generations) == 1 - assert generations[0].tokens.ids == [291] - assert generations[0].tokens.texts == [" and"] + assert generations[0].tokens.ids == [266] + assert generations[0].tokens.texts == [" the"] + + +def test_prefill_truncate_jetstream(): + _test_prefill_truncate() + + +@skip_if_jetstream_pytorch_enabled +def test_prefill_truncate(): + _test_prefill_truncate() diff --git a/text-generation-inference/tests/test_tinyllama.py b/text-generation-inference/tests/test_tinyllama.py index 5566bccd..c8ec58ea 100644 --- a/text-generation-inference/tests/test_tinyllama.py +++ b/text-generation-inference/tests/test_tinyllama.py @@ -1,12 +1,9 @@ - import pytest -from helpers import create_request, prepare_model +from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch from tqdm import tqdm -from optimum.tpu.jetstream_pt_support import jetstream_pt_available - MODEL_ID = "Maykeye/TinyLLama-v0" SEQUENCE_LENGTH = 256 @@ -17,18 +14,49 @@ def model_path(): return prepare_model(MODEL_ID, SEQUENCE_LENGTH) -def test_jetstream_info(model_path): +def _test_info(model_path, expected_device_type): """Verify the model info is correctly loaded and check expected results.""" - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1) info = generator.info assert info.requires_padding is True - assert info.device_type == "meta" + assert info.device_type == expected_device_type assert info.window_size == 0 assert info.speculate == 0 +def test_jetstream_info(model_path): + _test_info(model_path, "meta") + + +@skip_if_jetstream_pytorch_enabled +def test_info(model_path): + _test_info(model_path, "xla") + + +def _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): + """Verify that prefilling a batch with a single request with different sampling techniques.""" + generator = AutoGenerator.from_pretrained( + model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH + ) + requests = [] + max_new_tokens = 20 + for i in range(batch_size): + requests.append(create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)) + # Let's be pessimistic when estimating max_tokens + batch_size * (len(input_text) + max_new_tokens) + batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) + generations, next_batch = generator.prefill(batch) + assert next_batch.size == batch_size + # Whatever was passed as max_tokens, the server will correct it + # because of static batching + assert next_batch.max_tokens == batch_size * SEQUENCE_LENGTH + assert len(generations) == batch_size + for g in generations: + tokens = g.tokens + assert tokens.ids == [token_id] + assert tokens.texts == [token_text] + + @pytest.mark.parametrize( "input_text, token_id, token_text, do_sample", [ @@ -49,73 +77,82 @@ def test_jetstream_info(model_path): ) @pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): - """Verify that prefilling a batch with a single request with different sampling techniques. - """ - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") - generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) - requests = [] - max_new_tokens = 20 - for i in range(batch_size): - requests.append(create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)) - # Let's be pessimistic when estimating max_tokens - batch_size * (len(input_text) + max_new_tokens) - batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) - generations, next_batch = generator.prefill(batch) - assert next_batch.size == batch_size - # Whatever was passed as max_tokens, the server will correct it - # because of static batching - assert next_batch.max_tokens == batch_size * SEQUENCE_LENGTH - assert len(generations) == batch_size - for g in generations: - tokens = g.tokens - assert tokens.ids == [token_id] - assert tokens.texts == [token_text] + _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path) -def test_jetstream_prefill_change_sampling(model_path): - """Verify changing the sampling strategy between requests in the same batch works as expected. - """ - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") +@skip_if_jetstream_pytorch_enabled +@pytest.mark.parametrize( + "input_text, token_id, token_text, do_sample", + [ + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + 571, + " It", + False, + ], + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + 13, + "\n", + True, + ], + ], + ids=["greedy", "sample"], +) +@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) +def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): + _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path) + + +def _test_prefill_change_sampling( + model_path, + greedy_expected_token_id, + sampling_expected_token_id, +): + """Verify changing the sampling strategy between requests in the same batch works as expected.""" input_text = "It was a bright cold day in April, and the clocks were striking thirteen." batch_size = 1 - greedy_expected_token_id = 347 - greedy_expected_text = " The" - sampling_expected_token_id = 13 - sampling_expected_text = "\n" - generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) + generator = AutoGenerator.from_pretrained( + model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH + ) max_new_tokens = 20 - def check_request(do_sample, expected_token_id, expected_text): + def check_request(do_sample, expected_token_id): requests = [create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)] batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) generations, _ = generator.prefill(batch) tokens = generations[0].tokens - print(tokens) assert tokens.ids == [expected_token_id] - assert tokens.texts == [expected_text] generator.clear() # First request is greedy - check_request(False, greedy_expected_token_id, greedy_expected_text) + check_request(False, greedy_expected_token_id) # Second request is sampling - check_request(True, sampling_expected_token_id, sampling_expected_text) + check_request(True, sampling_expected_token_id) # Third request is greedy again - check_request(False, greedy_expected_token_id, greedy_expected_text) + check_request(False, greedy_expected_token_id) -def test_jetstream_decode_multiple(model_path): +def test_jetstream_prefill_change_sampling(model_path): + _test_prefill_change_sampling(model_path, 347, 13) + + +@skip_if_jetstream_pytorch_enabled +def test_prefill_change_sampling(model_path): + _test_prefill_change_sampling(model_path, 571, 13) + + +def _test_continuous_batching_two_requests(model_path): """Verify that two requests added to the batch at different generation steps generate the same outputs (continuous batching). """ - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") - generator = AutoGenerator.from_pretrained(model_path, - revision="", - max_batch_size=2, - max_sequence_length=SEQUENCE_LENGTH) + generator = AutoGenerator.from_pretrained( + model_path, + revision="", + max_batch_size=2, + max_sequence_length=SEQUENCE_LENGTH, + ) input_text = "Once upon a time" max_new_tokens = 20 # Prefill a single request, remembering the generated token @@ -177,3 +214,17 @@ def test_jetstream_decode_multiple(model_path): assert output.generated_tokens == max_new_tokens assert tokens[0] == tokens[1] assert output.text == generated_text + + +def test_jetstream_decode_multiple(model_path): + _test_continuous_batching_two_requests(model_path) + + +"""NOTE: This test does not work on PyTorch/XLA, because of the way +calculations are done in torch/xla and the effect of KV cache (they produce +similar outputs, but not identical). +""" +@pytest.mark.skip(reason="Test is not supported on PyTorch/XLA") +@skip_if_jetstream_pytorch_enabled +def test_decode_multiple(model_path): + _test_continuous_batching_two_requests(model_path) diff --git a/text-generation-inference/tests/test_warmup.py b/text-generation-inference/tests/test_warmup.py index bb16d785..102b6d00 100644 --- a/text-generation-inference/tests/test_warmup.py +++ b/text-generation-inference/tests/test_warmup.py @@ -2,17 +2,12 @@ from time import time -import pytest from helpers import create_request, prepare_model from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch -from optimum.tpu.jetstream_pt_support import jetstream_pt_available - def test_warmup_jetstream_pytorch(): - if not jetstream_pt_available(): - pytest.skip("Jetstream PyTorch is not available") model_id = "Maykeye/TinyLLama-v0" sequence_length = 256