Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Nov 21, 2024
1 parent 7c58951 commit 9a2b339
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ test_installs:
python -m pip install .[tests] -f https://storage.googleapis.com/libtpu-releases/index.html

tests: test_installs
python -m pytest -sv tests
python -m pytest -sv tests -k "not jetstream"

# Stand-alone TGI server for unit tests outside of TGI container
tgi_server:
Expand Down
4 changes: 3 additions & 1 deletion docs/source/howto/serving.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ curl localhost/generate_stream \

### Using Jetstream Pytorch as backend

[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. It is possible to use this engine by setting the `JETSTREAM_PT=1` environment variable.
[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. This engine is selected by default if the dependency is available.
If for some reason you want to use the Pytorch/XLA baskend, you can set the `JETSTREAM_PT_DISABLE=1` environment variable.


When using Jetstream Pytorch engine, it is possible to enable quantization to reduce the memory footprint and increase the throughput. To enable quantization, set the `QUANTIZATION=1` environment variable.

Expand Down
6 changes: 3 additions & 3 deletions optimum/tpu/jetstream_pt_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def jetstream_pt_available() -> bool:
"""Check if the necessary imports to use jetstream_pt are available.
"""
try:
# For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable.
jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1"
if not jetstream_pt_enabled:
# Jetstream Pytorch is enabled by default, it can be disabled with an ENV variable.
jetstream_pt_disabled = os.environ.get("JETSTREAM_PT_DISABLE", False) == "1"
if jetstream_pt_disabled:
return False
# Torch XLA should not be imported before torch_xla2 to avoid conflicts.
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules:
Expand Down
1 change: 1 addition & 0 deletions text-generation-inference/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import pytest
import os


# See https://stackoverflow.com/a/61193490/217945 for run_slow
Expand Down
7 changes: 7 additions & 0 deletions text-generation-inference/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import os

from text_generation_server.pb.generate_pb2 import (
Expand All @@ -9,6 +10,12 @@
from optimum.tpu.model import fetch_model


def skip_if_jetstream_pytorch_enabled(func):
reason = "Skipping because Jetstream PyTorch is enabled"
jetstream_disabled = os.getenv("JETSTREAM_PT_DISABLE") != 1
return pytest.mark.skipif(jetstream_disabled, reason=reason)(func)


def prepare_model(model_id, sequence_length):
# Add variables to environment so they can be used in AutoModelForCausalLM
os.environ["HF_SEQUENCE_LENGTH"] = str(sequence_length)
Expand Down
6 changes: 0 additions & 6 deletions text-generation-inference/tests/test_decode_jetstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import pytest
from decode_tests_utils import DecodeTestParams, decode_single_test

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


@pytest.mark.slow
@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"])
Expand Down Expand Up @@ -35,8 +33,6 @@
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b", "Mixtral-8x7B"],
)
def test_decode_single_jetstream_pytorch_slow(params, do_sample):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
params.do_sample = do_sample
decode_single_test(params)

Expand Down Expand Up @@ -64,7 +60,5 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample):
ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny"],
)
def test_decode_single_jetstream_pytorch(params, do_sample):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
params.do_sample = do_sample
decode_single_test(params)
3 changes: 2 additions & 1 deletion text-generation-inference/tests/test_prefill_truncate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from helpers import create_request, prepare_model
from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch

Expand Down Expand Up @@ -43,5 +43,6 @@ def test_prefill_truncate_jetstream():
_test_prefill_truncate()


@skip_if_jetstream_pytorch_enabled
def test_prefill_truncate():
_test_prefill_truncate()
2 changes: 0 additions & 2 deletions text-generation-inference/tests/test_tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from text_generation_server.pb.generate_pb2 import Batch
from tqdm import tqdm

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


MODEL_ID = "Maykeye/TinyLLama-v0"
SEQUENCE_LENGTH = 256
Expand Down
4 changes: 0 additions & 4 deletions text-generation-inference/tests/test_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


def test_warmup_jetstream_pytorch():
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
model_id = "Maykeye/TinyLLama-v0"
sequence_length = 256

Expand Down

0 comments on commit 9a2b339

Please sign in to comment.