Skip to content

Commit

Permalink
🦙 Llama3 on TGI - Jetstream Pytorch (#90)
Browse files Browse the repository at this point in the history
* fix(engine_loader): correct n_reps and cache_shape settings

* feat(tokenizer): donwload all json files when fetching model

This will ensure that tokenizer_config.json is loaded if needed.

* feat(jetstream pt): relax llama compatibility requirements

This change will allow Llama3 models to be loaded.

* test(jetstream pt): move Llama2-7b test to runslow/nightly

* fix(jetstream pt): clone weight before mapping

This is a workaround to avoid a core dump observed when testing on
TinyLLama-v0 model.
It should allow to prevent other similar problems later.
This allows to add again the basic test (not slow) that will run on PRs
and check Jetstream/Pytorch.

* test(jetstream pt): add test showing support of Llama3-8B

* review: fix imports for type checking

* fix: correct type hint

* fix(pyproject): correct jetstream git revision
  • Loading branch information
tengomucho authored Sep 10, 2024
1 parent fa24cc4 commit b25e973
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 20 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,12 @@ jobs:
find text-generation-inference/ -name "text_generation_server-*whl" -exec python -m pip install {} \;
python -m pytest --runslow -sv text-generation-inference/tests
# Use a different step to test the Jetstream Pytorch version, to avoid conflicts with torch-xla[tpu]
- name: Install and test TGI server slow tests (Jetstream Pytorch)
run: |
pip install -U .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \
pytest -sv text-generation-inference/tests --runslow -k jetstream
2 changes: 1 addition & 1 deletion optimum/tpu/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def fetch_model(
local_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["config.json", "model*.safetensors", SAFE_WEIGHTS_INDEX_NAME, "tokenizer*"],
allow_patterns=["*.json", "model*.safetensors", SAFE_WEIGHTS_INDEX_NAME],
)
end = time.time()
logger.info(f"Model successfully fetched in {end - start:.2f} s.")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ quality = ["black", "ruff", "isort"]
# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit.
# Pallas is pulled because it will install a compatible version of jax[tpu].
jetstream-pt = [
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git#df92015289953c506004e674d57651b03e4e89f2",
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@df92015289953c506004e674d57651b03e4e89f2",
"torch-xla[pallas] == 2.4.0"
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any

from transformers import AutoConfig
Expand All @@ -25,10 +24,8 @@ def model_can_use_jetstream_pt(model_path: str) -> bool:
the engine are installed.
"""
config = AutoConfig.from_pretrained(model_path)
# For now only Llama 2 with tokenizer.model is supported
if config.model_type != "llama" or not os.path.exists(
os.path.join(model_path, "tokenizer.model")
):
# For now only Llama is supported
if config.model_type != "llama":
return False
if jetstream_pt_available():
return True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Import torch_xla2 first
import torch_xla2 # isort:skip

from typing import Any
from typing import TYPE_CHECKING, Any

import jax
from jetstream_pt import fetch_models, torchjax
Expand All @@ -11,19 +11,21 @@
QuantizationConfig,
)
from loguru import logger


if TYPE_CHECKING:
from transformers import PretrainedConfig
from transformers import AutoConfig

from .engine import HfEngine
from .llama_model_exportable_hf import TransformerHf


def load_llama_model_info(model_path: str) -> Any:
# First get config
config = AutoConfig.from_pretrained(model_path)
def load_llama_model_info(config: "PretrainedConfig") -> Any:
num_layers = config.num_hidden_layers
num_heads = config.num_attention_heads
head_dim = config.hidden_size // num_heads
n_reps = config.num_key_value_heads // num_heads
n_reps = num_heads // config.num_key_value_heads
model_info = fetch_models.ModelInfo(
TransformerHf,
num_layers=num_layers,
Expand All @@ -34,10 +36,10 @@ def load_llama_model_info(model_path: str) -> Any:
return model_info


def load_model_info(model_path: str) -> Any:
config = AutoConfig.from_pretrained(model_path) # For now only Llama 2 is supported
def load_model_info(config: "PretrainedConfig") -> Any:
# For now only Llama is supported
if config.model_type == "llama":
return load_llama_model_info(model_path)
return load_llama_model_info(config)
# Other models supports can be added here later
return None

Expand All @@ -49,7 +51,9 @@ def create_engine_env_data(
max_input_tokens: int,
max_output_tokens: int,
) -> Any:
model_info = load_model_info(model_path)
# First get config
config = AutoConfig.from_pretrained(model_path)
model_info = load_model_info(config)
if model_info is None:
return None

Expand All @@ -72,7 +76,7 @@ def create_engine_env_data(
)
env_data.cache_shape = (
batch_size,
model_info.num_heads,
config.num_key_value_heads,
max_cache_length,
model_info.head_dim,
)
Expand All @@ -91,7 +95,8 @@ def instantiate_model_from_repo_id(
env: Any,
):
"""Create model instance by hf model_dir, and its config"""
model_info = load_model_info(model_dir)
config = AutoConfig.from_pretrained(model_dir)
model_info = load_model_info(config)
# at this point we can be quite optimistic and just assert
assert model_info is not None

Expand All @@ -117,7 +122,8 @@ def shard_weights(env, weights, weight_shardings):
for key, val in weights.items():
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
with jax.default_device(jax.devices("cpu")[0]):
arr = torch_xla2.tensor.t2j(val)
# Note we clone to avoid a core-dump that might happen otherwise when calling device_put
arr = torch_xla2.tensor.t2j(val.clone())
arr = jax.device_put(arr, sharding)
sharded[key] = torchjax.to_torch(arr)
return sharded
Expand Down
26 changes: 25 additions & 1 deletion text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _test_decode_single(params):
assert output.text == params.expected_text


@pytest.mark.slow
@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"])
@pytest.mark.parametrize("params",
[
Expand All @@ -101,8 +102,31 @@ def _test_decode_single(params):
sequence_length=256,
expected_text="\n\nThe clocks were striking thirteen\nThe clocks were striking thirteen\n",
),
DecodeTestParams(
model_id="meta-llama/Meta-Llama-3-8B",
sequence_length=256,
expected_text=" Winston Winston Smith, his chin on his hands, and the clock in the Ministry of Truth, M",
),
],
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B"],
)
def test_decode_single_jetstream_pytorch_slow(params, do_sample):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
params.do_sample = do_sample
_test_decode_single(params)


@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"])
@pytest.mark.parametrize("params",
[
DecodeTestParams(
model_id="Maykeye/TinyLLama-v0",
sequence_length=256,
expected_text=" She She had a big and it had a big, blue, and a big, red and a",
),
],
ids=["Llama-2-7b-hf"],
ids=["TinyLLama-v0"],
)
def test_decode_single_jetstream_pytorch(params, do_sample):
if not jetstream_pt_available():
Expand Down

0 comments on commit b25e973

Please sign in to comment.