Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🦙 Llama3 on TGI - Jetstream Pytorch #90

Merged
merged 9 commits into from
Sep 10, 2024
9 changes: 9 additions & 0 deletions .github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,12 @@ jobs:
find text-generation-inference/ -name "text_generation_server-*whl" -exec python -m pip install {} \;
python -m pytest --runslow -sv text-generation-inference/tests

# Use a different step to test the Jetstream Pytorch version, to avoid conflicts with torch-xla[tpu]
- name: Install and test TGI server slow tests (Jetstream Pytorch)
run: |
pip install -U .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \
pytest -sv text-generation-inference/tests --runslow -k jetstream
2 changes: 1 addition & 1 deletion optimum/tpu/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def fetch_model(
local_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["config.json", "model*.safetensors", SAFE_WEIGHTS_INDEX_NAME, "tokenizer*"],
allow_patterns=["*.json", "model*.safetensors", SAFE_WEIGHTS_INDEX_NAME],
)
end = time.time()
logger.info(f"Model successfully fetched in {end - start:.2f} s.")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ quality = ["black", "ruff", "isort"]
# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit.
# Pallas is pulled because it will install a compatible version of jax[tpu].
jetstream-pt = [
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git#df92015289953c506004e674d57651b03e4e89f2",
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@df92015289953c506004e674d57651b03e4e89f2",
"torch-xla[pallas] == 2.4.0"
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any

from transformers import AutoConfig
Expand All @@ -25,10 +24,8 @@ def model_can_use_jetstream_pt(model_path: str) -> bool:
the engine are installed.
"""
config = AutoConfig.from_pretrained(model_path)
# For now only Llama 2 with tokenizer.model is supported
if config.model_type != "llama" or not os.path.exists(
os.path.join(model_path, "tokenizer.model")
):
# For now only Llama is supported
if config.model_type != "llama":
return False
if jetstream_pt_available():
return True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Import torch_xla2 first
import torch_xla2 # isort:skip

from typing import Any
from typing import TYPE_CHECKING, Any

import jax
from jetstream_pt import fetch_models, torchjax
Expand All @@ -11,19 +11,21 @@
QuantizationConfig,
)
from loguru import logger


if TYPE_CHECKING:
from transformers import PretrainedConfig
from transformers import AutoConfig

from .engine import HfEngine
from .llama_model_exportable_hf import TransformerHf


def load_llama_model_info(model_path: str) -> Any:
# First get config
config = AutoConfig.from_pretrained(model_path)
def load_llama_model_info(config: "PretrainedConfig") -> Any:
num_layers = config.num_hidden_layers
num_heads = config.num_attention_heads
head_dim = config.hidden_size // num_heads
n_reps = config.num_key_value_heads // num_heads
n_reps = num_heads // config.num_key_value_heads
model_info = fetch_models.ModelInfo(
TransformerHf,
num_layers=num_layers,
Expand All @@ -34,10 +36,10 @@ def load_llama_model_info(model_path: str) -> Any:
return model_info


def load_model_info(model_path: str) -> Any:
config = AutoConfig.from_pretrained(model_path) # For now only Llama 2 is supported
def load_model_info(config: "PretrainedConfig") -> Any:
# For now only Llama is supported
if config.model_type == "llama":
return load_llama_model_info(model_path)
return load_llama_model_info(config)
# Other models supports can be added here later
return None

Expand All @@ -49,7 +51,9 @@ def create_engine_env_data(
max_input_tokens: int,
max_output_tokens: int,
) -> Any:
model_info = load_model_info(model_path)
# First get config
config = AutoConfig.from_pretrained(model_path)
model_info = load_model_info(config)
if model_info is None:
return None

Expand All @@ -72,7 +76,7 @@ def create_engine_env_data(
)
env_data.cache_shape = (
batch_size,
model_info.num_heads,
config.num_key_value_heads,
max_cache_length,
model_info.head_dim,
)
Expand All @@ -91,7 +95,8 @@ def instantiate_model_from_repo_id(
env: Any,
):
"""Create model instance by hf model_dir, and its config"""
model_info = load_model_info(model_dir)
config = AutoConfig.from_pretrained(model_dir)
model_info = load_model_info(config)
# at this point we can be quite optimistic and just assert
assert model_info is not None

Expand All @@ -117,7 +122,8 @@ def shard_weights(env, weights, weight_shardings):
for key, val in weights.items():
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
with jax.default_device(jax.devices("cpu")[0]):
arr = torch_xla2.tensor.t2j(val)
# Note we clone to avoid a core-dump that might happen otherwise when calling device_put
arr = torch_xla2.tensor.t2j(val.clone())
arr = jax.device_put(arr, sharding)
sharded[key] = torchjax.to_torch(arr)
return sharded
Expand Down
26 changes: 25 additions & 1 deletion text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _test_decode_single(params):
assert output.text == params.expected_text


@pytest.mark.slow
@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"])
@pytest.mark.parametrize("params",
[
Expand All @@ -101,8 +102,31 @@ def _test_decode_single(params):
sequence_length=256,
expected_text="\n\nThe clocks were striking thirteen\nThe clocks were striking thirteen\n",
),
DecodeTestParams(
model_id="meta-llama/Meta-Llama-3-8B",
sequence_length=256,
expected_text=" Winston Winston Smith, his chin on his hands, and the clock in the Ministry of Truth, M",
),
],
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B"],
)
def test_decode_single_jetstream_pytorch_slow(params, do_sample):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
params.do_sample = do_sample
_test_decode_single(params)


@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"])
@pytest.mark.parametrize("params",
[
DecodeTestParams(
model_id="Maykeye/TinyLLama-v0",
sequence_length=256,
expected_text=" She She had a big and it had a big, blue, and a big, red and a",
),
],
ids=["Llama-2-7b-hf"],
ids=["TinyLLama-v0"],
)
def test_decode_single_jetstream_pytorch(params, do_sample):
if not jetstream_pt_available():
Expand Down
Loading