Skip to content

Commit

Permalink
Move setup funcs to conftest.py,
Browse files Browse the repository at this point in the history
Add logging throughout test,
Move integration test to `build_tools/integration_tests/llm`,
Rename ci file
  • Loading branch information
stbaione committed Nov 4, 2024
1 parent b55f332 commit 5798b62
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

name: CI - sharktank and shortfin
name: CI - shark-platform

on:
workflow_dispatch:
Expand Down Expand Up @@ -70,5 +70,5 @@ jobs:
iree-runtime \
"numpy<2.0"
- name: Run shortfin LLM Server Integration Test
run: pytest -v build_tools/integration_tests/
- name: Run LLM Integration Tests
run: pytest -v build_tools/integration_tests/llm --log-cli-level=INFO
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import logging
import os
from pathlib import Path
import pytest
import requests
import shutil
import subprocess
import time
import uuid

pytest.importorskip("transformers")
from transformers import AutoTokenizer

CPU_SETTINGS = {
"device_flags": [
"-iree-hal-target-backends=llvm-cpu",
"--iree-llvmcpu-target-cpu=host",
],
"device": "local-task",
}
IREE_HIP_TARGET = os.environ.get("IREE_HIP_TARGET", "gfx1100")
gpu_settings = {
"device_flags": [
"-iree-hal-target-backends=rocm",
f"--iree-hip-target={IREE_HIP_TARGET}",
],
"device": "hip",
}
logger = logging.getLogger(__name__)


@pytest.fixture(scope="module")
def model_test_dir(request, tmp_path_factory):
"""Prepare model artifacts for starting the LLM server.
Args:
request (FixtureRequest): The following params are accepted:
- repo_id (str): The Hugging Face repo ID.
- model_file (str): The model file to download.
- tokenizer_id (str): The tokenizer ID to download.
- settings (dict): The settings for sharktank export.
- batch_sizes (list): The batch sizes to use for the model.
tmp_path_factory (TempPathFactory): Temp dir to save artifacts to.
Yields:
Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir.
"""
logger.info("Preparing model artifacts...")

repo_id = request.param["repo_id"]
model_file = request.param["model_file"]
tokenizer_id = request.param["tokenizer_id"]
Expand All @@ -48,25 +44,43 @@ def model_test_dir(request, tmp_path_factory):
try:
# Download model if it doesn't exist
model_path = hf_home / model_file
logger.info(f"Preparing model_path: {model_path}..")
if not os.path.exists(model_path):
logger.info(
f"Downloading model {repo_id} {model_file} from Hugging Face..."
)
subprocess.run(
f"huggingface-cli download --local-dir {hf_home} {repo_id} {model_file}",
shell=True,
check=True,
)
logger.info(f"Model downloaded to {model_path}")
else:
logger.info("Using cached model")

# Set up tokenizer if it doesn't exist
tokenizer_path = hf_home / "tokenizer.json"
logger.info(f"Preparing tokenizer_path: {tokenizer_path}...")
if not os.path.exists(tokenizer_path):
logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...")
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_id,
)
tokenizer.save_pretrained(hf_home)
logger.info(f"Tokenizer saved to {tokenizer_path}")
else:
logger.info("Using cached tokenizer")

# Export model if it doesn't exist
# Export model
mlir_path = tmp_dir / "model.mlir"
config_path = tmp_dir / "config.json"
bs_string = ",".join(map(str, batch_sizes))
logger.info(
"Exporting model with following settings:\n"
f" MLIR Path: {mlir_path}\n"
f" Config Path: {config_path}\n"
f" Batch Sizes: {bs_string}"
)
subprocess.run(
[
"python",
Expand All @@ -79,9 +93,11 @@ def model_test_dir(request, tmp_path_factory):
],
check=True,
)
logger.info(f"Model successfully exported to {mlir_path}")

# Compile model if it doesn't exist
# Compile model
vmfb_path = tmp_dir / "model.vmfb"
logger.info(f"Compiling model to {vmfb_path}")
subprocess.run(
[
"iree-compile",
Expand All @@ -92,6 +108,7 @@ def model_test_dir(request, tmp_path_factory):
+ settings["device_flags"],
check=True,
)
logger.info(f"Model successfully compiled to {vmfb_path}")

# Write config if it doesn't exist
edited_config_path = tmp_dir / "edited_config.json"
Expand All @@ -106,8 +123,11 @@ def model_test_dir(request, tmp_path_factory):
"transformer_block_count": 26,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
}
logger.info(f"Saving edited config to: {edited_config_path}\n")
logger.info(f"Config: {json.dumps(config, indent=2)}")
with open(edited_config_path, "w") as f:
json.dump(config, f)
logger.info("Model artifacts setup successfully")
yield hf_home, tmp_dir
finally:
shutil.rmtree(tmp_dir)
Expand All @@ -117,13 +137,16 @@ def model_test_dir(request, tmp_path_factory):
def available_port(port=8000, max_port=8100):
import socket

logger.info(f"Finding available port in range {port}-{max_port}...")

starting_port = port

while port < max_port:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", port))
s.close()
logger.info(f"Found available port: {port}")
return port
except socket.error:
port += 1
Expand All @@ -132,10 +155,12 @@ def available_port(port=8000, max_port=8100):


def wait_for_server(url, timeout=10):
logger.info(f"Waiting for server to start at {url}...")
start = time.time()
while time.time() - start < timeout:
try:
requests.get(f"{url}/health")
logger.info("Server successfully started")
return
except requests.exceptions.ConnectionError:
time.sleep(1)
Expand All @@ -144,6 +169,19 @@ def wait_for_server(url, timeout=10):

@pytest.fixture(scope="module")
def llm_server(request, model_test_dir, available_port):
"""Start the LLM server.
Args:
request (FixtureRequest): The following params are accepted:
- model_file (str): The model file to download.
- settings (dict): The settings for starting the server.
model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir.
available_port (int): The available port to start the server on.
Yields:
subprocess.Popen: The server process that was started.
"""
logger.info("Starting LLM server...")
# Start the server
hf_home, tmp_dir = model_test_dir
model_file = request.param["model_file"]
Expand All @@ -166,58 +204,3 @@ def llm_server(request, model_test_dir, available_port):
# Teardown: kill the server
server_process.terminate()
server_process.wait()


def do_generate(prompt, port):
headers = {"Content-Type": "application/json"}
# Create a GenerateReqInput-like structure
data = {
"text": prompt,
"sampling_params": {"max_tokens": 50, "temperature": 0.7},
"rid": uuid.uuid4().hex,
"return_logprob": False,
"logprob_start_len": -1,
"top_logprobs_num": 0,
"return_text_in_logprobs": False,
"stream": False,
}
print("Prompt text:")
print(data["text"])
BASE_URL = f"http://localhost:{port}"
response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data)
print(f"Generate endpoint status code: {response.status_code}")
if response.status_code == 200:
print("Generated text:")
data = response.text
assert data.startswith("data: ")
data = data[6:]
assert data.endswith("\n\n")
data = data[:-2]
return data
else:
response.raise_for_status()


@pytest.mark.parametrize(
"model_test_dir,llm_server",
[
(
{
"repo_id": "SlyEcho/open_llama_3b_v2_gguf",
"model_file": "open-llama-3b-v2-f16.gguf",
"tokenizer_id": "openlm-research/open_llama_3b_v2",
"settings": CPU_SETTINGS,
"batch_sizes": [1, 4],
},
{"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
)
],
indirect=True,
)
def test_llm_server(llm_server, available_port):
# Here you would typically make requests to your server
# and assert on the responses
assert llm_server.poll() is None
output = do_generate("1 2 3 4 5 ", available_port)
print(output)
assert output.startswith("6 7 8")
85 changes: 85 additions & 0 deletions build_tools/integration_tests/llm/cpu_llm_server_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
import os
import pytest
import requests
import uuid

logger = logging.getLogger(__name__)

CPU_SETTINGS = {
"device_flags": [
"-iree-hal-target-backends=llvm-cpu",
"--iree-llvmcpu-target-cpu=host",
],
"device": "local-task",
}
IREE_HIP_TARGET = os.environ.get("IREE_HIP_TARGET", "gfx1100")
gpu_settings = {
"device_flags": [
"-iree-hal-target-backends=rocm",
f"--iree-hip-target={IREE_HIP_TARGET}",
],
"device": "hip",
}


def do_generate(prompt, port):
logger.info("Generating request...")
headers = {"Content-Type": "application/json"}
# Create a GenerateReqInput-like structure
data = {
"text": prompt,
"sampling_params": {"max_tokens": 50, "temperature": 0.7},
"rid": uuid.uuid4().hex,
"return_logprob": False,
"logprob_start_len": -1,
"top_logprobs_num": 0,
"return_text_in_logprobs": False,
"stream": False,
}
logger.info("Prompt text:")
logger.info(data["text"])
BASE_URL = f"http://localhost:{port}"
response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data)
logger.info(f"Generate endpoint status code: {response.status_code}")
if response.status_code == 200:
logger.info("Generated text:")
data = response.text
assert data.startswith("data: ")
data = data[6:]
assert data.endswith("\n\n")
data = data[:-2]
return data
else:
response.raise_for_status()


@pytest.mark.parametrize(
"model_test_dir,llm_server",
[
(
{
"repo_id": "SlyEcho/open_llama_3b_v2_gguf",
"model_file": "open-llama-3b-v2-f16.gguf",
"tokenizer_id": "openlm-research/open_llama_3b_v2",
"settings": CPU_SETTINGS,
"batch_sizes": [1, 4],
},
{"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
)
],
indirect=True,
)
def test_llm_server(llm_server, available_port):
# Here you would typically make requests to your server
# and assert on the responses
assert llm_server.poll() is None
output = do_generate("1 2 3 4 5 ", available_port)
logger.info(output)
assert output.startswith("6 7 8")

0 comments on commit 5798b62

Please sign in to comment.