Skip to content

Commit

Permalink
Merge branch 'main' into llm-server-int-test
Browse files Browse the repository at this point in the history
  • Loading branch information
stbaione authored Nov 5, 2024
2 parents 5798b62 + 46debb4 commit beb5757
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 11 deletions.
39 changes: 32 additions & 7 deletions shortfin/python/shortfin_apps/llm/client.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import requests
import json
import uuid
import argparse
import time

BASE_URL = "http://localhost:8000"


def test_health():
response = requests.get(f"{BASE_URL}/health")
print(f"Health check status code: {response.status_code}")
return response.status_code == 200
response.raise_for_status()


def test_generate():
def test_generate(prompt_text):
headers = {"Content-Type": "application/json"}

# Create a GenerateReqInput-like structure
data = {
"text": "1 2 3 4 5",
"text": prompt_text,
"sampling_params": {"max_tokens": 50, "temperature": 0.7},
"rid": uuid.uuid4().hex,
"return_logprob": False,
Expand Down Expand Up @@ -49,10 +51,33 @@ def test_generate():


def main():
print("Testing webapp...")

health_ok = test_health()
generate_ok = test_generate()
parser = argparse.ArgumentParser(description="Test webapp with custom prompt")
parser.add_argument(
"--prompt",
type=str,
default="1 2 3 4 5 ",
help="The prompt text to send to the generate endpoint",
)

args = parser.parse_args()

print(f"Testing shortfin llm server at {BASE_URL}")

health_ok = False
# previous backoff for fibonacci backoff
prev_backoff = 0
backoff = 1
while not health_ok:
try:
health_ok = test_health()
except requests.exceptions.ConnectionError:
print(
f"Health check failed. Waiting for {backoff} seconds before retrying."
)
time.sleep(backoff)
prev_backoff, backoff = backoff, prev_backoff + backoff

generate_ok = test_generate(args.prompt)

if health_ok and generate_ok:
print("\nAll tests passed successfully!")
Expand Down
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/llm/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def run(self):
exec.start_position = len(self.input_token_ids) - 1
# TODO: Use correct eot token from config.
# while token_int != 128001:
for i in range(15):
for i in range(40):
exec.reset(InferencePhase.DECODE)
exec.input_token_ids = [token_int]
exec.start_position += 1
Expand Down
17 changes: 14 additions & 3 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,10 @@ def board_decodes(self, cache: AttnPageCache):
assert decode_request.phase == InferencePhase.DECODE
if len(exec_process.exec_requests) >= self.ideal_batch_size:
break
incoming_token_count = len(decode_request.input_token_ids)
needed_pages = math.ceil(
len(decode_request.input_token_ids) / self.page_seq_stride
(decode_request.start_position + incoming_token_count)
/ self.page_seq_stride
)
if needed_pages > len(decode_request.locked_pages):
pages = cache.acquire_free_pages(needed_pages)
Expand Down Expand Up @@ -307,7 +309,13 @@ async def run(self):

# Compute block sequence length as maximum sequence length, rounded
# up to the seq_stride.
bsl = max(len(r.input_token_ids) for r in self.exec_requests)
if self.phase == InferencePhase.PREFILL:
for r in self.exec_requests:
assert r.start_position == 0

bsl = max(
(r.start_position + len(r.input_token_ids)) for r in self.exec_requests
)
bsl = int(math.ceil(bsl / seq_stride) * seq_stride)
block_count = bsl // seq_stride
req_count = len(self.exec_requests)
Expand Down Expand Up @@ -358,7 +366,10 @@ async def run(self):
seq_lens_host = seq_lens.for_transfer()
with seq_lens_host.map(discard=True) as m:
m.fill(0)
m.items = [req.start_position + 1 for req in self.exec_requests]
m.items = [
req.start_position + len(req.input_token_ids)
for req in self.exec_requests
]
seq_lens_host.copy_to(seq_lens)

# Populate cache pages.
Expand Down

0 comments on commit beb5757

Please sign in to comment.