Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jetstream by default #118

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ tgi_test_jetstream: test_installs jetstream_requirements tgi_server
tgi_test: test_installs tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
python -m pytest -sv text-generation-inference/tests
python -m pytest -sv text-generation-inference/tests -k "not jetstream"

tgi_docker_test: tpu-tgi
python -m pip install -r text-generation-inference/integration-tests/requirements.txt
Expand Down
6 changes: 4 additions & 2 deletions docs/source/howto/serving.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ curl localhost/generate_stream \
-H 'Content-Type: application/json'
```

### Using Jetstream Pytorch as backend
### Jetstream Pytorch and Pytorch XLA backends

[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. This engine is selected by default if the dependency is available.
If for some reason you want to use the Pytorch/XLA backend instead, you can set the `JETSTREAM_PT_DISABLE=1` environment variable.

[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. It is possible to use this engine by setting the `JETSTREAM_PT=1` environment variable.

When using Jetstream Pytorch engine, it is possible to enable quantization to reduce the memory footprint and increase the throughput. To enable quantization, set the `QUANTIZATION=1` environment variable.

Expand Down
6 changes: 3 additions & 3 deletions optimum/tpu/jetstream_pt_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def jetstream_pt_available() -> bool:
"""Check if the necessary imports to use jetstream_pt are available.
"""
try:
# For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable.
jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1"
if not jetstream_pt_enabled:
# Jetstream Pytorch is enabled by default, it can be disabled with an ENV variable.
jetstream_pt_disabled = os.environ.get("JETSTREAM_PT_DISABLE", False) == "1"
if jetstream_pt_disabled:
return False
# Torch XLA should not be imported before torch_xla2 to avoid conflicts.
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules:
Expand Down
7 changes: 7 additions & 0 deletions text-generation-inference/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import pytest
from text_generation_server.pb.generate_pb2 import (
NextTokenChooserParameters,
Request,
Expand All @@ -9,6 +10,12 @@
from optimum.tpu.model import fetch_model


def skip_if_jetstream_pytorch_enabled(func):
reason = "Skipping because Jetstream PyTorch is enabled"
jetstream_disabled = os.getenv("JETSTREAM_PT_DISABLE") != 1
return pytest.mark.skipif(jetstream_disabled, reason=reason)(func)


def prepare_model(model_id, sequence_length):
# Add variables to environment so they can be used in AutoModelForCausalLM
os.environ["HF_SEQUENCE_LENGTH"] = str(sequence_length)
Expand Down
3 changes: 3 additions & 0 deletions text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@

import pytest
from decode_tests_utils import DecodeTestParams, decode_single_test
from helpers import skip_if_jetstream_pytorch_enabled


@skip_if_jetstream_pytorch_enabled
@pytest.mark.parametrize("params",
[
DecodeTestParams(
Expand All @@ -21,6 +23,7 @@
def test_decode_single(params):
decode_single_test(params)

@skip_if_jetstream_pytorch_enabled
@pytest.mark.slow
@pytest.mark.parametrize("params",
[
Expand Down
6 changes: 0 additions & 6 deletions text-generation-inference/tests/test_decode_jetstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import pytest
from decode_tests_utils import DecodeTestParams, decode_single_test

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


@pytest.mark.slow
@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"])
Expand Down Expand Up @@ -35,8 +33,6 @@
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b", "Mixtral-8x7B"],
)
def test_decode_single_jetstream_pytorch_slow(params, do_sample):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
params.do_sample = do_sample
decode_single_test(params)

Expand Down Expand Up @@ -64,7 +60,5 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample):
ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny"],
)
def test_decode_single_jetstream_pytorch(params, do_sample):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
params.do_sample = do_sample
decode_single_test(params)
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import pytest
from decode_tests_utils import DecodeTestParams, decode_single_test

from optimum.tpu.jetstream_pt_support import jetstream_pt_available


@pytest.mark.parametrize("params",
[
Expand All @@ -22,8 +20,6 @@
ids=["gemma-2b", "TinyLLama-v0"],
)
def test_decode_jetstream_quantization(quantization_jetstream_int8, params):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
decode_single_test(params)


Expand All @@ -49,6 +45,4 @@ def test_decode_jetstream_quantization(quantization_jetstream_int8, params):
ids=["Mixtral-8x7B", "Meta-Llama-3-8B" ,"Meta-Llama-3-70B"],
)
def test_decode_jetstream_quantization_slow(quantization_jetstream_int8, params):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
decode_single_test(params)
49 changes: 49 additions & 0 deletions text-generation-inference/tests/test_generator_slot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import pytest
import torch
from helpers import skip_if_jetstream_pytorch_enabled
from text_generation_server.pb.generate_pb2 import Request
from transformers import AutoTokenizer, GenerationConfig

Expand All @@ -15,6 +17,52 @@ def tokenizer(request):
return t


@pytest.mark.parametrize(
"input_text, generated_text",
[
[
"It was a bright cold day in April, and the clocks were striking thirteen.",
" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,"
" slipped quickly through the glass doors of Victory Mansions, though not quickly enough"
" to prevent a swirl of gritty dust from entering along with him.",
],
["This sentence is written in chinese:", "我很感谢你的热情"],
["Some text might contain a lot of emojis like 😃", "😍💪 👉 👀"],
],
ids=["spaces", "chinese-utf8", "emojis"],
)
def test_decode_streaming_jetstream(tokenizer, input_text, generated_text):
from text_generation_server.jetstream_pt_support.generator import Slot

slot = Slot(0, tokenizer)
request = Request(id=0, inputs=input_text)
slot.assign(0, request, GenerationConfig())

inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors="np")
input_ids = inputs["input_ids"][0]
generated_tokens = tokenizer(generated_text, add_special_tokens=False, return_tensors="np")["input_ids"][0]

# We need to regenerate the full text as the tokenizer might change it (extra spaces might be added)
all_input_ids = np.concatenate([input_ids, generated_tokens])
full_text = tokenizer.decode(all_input_ids, skip_special_tokens=True)
regenerated_text = full_text[len(input_text) :]

# Initialize the slot with the inputs
slot.reset(input_ids, selector=None)

assert slot.generated_tokens == 0

# Simulate an iterative generation (i.e. don't call select and use known tokens instead)
decoded_text = ""
for i in range(len(generated_tokens)):
text = slot.append(generated_tokens[i])
assert slot.generated_tokens == i + 1
decoded_text += text

assert decoded_text == regenerated_text


@skip_if_jetstream_pytorch_enabled
@pytest.mark.parametrize(
"input_text, generated_text",
[
Expand All @@ -31,6 +79,7 @@ def tokenizer(request):
)
def test_decode_streaming(tokenizer, input_text, generated_text):
from text_generation_server.generator import Slot

# Note: device used is cpu to make it faster
slot = Slot(0, tokenizer, "cpu")
request = Request(id=0, inputs=input_text)
Expand Down
132 changes: 0 additions & 132 deletions text-generation-inference/tests/test_gpt2.py

This file was deleted.

30 changes: 19 additions & 11 deletions text-generation-inference/tests/test_prefill_truncate.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,43 @@
from helpers import create_request, prepare_model
from helpers import create_request, prepare_model, skip_if_jetstream_pytorch_enabled
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch


def test_prefill_truncate():
model_id="Maykeye/TinyLLama-v0"
sequence_length=1024
def _test_prefill_truncate():
model_id = "Maykeye/TinyLLama-v0"
sequence_length = 1024

model_path = prepare_model(model_id, sequence_length)
max_new_tokens = 20

generator = AutoGenerator.from_pretrained(
model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length
)
input_text = "This is something I will tell by the end of the story"
input_text = "And to finish the story, I will say that"

request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length)
generations, _ = generator.prefill(batch)
assert len(generations) == 1
assert generations[0].tokens.ids == [31843]
assert generations[0].tokens.texts == ["."]

assert generations[0].tokens.ids == [357]
assert generations[0].tokens.texts == [" it"]
# Now re-test but with truncate
generator.clear()

request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
# This will only leave last tokens
request.truncate = 6
request.truncate = 3
batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length)
generations, _ = generator.prefill(batch)
assert len(generations) == 1
assert generations[0].tokens.ids == [291]
assert generations[0].tokens.texts == [" and"]
assert generations[0].tokens.ids == [266]
assert generations[0].tokens.texts == [" the"]


def test_prefill_truncate_jetstream():
_test_prefill_truncate()


@skip_if_jetstream_pytorch_enabled
def test_prefill_truncate():
_test_prefill_truncate()
Loading
Loading