Skip to content

Commit

Permalink
feat(tgi): Jetstream/Pytorch is now the default engine
Browse files Browse the repository at this point in the history
Now that the engine is stable and tested, its engine is set as the
default one for TGI.
  • Loading branch information
tengomucho committed Nov 22, 2024
1 parent 5298dbf commit 07f74ff
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ tgi_test_jetstream: test_installs jetstream_requirements tgi_server
tgi_test: test_installs tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
python -m pytest -sv text-generation-inference/tests
python -m pytest -sv text-generation-inference/tests -k "not jetstream"

tgi_docker_test: tpu-tgi
python -m pip install -r text-generation-inference/integration-tests/requirements.txt
Expand Down
6 changes: 4 additions & 2 deletions docs/source/howto/serving.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ curl localhost/generate_stream \
-H 'Content-Type: application/json'
```

### Using Jetstream Pytorch as backend
### Jetstream Pytorch and Pytorch XLA backends

[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. This engine is selected by default if the dependency is available.
If for some reason you want to use the Pytorch/XLA backend instead, you can set the `JETSTREAM_PT_DISABLE=1` environment variable.

[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. It is possible to use this engine by setting the `JETSTREAM_PT=1` environment variable.

When using Jetstream Pytorch engine, it is possible to enable quantization to reduce the memory footprint and increase the throughput. To enable quantization, set the `QUANTIZATION=1` environment variable.

Expand Down
6 changes: 3 additions & 3 deletions optimum/tpu/jetstream_pt_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def jetstream_pt_available() -> bool:
"""Check if the necessary imports to use jetstream_pt are available.
"""
try:
# For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable.
jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1"
if not jetstream_pt_enabled:
# Jetstream Pytorch is enabled by default, it can be disabled with an ENV variable.
jetstream_pt_disabled = os.environ.get("JETSTREAM_PT_DISABLE", False) == "1"
if jetstream_pt_disabled:
return False
# Torch XLA should not be imported before torch_xla2 to avoid conflicts.
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules:
Expand Down
7 changes: 7 additions & 0 deletions text-generation-inference/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import pytest
from text_generation_server.pb.generate_pb2 import (
NextTokenChooserParameters,
Request,
Expand All @@ -9,6 +10,12 @@
from optimum.tpu.model import fetch_model


def skip_if_jetstream_pytorch_enabled(func):
reason = "Skipping because Jetstream PyTorch is enabled"
jetstream_disabled = os.getenv("JETSTREAM_PT_DISABLE") != 1
return pytest.mark.skipif(jetstream_disabled, reason=reason)(func)


def prepare_model(model_id, sequence_length):
# Add variables to environment so they can be used in AutoModelForCausalLM
os.environ["HF_SEQUENCE_LENGTH"] = str(sequence_length)
Expand Down
3 changes: 3 additions & 0 deletions text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@

import pytest
from decode_tests_utils import DecodeTestParams, decode_single_test
from helpers import skip_if_jetstream_pytorch_enabled


@skip_if_jetstream_pytorch_enabled
@pytest.mark.parametrize("params",
[
DecodeTestParams(
Expand All @@ -21,6 +23,7 @@
def test_decode_single(params):
decode_single_test(params)

@skip_if_jetstream_pytorch_enabled
@pytest.mark.slow
@pytest.mark.parametrize("params",
[
Expand Down
6 changes: 0 additions & 6 deletions text-generation-inference/tests/test_decode_jetstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import pytest
from decode_tests_utils import DecodeTestParams, decode_single_test

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


@pytest.mark.slow
@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"])
Expand Down Expand Up @@ -35,8 +33,6 @@
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b", "Mixtral-8x7B"],
)
def test_decode_single_jetstream_pytorch_slow(params, do_sample):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
params.do_sample = do_sample
decode_single_test(params)

Expand Down Expand Up @@ -64,7 +60,5 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample):
ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny"],
)
def test_decode_single_jetstream_pytorch(params, do_sample):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
params.do_sample = do_sample
decode_single_test(params)
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import pytest
from decode_tests_utils import DecodeTestParams, decode_single_test

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


@pytest.mark.parametrize("params",
[
Expand All @@ -22,8 +20,6 @@
ids=["gemma-2b", "TinyLLama-v0"],
)
def test_decode_jetstream_quantization(quantization_jetstream_int8, params):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
decode_single_test(params)


Expand All @@ -49,6 +45,4 @@ def test_decode_jetstream_quantization(quantization_jetstream_int8, params):
ids=["Mixtral-8x7B", "Meta-Llama-3-8B" ,"Meta-Llama-3-70B"],
)
def test_decode_jetstream_quantization_slow(quantization_jetstream_int8, params):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
decode_single_test(params)
7 changes: 2 additions & 5 deletions text-generation-inference/tests/test_generator_slot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np
import pytest
import torch
from helpers import skip_if_jetstream_pytorch_enabled
from text_generation_server.pb.generate_pb2 import Request
from transformers import AutoTokenizer, GenerationConfig

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


TOKENIZERS = ["NousResearch/Llama-2-7b-hf", "openai-community/gpt2"]

Expand Down Expand Up @@ -33,9 +32,6 @@ def tokenizer(request):
ids=["spaces", "chinese-utf8", "emojis"],
)
def test_decode_streaming_jetstream(tokenizer, input_text, generated_text):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")

from text_generation_server.jetstream_pt_support.generator import Slot

slot = Slot(0, tokenizer)
Expand Down Expand Up @@ -66,6 +62,7 @@ def test_decode_streaming_jetstream(tokenizer, input_text, generated_text):
assert decoded_text == regenerated_text


@skip_if_jetstream_pytorch_enabled
@pytest.mark.parametrize(
"input_text, generated_text",
[
Expand Down
8 changes: 2 additions & 6 deletions text-generation-inference/tests/test_prefill_truncate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest
from helpers import create_request, prepare_model
from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


def _test_prefill_truncate():
model_id = "Maykeye/TinyLLama-v0"
Expand Down Expand Up @@ -38,10 +35,9 @@ def _test_prefill_truncate():


def test_prefill_truncate_jetstream():
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
_test_prefill_truncate()


@skip_if_jetstream_pytorch_enabled
def test_prefill_truncate():
_test_prefill_truncate()
16 changes: 5 additions & 11 deletions text-generation-inference/tests/test_tinyllama.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import pytest
from helpers import create_request, prepare_model
from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch
from tqdm import tqdm

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


MODEL_ID = "Maykeye/TinyLLama-v0"
SEQUENCE_LENGTH = 256
Expand All @@ -27,11 +25,10 @@ def _test_info(model_path, expected_device_type):


def test_jetstream_info(model_path):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
_test_info(model_path, "meta")


@skip_if_jetstream_pytorch_enabled
def test_info(model_path):
_test_info(model_path, "xla")

Expand Down Expand Up @@ -80,11 +77,10 @@ def _test_prefill(input_text, token_id, token_text, do_sample, batch_size, model
)
@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"])
def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
_test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path)


@skip_if_jetstream_pytorch_enabled
@pytest.mark.parametrize(
"input_text, token_id, token_text, do_sample",
[
Expand Down Expand Up @@ -139,11 +135,10 @@ def check_request(do_sample, expected_token_id):


def test_jetstream_prefill_change_sampling(model_path):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
_test_prefill_change_sampling(model_path, 347, 13)


@skip_if_jetstream_pytorch_enabled
def test_prefill_change_sampling(model_path):
_test_prefill_change_sampling(model_path, 571, 13)

Expand Down Expand Up @@ -222,8 +217,6 @@ def _test_continuous_batching_two_requests(model_path):


def test_jetstream_decode_multiple(model_path):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
_test_continuous_batching_two_requests(model_path)


Expand All @@ -232,5 +225,6 @@ def test_jetstream_decode_multiple(model_path):
similar outputs, but not identical).
"""
@pytest.mark.skip(reason="Test is not supported on PyTorch/XLA")
@skip_if_jetstream_pytorch_enabled
def test_decode_multiple(model_path):
_test_continuous_batching_two_requests(model_path)
5 changes: 0 additions & 5 deletions text-generation-inference/tests/test_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@

from time import time

import pytest
from helpers import create_request, prepare_model
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


def test_warmup_jetstream_pytorch():
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
model_id = "Maykeye/TinyLLama-v0"
sequence_length = 256

Expand Down

0 comments on commit 07f74ff

Please sign in to comment.