From fa24cc4710b81b54759eb75e57e32cc8a3d78269 Mon Sep 17 00:00:00 2001 From: Alvaro Moran <6949769+tengomucho@users.noreply.github.com> Date: Mon, 9 Sep 2024 11:59:07 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=88=EF=B8=8F=20Introduce=20Jetstream/Pyto?= =?UTF-8?q?rch=20in=20TGI=20(#88)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(tgi): add functions to load Jetstream Pytorch engine for Llama2 * chore(TokenSelector): remove XLA xm rng seed set * fix(version): remove warning on deprecated API use packaging.version's parse instead of pkg_resources' parse_version. * fix(generator): use pad_token_id for padding * fix(decode): clear unrequested slots * feat(imports): add function to check if Jetstream Pytorch can be used * feat(Jetstream): improved support for engine load The custom HfEngine contains functions that will allow for prefill and generate functions to use custom sampling functions. * feat(TGI): Added Jetstream/Pytorch generator This implementation is equivalent to the torch_xla one, but uses the Jetstream/Pytorch engine instead. * chore(fsdp v2): avoid importing PretrainedModel This way we can aboid trying to import torch xla. * feat(tgi): introduce AutoGenerator This is just a way to provide a factory class method to create Jetstream/Pytorch or Pytorch XLA generator. * feat(Jetstream PT): Enable support only if env var is set There are still some issues related to some fine-tuned models, so for now just enable only when JETSTREAM_PT is set. * feat(TGI): use AutoGenerator in model server * feat(package): add optional dependency on Jetstream/Pytorch For now it is possible to install dependency after optimum-tpu has been instelled, issuing this command: pip install "optimum-tpu[jetstream-pt]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html * test(Jetstream Pytorch): added a simple decode test Also adapted other tests to avoid torch-xla generator implementaion, to avoid conflict. I also added the Jetstream/pytorch test to workflow in CI. * test(decode): added a variant with do_sample=True with Jetstream PT * fix(README): correct link * doc(README): add mention on how to install and enable Pytorch/Jetstream * feat(build): make clean removes old TGI builds too * review: comply to comments requests - Added warning when trying to load torch_xla2 adter torch_xla - renamed jetstream_pt_support.check to model_can_use_jetstream_pt * review(AutoGenerator): log if using Jetstream/PT or torch xla --- .../workflows/test-pytorch-xla-tpu-tgi.yml | 10 + Makefile | 1 + README.md | 15 +- optimum/tpu/__init__.py | 1 + optimum/tpu/fsdp_v2.py | 18 +- optimum/tpu/generation/token_selector.py | 2 - optimum/tpu/jetstream_pt_support.py | 26 + optimum/tpu/version.py | 4 +- pyproject.toml | 10 +- text-generation-inference/server/Makefile | 4 +- .../server/pyproject.toml | 2 +- .../text_generation_server/auto_generator.py | 39 + .../text_generation_server/generator.py | 15 +- .../jetstream_pt_support/__init__.py | 15 + .../jetstream_pt_support/compatibility.py | 53 ++ .../jetstream_pt_support/engine.py | 83 +++ .../jetstream_pt_support/engine_loader.py | 147 ++++ .../jetstream_pt_support/generator.py | 666 ++++++++++++++++++ .../llama_model_exportable_hf.py | 297 ++++++++ .../jetstream_pt_support/logits_process.py | 80 +++ .../jetstream_pt_support/token_selector.py | 199 ++++++ .../server/text_generation_server/server.py | 4 +- .../tests/test_decode.py | 32 +- .../tests/test_generator_slot.py | 2 +- text-generation-inference/tests/test_gpt2.py | 8 +- .../tests/test_prefill_truncate.py | 4 +- 26 files changed, 1709 insertions(+), 28 deletions(-) create mode 100644 optimum/tpu/jetstream_pt_support.py create mode 100644 text-generation-inference/server/text_generation_server/auto_generator.py create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/__init__.py create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/engine.py create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/llama_model_exportable_hf.py create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/logits_process.py create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi.yml b/.github/workflows/test-pytorch-xla-tpu-tgi.yml index f9949ce9..4c681941 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi.yml @@ -31,3 +31,13 @@ jobs: - name: Build and test TGI server run: | HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test + + # Use a different step to test the Jetstream Pytorch version, to avoid conflicts with torch-xla[tpu] + - name: Install and test TGI server (Jetstream Pytorch) + run: | + pip install -U .[jetstream-pt] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ + -f https://storage.googleapis.com/libtpu-releases/index.html + JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \ + pytest -sv text-generation-inference/tests -k jetstream diff --git a/Makefile b/Makefile index 026b309c..b7723d34 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,7 @@ $(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES) clean: rm -rf dist + make -C text-generation-inference/server/ clean tpu-tgi: docker build --rm -f text-generation-inference/docker/Dockerfile \ diff --git a/README.md b/README.md index 4336a006..1ab94283 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,20 @@ Other TPU versions will be supported along the way. As part of the integration, we do support a [text-generation-inference (TGI)](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference) backend allowing to deploy and serve incoming HTTP requests and execute them on Cloud TPUs. -Please see the [TGI specific documentation]() on how to get started +Please see the [TGI specific documentation](text-generation-inference) on how to get started. + +### JetStream Pytorch Engine + +`optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated command: + +```shell +pip install "optimum-tpu[jetstream-pt]" \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ + -f https://storage.googleapis.com/libtpu-releases/index.html +``` + +To enable the support, export the environment variable `JETSTREAM_PT=1`. ## Training diff --git a/optimum/tpu/__init__.py b/optimum/tpu/__init__.py index 1982be57..4f14dfc8 100644 --- a/optimum/tpu/__init__.py +++ b/optimum/tpu/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .jetstream_pt_support import jetstream_pt_available # isort:skip from .fsdp_v2 import get_fsdp_config, use_fsdp_v2 from .modeling import AutoModelForCausalLM from .version import VERSION, __version__ diff --git a/optimum/tpu/fsdp_v2.py b/optimum/tpu/fsdp_v2.py index 8a138793..9f4a5ad1 100644 --- a/optimum/tpu/fsdp_v2.py +++ b/optimum/tpu/fsdp_v2.py @@ -15,12 +15,16 @@ """ Utility functions to provide FSDPv2 configuration for TPU training. """ -from typing import Dict, List, Union +from typing import Any, Dict, List, Union -from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging +PreTrainedModel = Any +# NOTE: instead of the above, modeling_utils.PreTrainedModel should be used, but since the usage is only for type +# hinting, it is not imported here, so to avoid pulling imports of torch_xla. + + def use_fsdp_v2(): """ Enable FSDPv2 for TPU training. @@ -61,6 +65,7 @@ def _unwrap_model(model: PreTrainedModel) -> PreTrainedModel: """ try: from peft.peft_model import LoraModel, PeftModel + if isinstance(model, PeftModel) and isinstance(model.base_model, LoraModel): return model.base_model.model return model @@ -89,10 +94,13 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict: if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM): logger = logging.get_logger(__name__) from torch_xla import __version__ as xla_version + if xla_version == "2.3.0": - logger.warning_once("Fine-tuning Gemma on Pytorch XLA 2.3.0 might raise some issues. In case of any " - "issues consider using the nightly version, and report the issue on the optimum-tpu " - "GitHub repository: https://github.com/huggingface/optimum-tpu/issues/new.") + logger.warning_once( + "Fine-tuning Gemma on Pytorch XLA 2.3.0 might raise some issues. In case of any " + "issues consider using the nightly version, and report the issue on the optimum-tpu " + "GitHub repository: https://github.com/huggingface/optimum-tpu/issues/new." + ) cls_to_wrap = "GemmaDecoderLayer" matched_model = True elif model_type == "llama": diff --git a/optimum/tpu/generation/token_selector.py b/optimum/tpu/generation/token_selector.py index cc93aa1e..33643844 100644 --- a/optimum/tpu/generation/token_selector.py +++ b/optimum/tpu/generation/token_selector.py @@ -3,7 +3,6 @@ from typing import List, Optional, Union import torch -import torch_xla.core.xla_model as xm from transformers.generation import ( GenerationConfig, GenerationMixin, @@ -53,7 +52,6 @@ def __init__( self.eos_token_ids = eos_token_ids self.pad_token_id = pad_token_id self.logits_warper = logits_warper - xm.set_rng_state(seed) self.generator = torch.Generator() self.generator.manual_seed(seed) diff --git a/optimum/tpu/jetstream_pt_support.py b/optimum/tpu/jetstream_pt_support.py new file mode 100644 index 00000000..eac038ad --- /dev/null +++ b/optimum/tpu/jetstream_pt_support.py @@ -0,0 +1,26 @@ +import os +import sys + +from loguru import logger + + +def jetstream_pt_available() -> bool: + """Check if the necessary imports to use jetstream_pt are available. + """ + try: + # For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable. + jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1" + if not jetstream_pt_enabled: + return False + # Torch XLA should not be imported before torch_xla2 to avoid conflicts. + if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules: + logger.warning("torch_xla2 cannot be imported after torch_xla, disabling Jetstream PyTorch support.") + return False + # Import torch_xla2 first! + import torch_xla2 # noqa: F401, isort:skip + + import jetstream_pt # noqa: F401 + + return True + except ImportError: + return False diff --git a/optimum/tpu/version.py b/optimum/tpu/version.py index 2343b3ba..b6b27852 100644 --- a/optimum/tpu/version.py +++ b/optimum/tpu/version.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pkg_resources import parse_version +from packaging.version import parse __version__ = "0.1.5" -VERSION = parse_version(__version__) +VERSION = parse(__version__) diff --git a/pyproject.toml b/pyproject.toml index eef4a9ac..571dd7a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,8 +43,8 @@ keywords = [ dependencies = [ "transformers == 4.41.1", - "torch >= 2.3.0, <= 2.4.0", - "torch-xla[tpu] >= 2.3.0, <= 2.4.0", + "torch == 2.4.0", + "torch-xla[tpu] == 2.4.0", "loguru == 0.6.0", "sentencepiece == 0.2.0", ] @@ -58,6 +58,12 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] tests = ["pytest", "safetensors"] quality = ["black", "ruff", "isort"] +# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit. +# Pallas is pulled because it will install a compatible version of jax[tpu]. +jetstream-pt = [ + "jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git#df92015289953c506004e674d57651b03e4e89f2", + "torch-xla[pallas] == 2.4.0" +] [project.urls] Homepage = "https://hf.co/hardware" diff --git a/text-generation-inference/server/Makefile b/text-generation-inference/server/Makefile index 9a2d3624..b26225d3 100644 --- a/text-generation-inference/server/Makefile +++ b/text-generation-inference/server/Makefile @@ -12,12 +12,14 @@ clean: # List static sources to be deployed in the package src_dir := $(mkfile_dir)/$(pkg_name) -sources := $(wildcard $(src_dir)/*.py) +rwildcard_py = $(wildcard $(1)/*.py) $(foreach d,$(wildcard $(1)/*),$(call rwildcard_py,$d)) +sources := $(call rwildcard_py,$(src_dir)) deployed_sources := $(subst $(src_dir), $(pkg_dir), $(sources)) # Static files are just copied define COPY + mkdir -p $(dir $@) cp -f $< $@ endef diff --git a/text-generation-inference/server/pyproject.toml b/text-generation-inference/server/pyproject.toml index d2f2fbca..5a3d4070 100644 --- a/text-generation-inference/server/pyproject.toml +++ b/text-generation-inference/server/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ ] [tool.setuptools] -packages = ["text_generation_server", "text_generation_server.pb"] +packages = ["text_generation_server", "text_generation_server.pb", "text_generation_server.jetstream_pt_support"] [tool.setuptools.dynamic] version = {attr = "text_generation_server.version.__version__"} diff --git a/text-generation-inference/server/text_generation_server/auto_generator.py b/text-generation-inference/server/text_generation_server/auto_generator.py new file mode 100644 index 00000000..23e5631b --- /dev/null +++ b/text-generation-inference/server/text_generation_server/auto_generator.py @@ -0,0 +1,39 @@ +from loguru import logger + +from .generator_base import Generator +from .jetstream_pt_support import model_can_use_jetstream_pt + + +class AutoGenerator: + + @staticmethod + def from_pretrained( + model_path: str, revision: str, max_batch_size: int, max_sequence_length: int + ) -> Generator: + """Instantiate a Generator for TPU using Jetstream Pytorch or Pytorch/XLA. + + Args: + model_path (`str`): + The path to a local model. This path must also contain a Tokenizer. + revision (`str`): + The revision of the model. + max_batch_size (`int`): + The maximum batch size. + max_sequence_length (`int`): + The maximum sequence length. + + Returns: + A TpuGenerator. + """ + if model_can_use_jetstream_pt(model_path): + logger.debug("Using Jetstream PyTorch generator.") + from .jetstream_pt_support.generator import TpuGeneratorJetStream + return TpuGeneratorJetStream.from_pretrained( + model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length + ) + else: + logger.debug("Using PyTorch/XLA generator.") + from .generator import TpuGenerator + return TpuGenerator.from_pretrained( + model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length + ) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index facd511d..4fdcdd19 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -590,6 +590,19 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa # just carry on with decoding. We adopt the id of the first # batch in the list as our next batch id. next_batch_id = batches[0].id + request_ids = [] + for batch in batches: + request_ids += batch.request_ids + cleared_request_ids = [] + for slot in self.slots: + if slot.state == slot.State.READY and slot.request_id not in request_ids: + cleared_request_ids.append(slot.request_id) + slot.clear() + if len(cleared_request_ids) > 0: + logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.") + active_slots = [slot for slot in self.slots if slot.state == slot.State.READY] + if len(active_slots) < len(request_ids): + logger.error("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") # Reconstruct input_ids and attention_mask from slots input_ids = None attention_mask = None @@ -608,7 +621,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa # Create blank inputs covering all slots (even empty ones) input_ids = torch.full( [batch_size, 1], - fill_value=self.tokenizer.eos_token_id, + fill_value=pad_token_id, dtype=torch.int64, ) cache_position = torch.zeros([1], dtype=torch.int64) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/__init__.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/__init__.py new file mode 100644 index 00000000..a9c19638 --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .compatibility import create_engine, model_can_use_jetstream_pt diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py new file mode 100644 index 00000000..a7a33b86 --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py @@ -0,0 +1,53 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any + +from transformers import AutoConfig + +from optimum.tpu import jetstream_pt_available + + +def model_can_use_jetstream_pt(model_path: str) -> bool: + """Checks if the model is supported by Jetstream Pytorch on Optimum TPU and if the required dependencies to provide + the engine are installed. + """ + config = AutoConfig.from_pretrained(model_path) + # For now only Llama 2 with tokenizer.model is supported + if config.model_type != "llama" or not os.path.exists( + os.path.join(model_path, "tokenizer.model") + ): + return False + if jetstream_pt_available(): + return True + return False + + +def create_engine( + model_path: str, + batch_size: int, + sequence_length: int, + max_input_tokens: int, + max_output_tokens: int, +) -> Any: + if not model_can_use_jetstream_pt(model_path): + # The model is not compatible with Jetstream PyTorch, just exit + return None + + # Now import engine_loader to prevent importing it at the top when not supported + from .engine_loader import create_engine + return create_engine( + model_path, batch_size, sequence_length, max_input_tokens, max_output_tokens + ) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine.py new file mode 100644 index 00000000..4874aa3a --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine.py @@ -0,0 +1,83 @@ +from typing import Any, Callable, Optional, Tuple + +import jax +import jax.numpy as jnp +import torch +from jetstream.engine import engine_api +from jetstream_pt import engine + + +class HfEngine(engine.PyTorchEngine): + def __init__( + self, + pt_model: torch.nn.Module, + env: engine.JetEngineEnvironment, + weights=None, + ): + super().__init__(pt_model, env, weights) + self.prefill_ex = jax.jit( + self.prefill_ex, + out_shardings=(self.get_prefix_destination_sharding(), None), + ) + + def generate_ex( + self, params: Any, decode_state: engine.DecodeState, sampling_fn: Callable[[Any, int], jax.Array] + ) -> tuple[engine.DecodeState, engine_api.ResultTokens]: + sampling_fn_backup = self._sampling + self._sampling = sampling_fn + new_decode_state, result_tokens = self.generate(params, decode_state) + self._sampling = sampling_fn_backup + return new_decode_state, result_tokens + + def prefill_ex( + self, + *, + params: Any, # Weights + _existing_prefix: Optional[engine.Prefix] = None, + padded_tokens: jax.Array, + true_length: int, + sampling_fn: Callable[[jax.Array], jax.Array], + ) -> Tuple[engine.Prefix, engine_api.ResultTokens]: + if isinstance(padded_tokens, jax.Array): + batched_token = padded_tokens.reshape(1, -1) + else: + raise TypeError("Input tokens should be of type Jax Array, but receiving:" " {prefill_inputs}") + seq_len = padded_tokens.shape[0] + input_indexes = jnp.arange(0, seq_len) + logits, updated_caches = self._call_model_prefill( + params, + batched_token, + input_indexes, + ) + if len(logits.shape) == 3: # b, seqlen, num words + logits = logits[0] # seqlen, num words + + # This is equivalent to last_logits = logits[:, true_length - 1, :], but it can be jitted + last_logits = jax.lax.dynamic_slice_in_dim(logits, true_length - 1, 1, axis=0) + token = sampling_fn(last_logits) + token_out = jnp.reshape(token, (1, 1)) + data = jnp.concatenate( + [ + token_out, # First token + jnp.ones_like(token_out), # validity of first token + jnp.zeros((1, 1), dtype=jnp.int32), # length = 0 + ], + axis=-1, + ) + length = token_out.shape[1] + result = engine_api.ResultTokens( + data=data, + tokens_idx=(0, length), + valid_idx=(length, 2 * length), + length_idx=(2 * length, 2 * length + 1), + samples_per_slot=1, + ) + # truncate to true_length didnt work need to be out side of jit + # caches = [ + # (jax.lax.dynamic_slice_in_dim( + # k, seq_len - true_length, true_length, axis=2), + # jax.lax.dynamic_slice_in_dim( + # v, seq_len - true_length, true_length, axis=2)) + # for k, v in updated_caches + # ] + return engine.Prefix(token, updated_caches, true_length), result diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py new file mode 100644 index 00000000..6e933df5 --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -0,0 +1,147 @@ +# Import torch_xla2 first +import torch_xla2 # isort:skip + +from typing import Any + +import jax +from jetstream_pt import fetch_models, torchjax +from jetstream_pt.environment import ( + JetEngineEnvironment, + JetEngineEnvironmentData, + QuantizationConfig, +) +from loguru import logger +from transformers import AutoConfig + +from .engine import HfEngine +from .llama_model_exportable_hf import TransformerHf + + +def load_llama_model_info(model_path: str) -> Any: + # First get config + config = AutoConfig.from_pretrained(model_path) + num_layers = config.num_hidden_layers + num_heads = config.num_attention_heads + head_dim = config.hidden_size // num_heads + n_reps = config.num_key_value_heads // num_heads + model_info = fetch_models.ModelInfo( + TransformerHf, + num_layers=num_layers, + num_heads=num_heads, + head_dim=head_dim, + n_reps=n_reps, + ) + return model_info + + +def load_model_info(model_path: str) -> Any: + config = AutoConfig.from_pretrained(model_path) # For now only Llama 2 is supported + if config.model_type == "llama": + return load_llama_model_info(model_path) + # Other models supports can be added here later + return None + + +def create_engine_env_data( + model_path: str, + batch_size: int, + sequence_length: int, + max_input_tokens: int, + max_output_tokens: int, +) -> Any: + model_info = load_model_info(model_path) + if model_info is None: + return None + + shard_on_batch = False + max_cache_length = max_input_tokens + max_output_tokens + + env_data = JetEngineEnvironmentData( + tokenizer_path="", # Tokenizer is not user, HF tokenizer is used instead + checkpoint_path=model_path, + checkpoint_format="safetensors", + batch_size=batch_size, + max_decode_length=sequence_length, + max_input_sequence_length=max_input_tokens, + quant_config=QuantizationConfig(), + cache_sequence_length=max_cache_length, + bf16_enable=True, + sharding_config_path="", + shard_on_batch=shard_on_batch, + n_reps=model_info.n_reps, + ) + env_data.cache_shape = ( + batch_size, + model_info.num_heads, + max_cache_length, + model_info.head_dim, + ) + env_data.num_layers = model_info.num_layers + return env_data + + +def create_model(model_path: str, env: Any) -> Any: + config = AutoConfig.from_pretrained(model_path) + if config.model_type == "llama": + return TransformerHf.from_config(config, env) + + +def instantiate_model_from_repo_id( + model_dir: str, + env: Any, +): + """Create model instance by hf model_dir, and its config""" + model_info = load_model_info(model_dir) + # at this point we can be quite optimistic and just assert + assert model_info is not None + + env.device = "meta" + model = create_model(model_dir, env) + weights = fetch_models._load_weights(model_dir) + updated_keys = model.get_hf_names_to_real_name() + for name, updated in updated_keys.items(): + if name in weights: + val = weights.pop(name) + weights[updated] = val + + model.load_state_dict(weights, assign=True, strict=False) + + return model + + +def shard_weights(env, weights, weight_shardings): + """Shard weights according to weight_shardings""" + for k, v in weight_shardings.items(): + logger.debug(f"SHARDING {k} {v}") + sharded = {} + for key, val in weights.items(): + sharding = env.sharding_by_axis(weight_shardings.get(key, -1)) + with jax.default_device(jax.devices("cpu")[0]): + arr = torch_xla2.tensor.t2j(val) + arr = jax.device_put(arr, sharding) + sharded[key] = torchjax.to_torch(arr) + return sharded + + +def create_engine( + model_path: str, + batch_size: int, + sequence_length: int, + max_input_tokens: int, + max_output_tokens: int, +) -> HfEngine: + # NOTE: for now no quantization is done + env_data = create_engine_env_data(model_path, batch_size, sequence_length, max_input_tokens, max_output_tokens) + if env_data is None: + return None + + env = JetEngineEnvironment(env_data) + model = instantiate_model_from_repo_id(model_path, env) + weight_shardings = model.get_sharding_annotations() + sharded_weights = shard_weights(env, model.state_dict(), weight_shardings) + + return HfEngine( + pt_model=model, + env=env, + weights=torchjax.from_torch_with_copy(sharded_weights), + ) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py new file mode 100644 index 00000000..59291bf7 --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -0,0 +1,666 @@ +import copy +import logging +import time +from enum import Enum +from typing import Any, List, Optional, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +import torch +import torch_xla2 +from jetstream.engine.token_utils import pad_tokens, take_nearest_length +from loguru import logger +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers.generation import GenerationConfig + +from ..generator_base import Generator +from ..pb.generate_pb2 import ( + Batch, + CachedBatch, + FinishReason, + GeneratedText, + Generation, + InfoResponse, + NextTokenChooserParameters, + Request, + StoppingCriteriaParameters, + Tokens, +) +from .engine import HfEngine +from .engine_loader import create_engine +from .token_selector import TokenSelector + + +# Disable optimum-tpu warnings as it seems to block the server after a while +optimum_logger = logging.getLogger("optimum.tpu") +optimum_logger.setLevel("CRITICAL") + +# These will do some bucketing on prefill lengths to avoid too many different sizes +PREFILL_LENGTHS = [ + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, +] + +class Slot: + """Represents a slot in a static batch""" + + class State(Enum): + EMPTY = 0 + READY = 1 + + def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase): + self._id = id + self._tokenizer = tokenizer + self.clear() + + def clear(self): + """Clear the slot and mark it as available.""" + self._state = Slot.State.EMPTY + self._batch_id = None + self._request_id = None + self._inputs = "" + self._generation_config = None + self._tokens = [] + self._selector = None + self._generated_tokens = 0 + self._next_text_token_start = 0 + self._next_text_token_end = 0 + self._generated_text = "" + self._next_text = "" + self._truncate = 0 + + @property + def id(self) -> int: + return self._id + + @property + def state(self) -> "Slot.State": + return self._state + + @property + def batch_id(self) -> int: + return self._batch_id + + @property + def request_id(self) -> int: + return self._request_id + + @property + def generation_config(self) -> GenerationConfig: + return self._generation_config + + @property + def generated_tokens(self) -> int: + return self._generated_tokens + + @property + def truncate(self) -> int: + return self._truncate + + @property + def tokens(self) -> jax.Array: + return self._tokens + + def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig): + """Assign a request to a slot. + + Args: + batch_id (`int`): The id of the batch containing the request. + request (`Request`): + The request to be assigned. Contains the inputs and tokens selection parameters. + generation_config (`transformers.GenerationConfig`): + The base generation config (might be modified by the request generation parameters). + """ + self._state = Slot.State.READY + self._batch_id = batch_id + self._request_id = request.id + self._inputs = request.inputs + self._generation_config = copy.deepcopy(generation_config) + # Update generation config with token chooser parameters + self._generation_config.temperature = request.parameters.temperature + self._generation_config.top_k = request.parameters.top_k + self._generation_config.top_p = request.parameters.top_p + self._generation_config.typical_p = request.parameters.typical_p + self._generation_config.do_sample = request.parameters.do_sample + self._generation_config.repetition_penalty = request.parameters.repetition_penalty + self._truncate = request.truncate + self.seed = request.parameters.seed + # TODO: watermark + self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens + self._max_new_tokens = self._generation_config.max_new_tokens + # TODO: stop_sequences, ignore_eos_token + + def reset(self, input_ids: jax.Array, selector: TokenSelector): + """Reset the slot for the next generation. + + Args: + input_ids: (`jax.Array`): + The new input_ids to use to generate the next token. + selector: (`TokenSelector`): + An object implementing the updated token selection logic. + """ + self._tokens = input_ids + self._next_text_token_start = 0 + self._next_text_token_end = self._tokens.shape[-1] + self._next_text = "" + self._selector = selector + + def _decode_next_tokens( + self, + ) -> str: + """Hack to hopefully support generate_stream for the maximum number of tokenizers""" + tokens = self._tokens + # We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode + # which decide to add a space or not depending on the surrounding ids. + new_text = self._tokenizer.decode(self._tokens[self._next_text_token_start :], skip_special_tokens=False) + if new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + return "" + + # Compare the generated text with the one using only the tokens producing the last one + last_text = self._tokenizer.decode( + tokens[self._next_text_token_start : self._next_text_token_end], + skip_special_tokens=False, + ) + if len(new_text) == len(last_text): + # Nothing new was actually generated + return "" + # Return the decoded text and store its token offsets + self._next_text_token_start = self._next_text_token_end + self._next_text_token_end = tokens.shape[-1] + return new_text[len(last_text) :] + + def append(self, next_token: int) -> str: + """Append a new generated token to this slot + + The new token is added to the list of generated tokens, which impacts + directly the generated_text and stopped property. + + The new token is however not added immediately to the slot inputs: it will + be added later on when it has effectively been used to produce the next token. + + Args: + next_token (`int`): + The newly generated token. + + Return: + The corresponding decoded text (if any). + """ + self._tokens = jnp.concat([self._tokens, jnp.array([next_token])]) + self._generated_tokens += 1 + next_text = self._decode_next_tokens() + # Now that a new token has been generated, we can append the previous one to the generated text + self._generated_text += self._next_text + self._next_text = next_text + return next_text + + def select(self, logits: jnp.ndarray) -> int: + """Select the next token from the candidate logits. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + The logits corresponding to the generated tokens. + + Return: + int: A scalar of the selected token. + """ + return self._selector.select(self._tokens, logits)[0] + + @property + def stopped(self) -> bool: + # unsqueeze tokens to avoid problems with stopping criteria + tokens = torch_xla2.tensor.j2t(self._tokens).unsqueeze(0) + return bool(torch.all(self._selector.stopping_criteria(tokens, None))) + + @property + def generated_text(self) -> str: + return self._generated_text + self._next_text + + @property + def next_token(self) -> int: + return None if len(self._tokens) == 0 else self._tokens[-1] + + @property + def empty(self) -> bool: + return len(self._tokens) == 0 + + +class TpuGeneratorJetStream(Generator): + """A Generator for models running on TPU, single threaded.""" + + def __init__( + self, + engine: HfEngine, + tokenizer: PreTrainedTokenizerBase, + ): + self.engine = engine + logger.debug("Loading params (i.e. weights) on engine") + self.params = engine.load_params() + logger.debug("Weights loaded") + logger.debug("Initializing decode state") + self.decode_state = engine.init_decode_state() + logger.debug("Decode state initialized") + + # Note: Jetstream/Pytorch requires padding to be done with 0 (at least when not specified) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = 0 + tokenizer.padding_side = "left" + tokenizer.truncation_side = "left" + self.tokenizer = tokenizer + self.special_tokens = self.tokenizer.all_special_ids + + # Slots are empty to begin with, they will be populated as new batches arrive + self.slots = [] + self.batch_id = 0 + # Note: this index will _never_ be decremented, and that's fine. + self.slot_index = 0 + + @property + def info(self) -> InfoResponse: + """Returns the expected InfoResponse.""" + dtype = self.engine.default_dtype.dtype + # NOTE: the device type reported is "meta", even if it's a TPU + return InfoResponse( + requires_padding=True, + dtype=str(dtype), + device_type=self.engine.env.device, + ) + + def _create_dummy_request(self, max_tokens: int) -> Batch: + """Create a dummy request for warmup.""" + # Generate a random input with slightly more tokens than requested, because special tokens are going to be + # skipped. + MARGIN = 10 + input_tokens = np.random.randint(self.model.config.vocab_size, size=(1, max_tokens + MARGIN), dtype=np.int64) + text = self.tokenizer.decode(input_tokens[0], skip_special_tokens=True) + # These are just dummy params to allow Request creation + parameters = NextTokenChooserParameters( + temperature=1.0, + top_k=None, + top_p=None, + do_sample=False, + seed=None, + repetition_penalty=1.0, + typical_p=1.0, + ) + stopping_parameters = StoppingCriteriaParameters(max_new_tokens=20, ignore_eos_token=True) + dummy_request = Request( + id=0, + inputs=text, + truncate=max_tokens, + parameters=parameters, + stopping_parameters=stopping_parameters, + ) + return dummy_request + + def warmup(self, batch: Batch) -> int: + """Verify if the hardware can support the target load. + + Args: + batch (`Batch`): + A batch corresponding to the maximum number of concurrent requests. + + Return: + The maximum number of tokens the model supports. + """ + logger.debug("Warming up the model") + start = time.time() + # Just check that the warmup request parameters match the model capacity + batch_size = self.engine.env.batch_size + if len(batch.requests) > batch_size: + raise ValueError( + f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length." + ) + + # Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible + # batch sizes and sequence lengths. + seq_len = self.model.config.sequence_length + bucket_seq_len = take_nearest_length(PREFILL_LENGTHS, seq_len) + dummy_request = self._create_dummy_request(seq_len) + decode_done = False + for l in reversed(PREFILL_LENGTHS): + # Skip all the unsupported lengths + if l > bucket_seq_len: + continue + # Set all truncate values for all requests + dummy_request.truncate = l + dummy_request.stopping_parameters.max_new_tokens = 10 + warmup_batch = Batch(id=0, + requests=[dummy_request], + size=1, + max_tokens=batch.max_tokens) + logger.debug(f"Warmup for requests, len {l} seq_len {seq_len}") + _generations, next_batch = self.prefill(warmup_batch) + if not decode_done and next_batch is not None: + self.decode([next_batch]) + decode_done = True + self.clear() + if not decode_done: + logger.debug("No decode done during warmup") + + self.prefill(batch) + self.clear() + elapsed = time.time() - start + logger.debug(f"Warmup done, took {elapsed:.2f}s") + seq_len = self.engine.env.seq_len + return batch_size * seq_len + + def _get_slot_id(self): + """Get the next available slot id.""" + batch_size = self.engine.env.batch_size + used_ids = [slot.id for slot in self.slots if slot.state != Slot.State.EMPTY] + for i in range(batch_size): + if i not in used_ids: + return i + # if we reach this point, all slots were used - this should not happen + raise ValueError("All slots are used, but we should have stopped earlier") + + @property + def model(self): + return self.engine.pt_model + + def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]: + """Tokenize the input text and return the corresponding input_ids and true_length. + + Args: + text (`str`): + The input text to tokenize. + max_length (`int`): + The maximum length of the input_ids (typically from request) + """ + if max_length == 0: + max_length = self.model.config.sequence_length + input_ids = self.tokenizer.encode( + text, + return_tensors="np", + truncation=True, + max_length=max_length, + add_special_tokens=False, + ) + tokens, true_length = pad_tokens(input_ids[0], + self.tokenizer.bos_token_id, + self.tokenizer.pad_token_id, + is_bos=True, + max_prefill_length=self.model.config.sequence_length, + jax_padding=True, + ) + return tokens, true_length + + def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: + """Prefill new requests. + + Args: + batch (`Batch`): + A batch containing the new requests. + + Return: + A list of `Generation` for each request and a `CachedBatch` containing all pending requests. + """ + + slots = {state: [] for state in Slot.State} + for slot in self.slots: + slots[slot.state].append(slot) + len_active_slots = len(slots[Slot.State.READY]) + # Delete all empty slots, no need to have them anymore + empty_slots = slots[Slot.State.EMPTY] + for slot in empty_slots: + self.slots.remove(slot) + len_requests = len(batch.requests) + model_batch_size = self.model.config.batch_size + if model_batch_size is not None and model_batch_size < len_active_slots + len_requests: + # If raising an error here wouldn't crash the server, we could raise a ValueError + error = ValueError( + f"Cannot prefill {len_requests} new request(s)." + f" Maximum batch size supported is: {model_batch_size}." + ) + # but since it's not possible, we just log the error and return an empty generation + logger.error(error) + return [], None + # Assign each request to an empty slot + logger.debug(f"Prefilling {len_requests} new request(s) adding to {len_active_slots} active slot(s)") + generations = [] + + for request in batch.requests: + # Dynamically create a new slot for each request + slot = Slot(self._get_slot_id(), self.tokenizer) + self.slot_index += 1 + slot.assign(self.batch_id, request, self.model.generation_config) + logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}") + + # Tokenize the inputs + input_ids, true_lengths = self._token_encode(request.inputs, slot.truncate) + truncated_input_ids = input_ids[:true_lengths] + selector = TokenSelector.create( + truncated_input_ids, + slot.generation_config, + self.model, + self.model.config.sequence_length, + seed=slot.seed, + ) + slot.reset(truncated_input_ids, selector) + # To allow jit'ing the select function, we need to wrap it in a partial + slot_select = jax.tree_util.Partial(slot.select) + # Ask for prefill and insert + prefill_results, _result_tokens = self.engine.prefill_ex( + params=self.params, + padded_tokens=input_ids, + true_length=true_lengths, + sampling_fn=slot_select, + ) + next_token = prefill_results.token.item() + self.decode_state = self.engine.insert(prefill_results, self.decode_state, slot.id) + + self._post_generate(slot, next_token, generations) + if not slot.empty: + # append current to list of active slots + self.slots.append(slot) + len_active_slots += 1 + + if len_active_slots > 0: + # Whatever initial batch these requests came from, we always return all pending requests in a single batch + request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] + batch = self._cached_batch(self.batch_id, request_ids) + else: + logger.debug("No more pending requests") + self.batch_id += 1 + logger.debug("Model ready for decoding") + return generations, batch + + def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: + """Decode the specified prefilled requests. + + Args: + batches (`List[CachedBatch]`): + A list of previous batches containing the prefilled requests. + + Return: + A list of `Generation` for each request and a `CachedBatch` containing all pending requests. + """ + # batches contains a list composed of ongoing requests: + # - the batch id returned by the last decode, + # - the batch id(s) returned by the last prefill(s) + # Batches are always concatenated during prefill, so we can + # just carry on with decoding. We adopt the id of the first + # batch in the list as our next batch id. + next_batch_id = batches[0].id + if len(batches) > 1: + logger.warning("Unexpected multiple batches received, only the first one will be processed.") + request_ids = [] + for batch in batches: + request_ids += batch.request_ids + cleared_request_ids = [] + for slot in self.slots: + if slot.state == slot.State.READY and slot.request_id not in request_ids: + cleared_request_ids.append(slot.request_id) + self.slots.remove(slot) + if len(cleared_request_ids) > 0: + logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.") + active_slots = [slot for slot in self.slots if slot.state == slot.State.READY] + if len(active_slots) < len(request_ids): + raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") + + # Define a custom function to select the next token for each slot + pad_token_id = self.tokenizer.pad_token_id + + def select_from_slots(logits: Any, batch_size: int) -> jnp.ndarray: + tokens = jnp.full((batch_size, 1), pad_token_id) + for slot in active_slots: + # Every slot might have a different selection criteria, so we are obliged to call select in a loop + next_token = slot.select(logits) + tokens = tokens.at[slot.id].set(next_token) + return tokens + + select_fn = select_from_slots + self.decode_state, result_tokens = self.engine.generate_ex(self.params, self.decode_state, select_fn) + + newly_empty = [] + generations = [] + for slot in active_slots: + # Get the next token. + # Note that for now we ignore is_valid and length as we don't use them, we will re-parse these in post + # generation. + next_token, _is_valid, _length = result_tokens.data[slot.id] + + if slot.state != Slot.State.READY: + logger.error(f"Unexpected Slot {slot.id} is not ready for decoding, skipping.") + raise ValueError("Unexpected Slot is not ready for decoding") + + self._post_generate(slot, next_token, generations) + if slot.empty: + newly_empty.append(slot) + + # Remove empty slots + for slot in newly_empty: + self.slots.remove(slot) + batch = None + if len(self.slots) > 0: + # Whatever initial batch these requests came from, we always return all pending requests in a single batch + request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] + batch = self._cached_batch(next_batch_id, request_ids) + else: + logger.debug("No more pending requests") + return generations, batch + + def _post_generate(self, slot: Slot, next_token: int, generations: List[Generation]) -> None: + """Post-generate a slot after the generation has been completed. + + This will check if the slot is finished and append the generated text to the response. + + Args: + slot (`Slot`): + The slot to post-generate. + next_token (`int`): + The next token generated by the model. + generations (`List[Generation]`): + The list of generations to append the slot to. + """ + # prepare the generation response + next_token_text = slot.append(next_token) + generated_text = None + finish_reason = None + if next_token == self.tokenizer.eos_token_id: + finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN + elif slot.stopped: + # For now we only support the length stopping criteria + finish_reason = FinishReason.FINISH_REASON_LENGTH + request_id = slot.request_id + if finish_reason is not None: + # We must include the generated text for each finished sequence in the response + generated_text = GeneratedText( + text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason + ) + logger.debug(f"Finished generating tokens for request {request_id}") + # This slot is now empty, it will be removed from the list of + # active slots. + slot.clear() + generations.append( + Generation( + request_id=request_id, + prefill_tokens=None, + tokens=Tokens( + ids=[next_token], + logprobs=[0], + texts=[next_token_text], + is_special=[next_token in self.special_tokens], + ), + generated_text=generated_text, + ) + ) + + def _cached_batch(self, batch_id: int, request_ids: List): + size = len(request_ids) + max_tokens = size * self.model.config.sequence_length + return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens) + + def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch: + """Remove requests that are not listed from the specified batch + + Args: + batch_id (`int`): + The id of a cached batch. + request_ids(`List[int]`): + The list of requests that must be kept. + + Return: + A `CachedBatch` containing the pending requests. + """ + keep_slot_ids = [slot.id for slot in self.slots if slot.request_id in keep_request_ids] + self._clear(keep_slot_ids) + return self._cached_batch(batch_id, keep_request_ids) + + def clear(self, batch_id: Optional[int] = None): + """Remove a subset or all requests from the generator""" + keep_ids = [] + if batch_id is not None: + keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id] + return self._clear(keep_ids) + + def _clear(self, keep_slot_ids: List): + for slot in self.slots: + if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids: + logger.debug(f"Removing slot {slot.id} with request {slot.request_id}") + slot.clear() + + @classmethod + def from_pretrained(cls, model_path: str, revision: str, max_batch_size: int, max_sequence_length: int): + """Instantiate a Generator that uses JetStream/Pytorch engine. + + Args: + model_path (`str`): + The path to a local model. This path must also contain a Tokenizer. + revision (`str`): + Deprecated parameter, only an empty string or None is supported, other values are ignored. + max_batch_size (`int`): + The maximum batch size. + max_sequence_length (`int`): + The maximum sequence length. + + Returns: + A TpuGenerator. + """ + if revision != "": + logger.warning("Revision is not supported for JetStream/Pytorch engine, ignoring.") + logger.info("Loading model engine (this can take a few minutes).") + start = time.time() + engine = create_engine( + model_path, + max_batch_size, + sequence_length=max_sequence_length, + max_input_tokens=max_sequence_length, + max_output_tokens=max_sequence_length, + ) + end = time.time() + logger.info(f"Engine successfully loaded in {end - start:.2f} s.") + tokenizer = AutoTokenizer.from_pretrained(model_path) + return cls(engine, tokenizer) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/llama_model_exportable_hf.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/llama_model_exportable_hf.py new file mode 100644 index 00000000..1bab00d3 --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/llama_model_exportable_hf.py @@ -0,0 +1,297 @@ +from typing import Any, List, Optional + +import jax +import torch +import torch.nn.functional as F +from jetstream_pt.layers import ( + Attention, + RMSNorm, + get_quantized_embedding_layer, + get_quantized_linear_layer, +) +from jetstream_pt.model_base import ModuleBase +from transformers import GenerationConfig, GenerationMixin, LlamaConfig + + +class FeedForward(ModuleBase): + """Feed-forward module, AKA LlamaMLP on HuggingFace. + + Note the main difference is that it uses intermediate_size instead of multiple_of and ffn_dim_multiplier. + The parameter dim here corresponds to hidden_size in HuggingFace's Llama model, and hidden_dim is not really used, + because intermediate_size is used instead. + """ + + def __init__( + self, + dim: int, + intermediate_size: int, + device="meta", + env=None, + ): + super().__init__() + self.env = env + + LinearLayer = get_quantized_linear_layer(env.quant_config) + linear_kwargs = {} + if LinearLayer != torch.nn.Linear: + linear_kwargs["quant_config"] = env.quant_config + + self.w1 = LinearLayer( + dim, + intermediate_size, + bias=False, + device=device, + **linear_kwargs, + ) + self.w2 = LinearLayer( + intermediate_size, + dim, + bias=False, + device=device, + **linear_kwargs, + ) + self.w3 = LinearLayer( + dim, + intermediate_size, + bias=False, + device=device, + **linear_kwargs, + ) + self.hf_name("w1", "gate_proj") + self.hf_name("w2", "down_proj") + self.hf_name("w3", "up_proj") + + self.annotate_sharding("w1.weight", 0) + self.annotate_sharding("w2.weight", 1) + self.annotate_sharding("w3.weight", 0) + + def forward(self, x): + result = self.w2(F.silu(self.w1(x)) * self.w3(x)) + return result + + +class TransformerBlockHf(ModuleBase): + """This is essentially the same as the JetstreamPytoch Transformer, but it avoids using multiple_of and + ffn_dim_multiplier that are not available in HuggingFace's Llama model, and it uses intermediate_size instead. + """ + + def __init__( + self, + layer_id: int, + config: LlamaConfig, + device, + env, + ): + super().__init__() + self.env = env + self.n_heads = config.num_attention_heads + self.dim = config.hidden_size + self.head_dim = config.hidden_size // config.num_attention_heads + + self.attention = Attention( + config.num_attention_heads, + config.num_key_value_heads or config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + config.hidden_size, + env=env, + device=device, + layer_id=layer_id, + ) + self.feed_forward = FeedForward( + dim=config.hidden_size, + intermediate_size=config.intermediate_size, + device=device, + env=env, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, device=device + ) + self.ffn_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, device=device + ) + + self.hf_name("attention", "self_attn") + self.attention.hf_name("wq", "q_proj") + self.attention.hf_name("wk", "k_proj") + self.attention.hf_name("wv", "v_proj") + self.attention.hf_name("wo", "o_proj") + + self.attention.annotate_sharding("wq.weight", 0) + self.attention.annotate_sharding("wk.weight", 0) + self.attention.annotate_sharding("wv.weight", 0) + self.attention.annotate_sharding("wo.weight", 1) + + self.hf_name("feed_forward", "mlp") + self.hf_name("attention_norm", "input_layernorm") + self.hf_name("ffn_norm", "post_attention_layernorm") + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + cache, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, + ): + with jax.named_scope("Attention"): + attn = self.attention.forward( + self.attention_norm(x), + freqs_cis, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, + ) + with jax.named_scope("ffn_norm"): + h = x + attn + ffns = self.ffn_norm(h) + + with jax.named_scope("ffn"): + out = h + self.feed_forward.forward(ffns) + return out + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +class TransformerHf(ModuleBase, GenerationMixin): + """Transformer module that uses HF LlamaConfig instead of Jetstream Pytorch ModelArgs + device. + + Note that this class also derives from GenerationMixin, so that we can use its methods. + """ + + def __init__( + self, + config: LlamaConfig, + device, + env, + ): + super().__init__() + self.env = env + self.config = config + self.generation_config = GenerationConfig.from_model_config(config) + self.vocab_size = config.vocab_size + self.n_layers = config.num_hidden_layers + + Embedding = get_quantized_embedding_layer(env.quant_config) + self.tok_embeddings = Embedding( + config.vocab_size, + config.hidden_size, + device=device, + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(config.num_hidden_layers): + self.layers.append(TransformerBlockHf(layer_id, config, device, env)) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device) + + LinearLayer = get_quantized_linear_layer(env.quant_config) + linear_kwargs = {} + if LinearLayer != torch.nn.Linear: + linear_kwargs["quant_config"] = env.quant_config + + self.output = LinearLayer( + config.hidden_size, + config.vocab_size, + bias=False, + device=device, + **linear_kwargs, + ) + # TODO what to do with this + freqs_cis = precompute_freqs_cis( + config.hidden_size // config.num_attention_heads, + env.cache_len * 2, + theta=config.rope_theta, + ) + + self.register_buffer("freqs_cis", freqs_cis) + + self.hf_name("output", "lm_head") + self.hf_name("norm", "model.norm") + self.hf_name("layers", "model.layers") + self.hf_name("tok_embeddings", "model.embed_tokens") + + self.annotate_sharding("tok_embeddings.weight", 1) + self.annotate_sharding("output.weight", 0) + + @torch.no_grad() + def forward( + self, + tokens: torch.Tensor, + input_pos: torch.Tensor, + caches: List[Any], + mask, + start=None, + ragged_batch_index=None, + ragged_block_index=None, + ): + """ + tokens: the input token for decoding + input_pos: the decoding position relative to the start, which is the length of the decoding results + caches: kv caches + mask: causal mask to filter the attention results + start: the starting position for each slot + ragged_batch_index: precomputed batch index for ragged attention + ragged_block_index: precomputed block index for ragged attention + """ + + with jax.named_scope("transformer_tok"): + seqlen = tokens.shape[-1] + h = self.tok_embeddings(tokens) + + with jax.named_scope("transformer_freq"): + bsz, seqlen = tokens.shape + freqs_cis = self.freqs_cis[input_pos] + freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) + + end = None if start is None else (start + input_pos) % self.env.cache_len + # For stacked case, cannot get cache inside the loop which will cause cache copy + for layer_id, layer in enumerate(self.layers): + if caches[0].stacked: + cache = caches[0] + else: + cache = caches[layer_id] + # else: # For stacked case, there is only 1 yer of kv cache + + with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): + h = layer( + h, + freqs_cis, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, + ) + + with jax.named_scope("transformer_norm"): + h = self.norm(h) + output = self.output(h).float() + return output + + @classmethod + def from_config(cls, config, env): + device = "meta" + model = cls(config, device, env) + return model + + def drop_weight(self, key): + return key.startswith("model") + + def shard_weights(self, _weights_dict): + """Shards the weights + + Assumes the weights_dict is a list of XLATensor2 + """ diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/logits_process.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/logits_process.py new file mode 100644 index 00000000..a525be61 --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/logits_process.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Tuple + +import jax +import jax.numpy as jnp +from transformers import GenerationConfig + + +@dataclass +class FusedLogitsWarper: + """ + A class that performs top-k then top-p filtering, optionally applying a temperature. + + Top-k filtering only keeps the `k` tokens with the best scores. + + Top-p filtering only keeps the top tokens whose cumulated probability is above `p`. + + The filtered tokens are returned as a list of indices, along with the corresponding subset of + the original logits. + + If only top-k filtering is active, the filtered tokens are sorted by descending order. + + If top-p filtering is active, the filtered tokens are sorted by ascending order. + + Args: + temperature (`float`): + Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases + randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely + token. + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + """ + + temperature: float = 1.0 + top_k: int = 0 + top_p: float = 1.0 + + @classmethod + def from_config(cls, generation_config: GenerationConfig) -> "FusedLogitsWarper": + """Instantiate a fused warper from a generation configuration. + + Args: + generation_config (`~transformers.generation.GenerationConfig`): + The generation configuration to be used as base parametrization for the fused warper. + + Returns: + a `FusedLogitsWarper` or None if neither top-k nor top-p are configured. + """ + return cls(generation_config.temperature, generation_config.top_k, generation_config.top_p) + + def __call__(self, logits: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + if self.temperature != 1.0: + logits = logits / self.temperature + + do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1] + do_top_p = self.top_p < 1.0 and self.top_p > 0.0 + + if do_top_k: + sorted_indices = jnp.argsort(logits, axis=-1)[..., ::-1][:, : self.top_k] + sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1) + else: + sorted_indices = jnp.argsort(logits, axis=-1) + sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1) + + if do_top_p: + if do_top_k: + # logits have been sorted in descending order, so we need to flip them + sorted_logits = jnp.flip(sorted_logits, axis=-1) + sorted_indices = jnp.flip(sorted_indices, axis=-1) + # We always keep the best logits and those whose cumulative probability is strictly higher than top_p + cum_probs = jax.nn.softmax(sorted_logits, axis=-1).cumsum(axis=-1) + keep_mask = cum_probs > (1 - self.top_p) + keep_mask = keep_mask.at[:, -1].set(True) + # Set rejected logits to -inf so that they are ignored in downstream comparisons + sorted_logits = jnp.where(keep_mask, sorted_logits, float("-Inf")) + + return sorted_logits, sorted_indices diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py new file mode 100644 index 00000000..9f9e21f8 --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py @@ -0,0 +1,199 @@ +import copy +import logging +from typing import List, Optional, Union + +import jax +import jax.numpy as jnp +from jetstream.engine import sampling_utils +from transformers.generation import ( + GenerationConfig, + GenerationMixin, + LogitsProcessorList, + StoppingCriteriaList, +) +from transformers.generation.utils import GenerationMode + +from .logits_process import FusedLogitsWarper + + +logger = logging.getLogger(__name__) + + +class TokenSelector: + """Implements the token selection logic corresponding to a generation configuration. + + This class combines and uses the logits processors and stopping criteria implemented in + the transformers library. + + The algorithm to select these objects is heavily inspired by the transformers `GenerationMixin.generate()` + method, but the actual token selection methods are specific, and partially adapted from Jetstream/Pytorch sampling + implementation. + + The reason why this class does not inherit from `GenerationMixin` is because it does not + include the code to produce the tokens logits. + Separating the production of the tokens logits from the tokens selection allows this class + to be used with different generation paradigms, either synchronously using a single `TokenSelector` in + `GenerationMixin.generate()` or asynchronously using multiple `TokenSelector` inside an inference endpoint. + + The constructor of this class should not be called directly: instances should be obtained by + calling `TokenSelector.create()`. + """ + + def __init__( + self, + mode: GenerationMode, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + eos_token_ids: Union[int, List[int]], + pad_token_id: int, + logits_warper: Optional[LogitsProcessorList] = None, + seed: Optional[int] = 0, + ): + self.mode = mode + self.logits_processor = logits_processor + self.stopping_criteria = stopping_criteria + self.eos_token_ids = eos_token_ids + self.pad_token_id = pad_token_id + self.logits_warper = logits_warper + self.key = jax.random.PRNGKey(seed) + + @classmethod + def create( + cls, + input_ids: jnp.ndarray, + generation_config: GenerationConfig, + model: GenerationMixin, + max_seq_length: int, + stopping_criteria: Optional[StoppingCriteriaList] = None, + seed: Optional[int] = 0, + ) -> "TokenSelector": + r"""Creates the `TokenSelector` for a specific generation configuration. + + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + generation_config (`~transformers.generation.GenerationConfig`, *optional*): + The generation configuration to parametrize the token selection. + model (`~transformers.generation.GenerationMixin`): + The model provides the internal helpers allowing to select the logits processors and stopping criterias. + max_seq_length (`int`): + The maximum number of input + generated tokens for this model. It depends on the model compilation parameters. + stopping_criteria (`Optional[transformers.generation.StoppingCriteriaList], defaults to `None`): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. + seed(`Optional[int]`): + The optional seed for sampling. Defaults to zero. + Return: + The `TokenSelector` instance. + """ + generation_config.validate() + generation_config = copy.deepcopy(generation_config) + + unsupported_generation_flags = [ + "output_attentions", + "output_hidden_states", + "output_scores", + "return_dict_in_generate", + ] + for flag in unsupported_generation_flags: + if getattr(generation_config, flag, False): + raise ValueError("{flag} is not supported for generation.") + + if generation_config.max_new_tokens is not None: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids.shape[-1] + + min_length = generation_config.min_length + if min_length > max_seq_length: + raise ValueError( + f"The minimum generation length ({min_length}) exceeds the model maximum sequence length ({max_seq_length})" + ) + max_length = generation_config.max_length + if max_length > max_seq_length: + logger.warning( + f"Adjusting the maximum generation length ({max_length}) to the model maximum sequence length ({max_seq_length})" + ) + generation_config.max_length = max_seq_length + + # Instantiate transformers library processors and criterias + logits_processor = model._get_logits_processor( + generation_config, + input_ids_seq_length=input_ids.shape[-1], + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=None, + logits_processor=LogitsProcessorList(), + ) + if stopping_criteria is None: + stopping_criteria = StoppingCriteriaList() + stopping_criteria = model._get_stopping_criteria(generation_config, stopping_criteria=stopping_criteria) + + # This is not supposed to happen for any of the models we support + eos_token_id = generation_config.eos_token_id + assert eos_token_id is not None + # The generation requires special tokens + eos_token_ids = eos_token_id if isinstance(eos_token_id, list) else [eos_token_id] + if generation_config.pad_token_id is None: + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_ids[0]} for open-end generation.") + generation_config.pad_token_id = eos_token_ids[0] + + generation_mode = generation_config.get_generation_mode() + if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]: + raise ValueError("Unsupported generation mode") + + logits_warper = None + if generation_mode == GenerationMode.SAMPLE: + logits_warper = FusedLogitsWarper.from_config(generation_config) + + return cls( + mode=generation_mode, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + logits_warper=logits_warper, + eos_token_ids=eos_token_ids, + pad_token_id=generation_config.pad_token_id, + seed=seed, + ) + + def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: + """Select the next tokens from the candidate logits. + + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation (not used in all generation modes). + logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + The logits corresponding to the generated tokens. + + Return: + `jnp.ndarray`: A `jnp.ndarray` containing the selected tokens. + """ + scores = self.logits_processor(input_ids, logits) + if self.mode == GenerationMode.SAMPLE: + return self._sample(scores) + else: + return jnp.argmax(scores, axis=-1) + + def _sample(self, scores: jnp.ndarray) -> jnp.ndarray: + do_top_k = self.logits_warper.top_k > 0 and self.logits_warper.top_k < scores.shape[-1] + do_top_p = self.logits_warper.top_p < 1.0 and self.logits_warper.top_p > 0.0 + + if do_top_k: + return sampling_utils.sample_topk_logits( + scores, + self.logits_warper.top_k, + self.logits_warper.temperature, + self.key, + ) + elif do_top_p: + return sampling_utils.sample_nucleus_topp_logits( + scores, + self.logits_warper.top_p, + self.logits_warper.temperature, + self.key, + ) + + return jax.random.categorical(self.key, scores / self.logits_warper.temperature) diff --git a/text-generation-inference/server/text_generation_server/server.py b/text-generation-inference/server/text_generation_server/server.py index e53e8e14..d6525d90 100644 --- a/text-generation-inference/server/text_generation_server/server.py +++ b/text-generation-inference/server/text_generation_server/server.py @@ -6,7 +6,7 @@ from grpc_reflection.v1alpha import reflection from loguru import logger -from .generator import Generator, TpuGenerator +from .auto_generator import AutoGenerator, Generator from .interceptor import ExceptionInterceptor from .pb import generate_pb2, generate_pb2_grpc @@ -73,7 +73,7 @@ async def serve_inner(model_path: str): server_urls = [local_url] try: - generator = TpuGenerator.from_pretrained( + generator = AutoGenerator.from_pretrained( model_path, revision=revision, max_batch_size=max_batch_size, diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index 982c1da6..160e8003 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -3,16 +3,19 @@ import pytest from helpers import create_request, prepare_model -from text_generation_server.generator import TpuGenerator +from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch from tqdm import tqdm +from optimum.tpu.jetstream_pt_support import jetstream_pt_available + @dataclass class DecodeTestParams: model_id: str sequence_length: int expected_text: str + do_sample: bool = False @pytest.mark.parametrize("params", @@ -63,10 +66,10 @@ def _test_decode_single(params): input_text = "It was a bright cold day in April, and the clocks were striking thirteen." max_new_tokens = 20 - generator = TpuGenerator.from_pretrained( + generator = AutoGenerator.from_pretrained( model_path, revision="", max_batch_size=1, max_sequence_length=params.sequence_length ) - request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) + request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=params.do_sample) batch = Batch(id=0, requests=[request], size=1, max_tokens=params.sequence_length) generations, next_batch = generator.prefill(batch) # We already generated one token: call decode max_new_tokens - 1 times @@ -84,4 +87,25 @@ def _test_decode_single(params): output = generations[0].generated_text assert output.generated_tokens == max_new_tokens assert output.finish_reason == 0 - assert output.text == params.expected_text + if params.do_sample: + assert output.text != params.expected_text + else: + assert output.text == params.expected_text + + +@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"]) +@pytest.mark.parametrize("params", + [ + DecodeTestParams( + model_id="meta-llama/Llama-2-7b-hf", + sequence_length=256, + expected_text="\n\nThe clocks were striking thirteen\nThe clocks were striking thirteen\n", + ), + ], + ids=["Llama-2-7b-hf"], +) +def test_decode_single_jetstream_pytorch(params, do_sample): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + params.do_sample = do_sample + _test_decode_single(params) diff --git a/text-generation-inference/tests/test_generator_slot.py b/text-generation-inference/tests/test_generator_slot.py index c40c1644..f0f23ed7 100644 --- a/text-generation-inference/tests/test_generator_slot.py +++ b/text-generation-inference/tests/test_generator_slot.py @@ -1,6 +1,5 @@ import pytest import torch -from text_generation_server.generator import Slot from text_generation_server.pb.generate_pb2 import Request from transformers import AutoTokenizer, GenerationConfig @@ -31,6 +30,7 @@ def tokenizer(request): ids=["spaces", "chinese-utf8", "emojis"], ) def test_decode_streaming(tokenizer, input_text, generated_text): + from text_generation_server.generator import Slot # Note: device used is cpu to make it faster slot = Slot(0, tokenizer, "cpu") request = Request(id=0, inputs=input_text) diff --git a/text-generation-inference/tests/test_gpt2.py b/text-generation-inference/tests/test_gpt2.py index 26d61a0b..776d8418 100644 --- a/text-generation-inference/tests/test_gpt2.py +++ b/text-generation-inference/tests/test_gpt2.py @@ -1,7 +1,7 @@ import pytest from helpers import create_request, prepare_model -from text_generation_server.generator import TpuGenerator +from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch from tqdm import tqdm @@ -16,7 +16,7 @@ def model_path(): def test_info(model_path): - generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1) + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1) info = generator.info assert info.requires_padding is True assert info.device_type == "xla" @@ -44,7 +44,7 @@ def test_info(model_path): ) @pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): - generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) requests = [] max_new_tokens = 20 for i in range(batch_size): @@ -65,7 +65,7 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_ def test_decode_multiple(model_path): - generator = TpuGenerator.from_pretrained(model_path, + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=2, max_sequence_length=SEQUENCE_LENGTH) diff --git a/text-generation-inference/tests/test_prefill_truncate.py b/text-generation-inference/tests/test_prefill_truncate.py index e3666c66..4ad78ab5 100644 --- a/text-generation-inference/tests/test_prefill_truncate.py +++ b/text-generation-inference/tests/test_prefill_truncate.py @@ -1,5 +1,5 @@ from helpers import create_request, prepare_model -from text_generation_server.generator import TpuGeneratorSingleThread as TpuGenerator +from text_generation_server.auto_generator import AutoGenerator from text_generation_server.pb.generate_pb2 import Batch @@ -10,7 +10,7 @@ def test_prefill_truncate(): model_path = prepare_model(model_id, sequence_length) max_new_tokens = 20 - generator = TpuGenerator.from_pretrained( + generator = AutoGenerator.from_pretrained( model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length ) input_text = "This is something I will tell by the end of the story"