-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat(tgi): add functions to load Jetstream Pytorch engine for Llama2 * chore(TokenSelector): remove XLA xm rng seed set * fix(version): remove warning on deprecated API use packaging.version's parse instead of pkg_resources' parse_version. * fix(generator): use pad_token_id for padding * fix(decode): clear unrequested slots * feat(imports): add function to check if Jetstream Pytorch can be used * feat(Jetstream): improved support for engine load The custom HfEngine contains functions that will allow for prefill and generate functions to use custom sampling functions. * feat(TGI): Added Jetstream/Pytorch generator This implementation is equivalent to the torch_xla one, but uses the Jetstream/Pytorch engine instead. * chore(fsdp v2): avoid importing PretrainedModel This way we can aboid trying to import torch xla. * feat(tgi): introduce AutoGenerator This is just a way to provide a factory class method to create Jetstream/Pytorch or Pytorch XLA generator. * feat(Jetstream PT): Enable support only if env var is set There are still some issues related to some fine-tuned models, so for now just enable only when JETSTREAM_PT is set. * feat(TGI): use AutoGenerator in model server * feat(package): add optional dependency on Jetstream/Pytorch For now it is possible to install dependency after optimum-tpu has been instelled, issuing this command: pip install "optimum-tpu[jetstream-pt]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html * test(Jetstream Pytorch): added a simple decode test Also adapted other tests to avoid torch-xla generator implementaion, to avoid conflict. I also added the Jetstream/pytorch test to workflow in CI. * test(decode): added a variant with do_sample=True with Jetstream PT * fix(README): correct link * doc(README): add mention on how to install and enable Pytorch/Jetstream * feat(build): make clean removes old TGI builds too * review: comply to comments requests - Added warning when trying to load torch_xla2 adter torch_xla - renamed jetstream_pt_support.check to model_can_use_jetstream_pt * review(AutoGenerator): log if using Jetstream/PT or torch xla
- Loading branch information
1 parent
5eb6cf3
commit fa24cc4
Showing
26 changed files
with
1,709 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
import sys | ||
|
||
from loguru import logger | ||
|
||
|
||
def jetstream_pt_available() -> bool: | ||
"""Check if the necessary imports to use jetstream_pt are available. | ||
""" | ||
try: | ||
# For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable. | ||
jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1" | ||
if not jetstream_pt_enabled: | ||
return False | ||
# Torch XLA should not be imported before torch_xla2 to avoid conflicts. | ||
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules: | ||
logger.warning("torch_xla2 cannot be imported after torch_xla, disabling Jetstream PyTorch support.") | ||
return False | ||
# Import torch_xla2 first! | ||
import torch_xla2 # noqa: F401, isort:skip | ||
|
||
import jetstream_pt # noqa: F401 | ||
|
||
return True | ||
except ImportError: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
text-generation-inference/server/text_generation_server/auto_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from loguru import logger | ||
|
||
from .generator_base import Generator | ||
from .jetstream_pt_support import model_can_use_jetstream_pt | ||
|
||
|
||
class AutoGenerator: | ||
|
||
@staticmethod | ||
def from_pretrained( | ||
model_path: str, revision: str, max_batch_size: int, max_sequence_length: int | ||
) -> Generator: | ||
"""Instantiate a Generator for TPU using Jetstream Pytorch or Pytorch/XLA. | ||
Args: | ||
model_path (`str`): | ||
The path to a local model. This path must also contain a Tokenizer. | ||
revision (`str`): | ||
The revision of the model. | ||
max_batch_size (`int`): | ||
The maximum batch size. | ||
max_sequence_length (`int`): | ||
The maximum sequence length. | ||
Returns: | ||
A TpuGenerator. | ||
""" | ||
if model_can_use_jetstream_pt(model_path): | ||
logger.debug("Using Jetstream PyTorch generator.") | ||
from .jetstream_pt_support.generator import TpuGeneratorJetStream | ||
return TpuGeneratorJetStream.from_pretrained( | ||
model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length | ||
) | ||
else: | ||
logger.debug("Using PyTorch/XLA generator.") | ||
from .generator import TpuGenerator | ||
return TpuGenerator.from_pretrained( | ||
model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 15 additions & 0 deletions
15
text-generation-inference/server/text_generation_server/jetstream_pt_support/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .compatibility import create_engine, model_can_use_jetstream_pt |
53 changes: 53 additions & 0 deletions
53
...-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
from typing import Any | ||
|
||
from transformers import AutoConfig | ||
|
||
from optimum.tpu import jetstream_pt_available | ||
|
||
|
||
def model_can_use_jetstream_pt(model_path: str) -> bool: | ||
"""Checks if the model is supported by Jetstream Pytorch on Optimum TPU and if the required dependencies to provide | ||
the engine are installed. | ||
""" | ||
config = AutoConfig.from_pretrained(model_path) | ||
# For now only Llama 2 with tokenizer.model is supported | ||
if config.model_type != "llama" or not os.path.exists( | ||
os.path.join(model_path, "tokenizer.model") | ||
): | ||
return False | ||
if jetstream_pt_available(): | ||
return True | ||
return False | ||
|
||
|
||
def create_engine( | ||
model_path: str, | ||
batch_size: int, | ||
sequence_length: int, | ||
max_input_tokens: int, | ||
max_output_tokens: int, | ||
) -> Any: | ||
if not model_can_use_jetstream_pt(model_path): | ||
# The model is not compatible with Jetstream PyTorch, just exit | ||
return None | ||
|
||
# Now import engine_loader to prevent importing it at the top when not supported | ||
from .engine_loader import create_engine | ||
return create_engine( | ||
model_path, batch_size, sequence_length, max_input_tokens, max_output_tokens | ||
) |
Oops, something went wrong.