Skip to content

Commit

Permalink
✈️ Introduce Jetstream/Pytorch in TGI (#88)
Browse files Browse the repository at this point in the history
* feat(tgi): add functions to load Jetstream Pytorch engine for Llama2

* chore(TokenSelector): remove XLA xm rng seed set

* fix(version): remove warning on deprecated API

use packaging.version's parse instead of pkg_resources' parse_version.

* fix(generator): use pad_token_id for padding

* fix(decode): clear unrequested slots

* feat(imports): add function to check if Jetstream Pytorch can be used

* feat(Jetstream): improved support for engine load

The custom HfEngine contains functions that will allow for prefill and
generate functions to use custom sampling functions.

* feat(TGI): Added Jetstream/Pytorch generator

This implementation is equivalent to the torch_xla one, but uses the
Jetstream/Pytorch engine instead.

* chore(fsdp v2): avoid importing PretrainedModel

This way we can aboid trying to import torch xla.

* feat(tgi): introduce AutoGenerator

This is just a way to provide a factory class method to create
Jetstream/Pytorch or Pytorch XLA generator.

* feat(Jetstream PT): Enable support only if env var is set

There are still some issues related to some fine-tuned models, so for
now just enable only when JETSTREAM_PT is set.

* feat(TGI): use AutoGenerator in model server

* feat(package): add optional dependency on Jetstream/Pytorch

For now it is possible to install dependency after optimum-tpu has been
instelled, issuing this command:

pip install "optimum-tpu[jetstream-pt]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

* test(Jetstream Pytorch): added a simple decode test

Also adapted other tests to avoid torch-xla generator implementaion, to
avoid conflict.
I also added the Jetstream/pytorch test to workflow in CI.

* test(decode): added a variant with do_sample=True with Jetstream PT

* fix(README): correct link

* doc(README): add mention on how to install and enable Pytorch/Jetstream

* feat(build): make clean removes old TGI builds too

* review: comply to comments requests

- Added warning when trying to load torch_xla2 adter torch_xla
- renamed jetstream_pt_support.check to model_can_use_jetstream_pt

* review(AutoGenerator): log if using Jetstream/PT or torch xla
  • Loading branch information
tengomucho authored Sep 9, 2024
1 parent 5eb6cf3 commit fa24cc4
Show file tree
Hide file tree
Showing 26 changed files with 1,709 additions and 28 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/test-pytorch-xla-tpu-tgi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,13 @@ jobs:
- name: Build and test TGI server
run: |
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test
# Use a different step to test the Jetstream Pytorch version, to avoid conflicts with torch-xla[tpu]
- name: Install and test TGI server (Jetstream Pytorch)
run: |
pip install -U .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \
pytest -sv text-generation-inference/tests -k jetstream
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ $(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES)

clean:
rm -rf dist
make -C text-generation-inference/server/ clean

tpu-tgi:
docker build --rm -f text-generation-inference/docker/Dockerfile \
Expand Down
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,20 @@ Other TPU versions will be supported along the way.
As part of the integration, we do support a [text-generation-inference (TGI)](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference) backend allowing to deploy and serve
incoming HTTP requests and execute them on Cloud TPUs.

Please see the [TGI specific documentation]() on how to get started
Please see the [TGI specific documentation](text-generation-inference) on how to get started.

### JetStream Pytorch Engine

`optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated command:

```shell
pip install "optimum-tpu[jetstream-pt]" \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
```

To enable the support, export the environment variable `JETSTREAM_PT=1`.

## Training

Expand Down
1 change: 1 addition & 0 deletions optimum/tpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .jetstream_pt_support import jetstream_pt_available # isort:skip
from .fsdp_v2 import get_fsdp_config, use_fsdp_v2
from .modeling import AutoModelForCausalLM
from .version import VERSION, __version__
18 changes: 13 additions & 5 deletions optimum/tpu/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
"""
Utility functions to provide FSDPv2 configuration for TPU training.
"""
from typing import Dict, List, Union
from typing import Any, Dict, List, Union

from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging


PreTrainedModel = Any
# NOTE: instead of the above, modeling_utils.PreTrainedModel should be used, but since the usage is only for type
# hinting, it is not imported here, so to avoid pulling imports of torch_xla.


def use_fsdp_v2():
"""
Enable FSDPv2 for TPU training.
Expand Down Expand Up @@ -61,6 +65,7 @@ def _unwrap_model(model: PreTrainedModel) -> PreTrainedModel:
"""
try:
from peft.peft_model import LoraModel, PeftModel

if isinstance(model, PeftModel) and isinstance(model.base_model, LoraModel):
return model.base_model.model
return model
Expand Down Expand Up @@ -89,10 +94,13 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM):
logger = logging.get_logger(__name__)
from torch_xla import __version__ as xla_version

if xla_version == "2.3.0":
logger.warning_once("Fine-tuning Gemma on Pytorch XLA 2.3.0 might raise some issues. In case of any "
"issues consider using the nightly version, and report the issue on the optimum-tpu "
"GitHub repository: https://github.com/huggingface/optimum-tpu/issues/new.")
logger.warning_once(
"Fine-tuning Gemma on Pytorch XLA 2.3.0 might raise some issues. In case of any "
"issues consider using the nightly version, and report the issue on the optimum-tpu "
"GitHub repository: https://github.com/huggingface/optimum-tpu/issues/new."
)
cls_to_wrap = "GemmaDecoderLayer"
matched_model = True
elif model_type == "llama":
Expand Down
2 changes: 0 additions & 2 deletions optimum/tpu/generation/token_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List, Optional, Union

import torch
import torch_xla.core.xla_model as xm
from transformers.generation import (
GenerationConfig,
GenerationMixin,
Expand Down Expand Up @@ -53,7 +52,6 @@ def __init__(
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
self.logits_warper = logits_warper
xm.set_rng_state(seed)
self.generator = torch.Generator()
self.generator.manual_seed(seed)

Expand Down
26 changes: 26 additions & 0 deletions optimum/tpu/jetstream_pt_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import sys

from loguru import logger


def jetstream_pt_available() -> bool:
"""Check if the necessary imports to use jetstream_pt are available.
"""
try:
# For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable.
jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1"
if not jetstream_pt_enabled:
return False
# Torch XLA should not be imported before torch_xla2 to avoid conflicts.
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules:
logger.warning("torch_xla2 cannot be imported after torch_xla, disabling Jetstream PyTorch support.")
return False
# Import torch_xla2 first!
import torch_xla2 # noqa: F401, isort:skip

import jetstream_pt # noqa: F401

return True
except ImportError:
return False
4 changes: 2 additions & 2 deletions optimum/tpu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pkg_resources import parse_version
from packaging.version import parse


__version__ = "0.1.5"
VERSION = parse_version(__version__)
VERSION = parse(__version__)
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ keywords = [

dependencies = [
"transformers == 4.41.1",
"torch >= 2.3.0, <= 2.4.0",
"torch-xla[tpu] >= 2.3.0, <= 2.4.0",
"torch == 2.4.0",
"torch-xla[tpu] == 2.4.0",
"loguru == 0.6.0",
"sentencepiece == 0.2.0",
]
Expand All @@ -58,6 +58,12 @@ build-backend = "setuptools.build_meta"
[project.optional-dependencies]
tests = ["pytest", "safetensors"]
quality = ["black", "ruff", "isort"]
# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit.
# Pallas is pulled because it will install a compatible version of jax[tpu].
jetstream-pt = [
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git#df92015289953c506004e674d57651b03e4e89f2",
"torch-xla[pallas] == 2.4.0"
]

[project.urls]
Homepage = "https://hf.co/hardware"
Expand Down
4 changes: 3 additions & 1 deletion text-generation-inference/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ clean:

# List static sources to be deployed in the package
src_dir := $(mkfile_dir)/$(pkg_name)
sources := $(wildcard $(src_dir)/*.py)
rwildcard_py = $(wildcard $(1)/*.py) $(foreach d,$(wildcard $(1)/*),$(call rwildcard_py,$d))
sources := $(call rwildcard_py,$(src_dir))
deployed_sources := $(subst $(src_dir), $(pkg_dir), $(sources))

# Static files are just copied

define COPY
mkdir -p $(dir $@)
cp -f $< $@
endef

Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
]

[tool.setuptools]
packages = ["text_generation_server", "text_generation_server.pb"]
packages = ["text_generation_server", "text_generation_server.pb", "text_generation_server.jetstream_pt_support"]

[tool.setuptools.dynamic]
version = {attr = "text_generation_server.version.__version__"}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from loguru import logger

from .generator_base import Generator
from .jetstream_pt_support import model_can_use_jetstream_pt


class AutoGenerator:

@staticmethod
def from_pretrained(
model_path: str, revision: str, max_batch_size: int, max_sequence_length: int
) -> Generator:
"""Instantiate a Generator for TPU using Jetstream Pytorch or Pytorch/XLA.
Args:
model_path (`str`):
The path to a local model. This path must also contain a Tokenizer.
revision (`str`):
The revision of the model.
max_batch_size (`int`):
The maximum batch size.
max_sequence_length (`int`):
The maximum sequence length.
Returns:
A TpuGenerator.
"""
if model_can_use_jetstream_pt(model_path):
logger.debug("Using Jetstream PyTorch generator.")
from .jetstream_pt_support.generator import TpuGeneratorJetStream
return TpuGeneratorJetStream.from_pretrained(
model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length
)
else:
logger.debug("Using PyTorch/XLA generator.")
from .generator import TpuGenerator
return TpuGenerator.from_pretrained(
model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length
)
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,19 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches[0].id
request_ids = []
for batch in batches:
request_ids += batch.request_ids
cleared_request_ids = []
for slot in self.slots:
if slot.state == slot.State.READY and slot.request_id not in request_ids:
cleared_request_ids.append(slot.request_id)
slot.clear()
if len(cleared_request_ids) > 0:
logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
if len(active_slots) < len(request_ids):
logger.error("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
# Reconstruct input_ids and attention_mask from slots
input_ids = None
attention_mask = None
Expand All @@ -608,7 +621,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[batch_size, 1],
fill_value=self.tokenizer.eos_token_id,
fill_value=pad_token_id,
dtype=torch.int64,
)
cache_position = torch.zeros([1], dtype=torch.int64)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .compatibility import create_engine, model_can_use_jetstream_pt
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any

from transformers import AutoConfig

from optimum.tpu import jetstream_pt_available


def model_can_use_jetstream_pt(model_path: str) -> bool:
"""Checks if the model is supported by Jetstream Pytorch on Optimum TPU and if the required dependencies to provide
the engine are installed.
"""
config = AutoConfig.from_pretrained(model_path)
# For now only Llama 2 with tokenizer.model is supported
if config.model_type != "llama" or not os.path.exists(
os.path.join(model_path, "tokenizer.model")
):
return False
if jetstream_pt_available():
return True
return False


def create_engine(
model_path: str,
batch_size: int,
sequence_length: int,
max_input_tokens: int,
max_output_tokens: int,
) -> Any:
if not model_can_use_jetstream_pt(model_path):
# The model is not compatible with Jetstream PyTorch, just exit
return None

# Now import engine_loader to prevent importing it at the top when not supported
from .engine_loader import create_engine
return create_engine(
model_path, batch_size, sequence_length, max_input_tokens, max_output_tokens
)
Loading

0 comments on commit fa24cc4

Please sign in to comment.