diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml index 806fd4a3..4b95d988 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml @@ -38,3 +38,12 @@ jobs: find text-generation-inference/ -name "text_generation_server-*whl" -exec python -m pip install {} \; python -m pytest --runslow -sv text-generation-inference/tests + # Use a different step to test the Jetstream Pytorch version, to avoid conflicts with torch-xla[tpu] + - name: Install and test TGI server slow tests (Jetstream Pytorch) + run: | + pip install -U .[jetstream-pt] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ + -f https://storage.googleapis.com/libtpu-releases/index.html + JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \ + pytest -sv text-generation-inference/tests --runslow -k jetstream diff --git a/optimum/tpu/model.py b/optimum/tpu/model.py index 38e1bb84..030d97d3 100644 --- a/optimum/tpu/model.py +++ b/optimum/tpu/model.py @@ -49,7 +49,7 @@ def fetch_model( local_path = snapshot_download( repo_id=model_id, revision=revision, - allow_patterns=["config.json", "model*.safetensors", SAFE_WEIGHTS_INDEX_NAME, "tokenizer*"], + allow_patterns=["*.json", "model*.safetensors", SAFE_WEIGHTS_INDEX_NAME], ) end = time.time() logger.info(f"Model successfully fetched in {end - start:.2f} s.") diff --git a/pyproject.toml b/pyproject.toml index 571dd7a1..30c1ef5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ quality = ["black", "ruff", "isort"] # Jetstream/Pytorch support is experimental for now, requires installation from fixed commit. # Pallas is pulled because it will install a compatible version of jax[tpu]. jetstream-pt = [ - "jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git#df92015289953c506004e674d57651b03e4e89f2", + "jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@df92015289953c506004e674d57651b03e4e89f2", "torch-xla[pallas] == 2.4.0" ] diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py index a7a33b86..e09e5c35 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any from transformers import AutoConfig @@ -25,10 +24,8 @@ def model_can_use_jetstream_pt(model_path: str) -> bool: the engine are installed. """ config = AutoConfig.from_pretrained(model_path) - # For now only Llama 2 with tokenizer.model is supported - if config.model_type != "llama" or not os.path.exists( - os.path.join(model_path, "tokenizer.model") - ): + # For now only Llama is supported + if config.model_type != "llama": return False if jetstream_pt_available(): return True diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index 6e933df5..332cf5de 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -1,7 +1,7 @@ # Import torch_xla2 first import torch_xla2 # isort:skip -from typing import Any +from typing import TYPE_CHECKING, Any import jax from jetstream_pt import fetch_models, torchjax @@ -11,19 +11,21 @@ QuantizationConfig, ) from loguru import logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig from transformers import AutoConfig from .engine import HfEngine from .llama_model_exportable_hf import TransformerHf -def load_llama_model_info(model_path: str) -> Any: - # First get config - config = AutoConfig.from_pretrained(model_path) +def load_llama_model_info(config: "PretrainedConfig") -> Any: num_layers = config.num_hidden_layers num_heads = config.num_attention_heads head_dim = config.hidden_size // num_heads - n_reps = config.num_key_value_heads // num_heads + n_reps = num_heads // config.num_key_value_heads model_info = fetch_models.ModelInfo( TransformerHf, num_layers=num_layers, @@ -34,10 +36,10 @@ def load_llama_model_info(model_path: str) -> Any: return model_info -def load_model_info(model_path: str) -> Any: - config = AutoConfig.from_pretrained(model_path) # For now only Llama 2 is supported +def load_model_info(config: "PretrainedConfig") -> Any: + # For now only Llama is supported if config.model_type == "llama": - return load_llama_model_info(model_path) + return load_llama_model_info(config) # Other models supports can be added here later return None @@ -49,7 +51,9 @@ def create_engine_env_data( max_input_tokens: int, max_output_tokens: int, ) -> Any: - model_info = load_model_info(model_path) + # First get config + config = AutoConfig.from_pretrained(model_path) + model_info = load_model_info(config) if model_info is None: return None @@ -72,7 +76,7 @@ def create_engine_env_data( ) env_data.cache_shape = ( batch_size, - model_info.num_heads, + config.num_key_value_heads, max_cache_length, model_info.head_dim, ) @@ -91,7 +95,8 @@ def instantiate_model_from_repo_id( env: Any, ): """Create model instance by hf model_dir, and its config""" - model_info = load_model_info(model_dir) + config = AutoConfig.from_pretrained(model_dir) + model_info = load_model_info(config) # at this point we can be quite optimistic and just assert assert model_info is not None @@ -117,7 +122,8 @@ def shard_weights(env, weights, weight_shardings): for key, val in weights.items(): sharding = env.sharding_by_axis(weight_shardings.get(key, -1)) with jax.default_device(jax.devices("cpu")[0]): - arr = torch_xla2.tensor.t2j(val) + # Note we clone to avoid a core-dump that might happen otherwise when calling device_put + arr = torch_xla2.tensor.t2j(val.clone()) arr = jax.device_put(arr, sharding) sharded[key] = torchjax.to_torch(arr) return sharded diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index 160e8003..1f431a3e 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -93,6 +93,7 @@ def _test_decode_single(params): assert output.text == params.expected_text +@pytest.mark.slow @pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"]) @pytest.mark.parametrize("params", [ @@ -101,8 +102,31 @@ def _test_decode_single(params): sequence_length=256, expected_text="\n\nThe clocks were striking thirteen\nThe clocks were striking thirteen\n", ), + DecodeTestParams( + model_id="meta-llama/Meta-Llama-3-8B", + sequence_length=256, + expected_text=" Winston Winston Smith, his chin on his hands, and the clock in the Ministry of Truth, M", + ), + ], + ids=["Llama-2-7b-hf", "Meta-Llama-3-8B"], +) +def test_decode_single_jetstream_pytorch_slow(params, do_sample): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + params.do_sample = do_sample + _test_decode_single(params) + + +@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"]) +@pytest.mark.parametrize("params", + [ + DecodeTestParams( + model_id="Maykeye/TinyLLama-v0", + sequence_length=256, + expected_text=" She She had a big and it had a big, blue, and a big, red and a", + ), ], - ids=["Llama-2-7b-hf"], + ids=["TinyLLama-v0"], ) def test_decode_single_jetstream_pytorch(params, do_sample): if not jetstream_pt_available():