-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shortfin llm beam search #1011
base: main
Are you sure you want to change the base?
Shortfin llm beam search #1011
Changes from all commits
a5f9c51
a4194a3
a92623c
0b8f388
abce7f3
730664b
46371ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,6 +127,12 @@ def add_model_options(parser: argparse.ArgumentParser): | |
type=int, | ||
default=512, | ||
) | ||
parser.add_argument( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Beam search is a service level control so there is no benefit to including it during export. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will remove all reference in sharktank |
||
"--n-beams", | ||
help="Number of beams to use when generating tokens.", | ||
type=int, | ||
default=1, | ||
) | ||
|
||
|
||
def add_quantization_options(parser: argparse.ArgumentParser): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from asyncio import gather | ||
from typing import Dict, List, Tuple | ||
from uuid import uuid4 | ||
|
||
import numpy as np | ||
|
||
from .messages import InferenceExecRequest | ||
|
||
|
||
class BeamGroup: | ||
def __init__(self, n_beams: int, exec_reqs: list[InferenceExecRequest]): | ||
self.n_beams = n_beams | ||
self.exec_reqs = exec_reqs | ||
self.completed_reqs: set[InferenceExecRequest] = set() | ||
|
||
async def wait(self): | ||
done_signals = [ | ||
req.done for req in self.exec_reqs if req not in self.completed_reqs | ||
] | ||
return await gather(*done_signals) | ||
|
||
def topk( | ||
self, logits: np.array, k: int, axis: int | ||
) -> Tuple[List[float], List[int]]: | ||
# TODO: Move this to sfnp.array | ||
indices = np.argpartition(logits, -k, axis=axis) | ||
topk_indices = indices[axis][-k:] | ||
topk_values = logits[axis][topk_indices] | ||
|
||
sorted_indices = np.argsort(topk_values)[::-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most importantly when moving we want to use a topk algorithm and not sort and select topk. This is easy 20k+ elements for most models so even simple top-k implementations will outperform sorting the whole array. This can be included in the TODO above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This avoids sorting the 20k+ elements in the array, unless I missed something somewhere. I use argpartition to ensure that the From there, I slice off those maximum k indices. Then I do an index view into logits to obtain So, This argsort is just sorting the k top values (i.e. 4 top values), not values of the entire array. Will double-check the shape in debugger There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, double-checked, len(topk_values) == k There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I misundestood. In that case we don't need to both including this sort. You don't have to remove but typically we don't actually care that the concurrent hypothesis are sorted just that we iterate on the right K. Then once we have the final K options we sort post all decode steps and return. Basically no point to sort per-step in this case, we just sort at the end. |
||
topk_values = topk_values[sorted_indices] | ||
topk_indices = topk_indices[sorted_indices] | ||
|
||
return topk_values, topk_indices | ||
|
||
def log_softmax(self, logits: np.array) -> np.array: | ||
# TODO: Move this to sfnp.array | ||
c = logits.max() | ||
logsumexp = np.log(np.exp(logits - c).sum()) | ||
return logits - c - logsumexp | ||
|
||
def evaluate_topk(self) -> List[tuple[float, InferenceExecRequest, int]]: | ||
# TODO: Use temperature when processing logits for better diversity of | ||
# outputs. | ||
exec_reqs = self.exec_reqs | ||
|
||
log_prob_map: Dict[float, tuple[InferenceExecRequest, int]] = {} | ||
# Find the topk tokens for each req in our beam group | ||
for exec_req in exec_reqs: | ||
if exec_req in self.completed_reqs: | ||
continue | ||
# NOTE: This copy is slow, and part of why this needs to be moved to | ||
# `shortfin.array` | ||
logits = np.array(exec_req.result_logits) | ||
# Take log_softmax. This is to avoid a req's cumulative probability | ||
# becoming too small, which can lead precision issues. | ||
# This allows us to obtain cumulative probability by summing | ||
# the log_probabilities, instead of multiplying the probabilities. | ||
log_logits = self.log_softmax(logits) | ||
log_logits = np.squeeze(log_logits, 1) | ||
values, tokens = self.topk(log_logits, self.n_beams, -1) | ||
for value, token in zip(values, tokens): | ||
cumulative_log_prob = exec_req.cumulative_log_prob + value | ||
log_prob_map[cumulative_log_prob] = (exec_req, token) | ||
|
||
# Find the topk tokens across all exec_reqs | ||
sorted_keys = sorted(log_prob_map.keys(), reverse=True) | ||
exec_req_selections: List[tuple[float, InferenceExecRequest, int]] = [] | ||
for key in sorted_keys[: self.n_beams - len(self.completed_reqs)]: | ||
exec_req, token = log_prob_map[key] | ||
exec_req.cumulative_log_prob = key | ||
exec_req_selections.append( | ||
( | ||
key, | ||
*log_prob_map[key], | ||
) | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Post selecting the K beams you want to continue down you should normalize the running decode. Float addition can accumulate error so if you find the minimum value and subtract that from all continued beams we can continue down the hypothesis train. There will need to be some correction for different length hypothesis to accommodate for them but it just means tracking what the accumulated normalization was. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha, so IIUC, for each decode iteration, shift all log_probs to the right by subtracting the min(all_log_probs). cumulative_log_prob += (log_prob - min(all_log_probs)) For each beam, track the total accumulated normalization: accumulated_normalization += abs(min(all_log_probs)) When a beam reaches eos, add the total accumulated normalization to it's cumulative_log_prob. After our decode iterations are finished, we apply length normalization to our top beams. So, our final score for each beam is calculated like this: final_score = (cumulative_log_prob + accumulated_normalization) / len(beam) Does that make sense, or should I be doing the correction and/or length normalization while I'm in the decode loop? |
||
return exec_req_selections | ||
|
||
def process_beams(self, eos_token_id): | ||
exec_reqs_selections = self.evaluate_topk() | ||
visited_reqs: Dict[str, InferenceExecRequest] = {} | ||
new_reqs = set() | ||
|
||
for log_prob, req, token in exec_reqs_selections: | ||
new_req = req | ||
if new_req.instance_id not in visited_reqs: | ||
new_req.input_token_ids.append(token) | ||
new_req.start_position += 1 | ||
|
||
else: | ||
visited_req = visited_reqs[new_req.instance_id] | ||
new_req = visited_req.replicate_self() | ||
new_req.input_token_ids.append(token) | ||
|
||
new_req.cumulative_log_prob = log_prob | ||
visited_reqs[new_req.instance_id] = new_req | ||
new_reqs.add(new_req) | ||
if token == eos_token_id: | ||
self.completed_reqs.add(new_req) | ||
|
||
for req in self.exec_reqs: | ||
if req not in new_reqs: | ||
req.free_cache_pages() | ||
|
||
self.exec_reqs = list(new_reqs) | ||
|
||
def find_top_beam(self) -> InferenceExecRequest: | ||
completed_reqs = list(self.completed_reqs) | ||
if not completed_reqs: | ||
completed_reqs = self.exec_reqs | ||
max_score = completed_reqs[0].cumulative_log_prob | ||
selected_req = completed_reqs[0] | ||
for req in completed_reqs[1:]: | ||
if req.cumulative_log_prob > max_score: | ||
selected_req = req | ||
max_score = req.cumulative_log_prob | ||
|
||
return selected_req | ||
|
||
def __del__(self): | ||
for req in self.exec_reqs: | ||
req.free_cache_pages() | ||
|
||
for req in self.completed_reqs: | ||
req.free_cache_pages() | ||
|
||
|
||
class BeamManager: | ||
def __init__(self, n_beams): | ||
self.n_beams: int = n_beams | ||
self.beam_map: dict[str, BeamGroup] = {} | ||
|
||
def create_beam(self, requests: list[InferenceExecRequest]) -> BeamGroup: | ||
beam_group_id = str(uuid4()) | ||
for req in requests: | ||
req.beam_group_id = beam_group_id | ||
|
||
beam_group = BeamGroup( | ||
self.n_beams, | ||
requests, | ||
) | ||
self.beam_map[beam_group_id] = beam_group | ||
return beam_group | ||
|
||
def delete_beam(self, beam_group_id: str): | ||
beam_group = self.beam_map[beam_group_id] | ||
del beam_group |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import asyncio | ||
import copy | ||
import io | ||
import json | ||
import logging | ||
|
@@ -15,6 +16,7 @@ | |
# TODO: Have a generic "Responder" interface vs just the concrete impl. | ||
from shortfin.interop.fastapi import FastAPIResponder | ||
|
||
from .beam_manager import BeamManager | ||
from .io_struct import GenerateReqInput | ||
from .messages import InferenceExecRequest, InferencePhase | ||
from .service import GenerateService | ||
|
@@ -39,6 +41,7 @@ def __init__( | |
input_token_ids: list[int], | ||
max_completion_tokens: int, | ||
eos_token_id: int, | ||
n_beams: int, | ||
): | ||
super().__init__(fiber=client.fiber) | ||
self.client = client | ||
|
@@ -48,6 +51,8 @@ def __init__( | |
self.result_token_ids: list[int] = [] | ||
self.max_completion_tokens = max_completion_tokens | ||
self.eos_token_id = eos_token_id | ||
self.n_beams = n_beams | ||
self.beam_manager = BeamManager(n_beams) | ||
|
||
self.streamed_tokens_index = 0 | ||
|
||
|
@@ -68,20 +73,70 @@ async def run(self): | |
self.append_token(token_int) | ||
# Decode loop. | ||
exec.start_position = len(self.input_token_ids) - 1 | ||
for i in range(self.max_completion_tokens): | ||
exec.reset(InferencePhase.DECODE) | ||
exec.input_token_ids.append(token_int) | ||
exec.start_position += 1 | ||
self.client.batcher.submit(exec) | ||
await exec.done | ||
token = sfnp.argmax(exec.result_logits) | ||
token_int = token.items[0] | ||
self.append_token(token_int) | ||
if token_int == self.eos_token_id: | ||
break | ||
exec.input_token_ids.append(token_int) | ||
if self.n_beams > 1: | ||
await self.beam_search_decode_loop(exec) | ||
else: | ||
await self.greedy_decode_loop(exec) | ||
finally: | ||
logger.info(f"Freeing cache pages: {exec.rid}") | ||
exec.free_cache_pages() | ||
|
||
async def greedy_decode_loop(self, exec_req: InferenceExecRequest): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These functions should be abstracted into a helper componsite class instead of integrated into generate. The idea should be that you create the class for the type of search you do. E.g.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Essentially we want this class to define how someone would write different sampling methods and we can just dependency inject the different behavior into the generator. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great idea, will make this refactor |
||
for _ in range(self.max_completion_tokens): | ||
exec_req.reset(InferencePhase.DECODE) | ||
self.client.batcher.submit(exec_req) | ||
await exec_req.done | ||
token = sfnp.argmax(exec_req.result_logits) | ||
token_int = token.items[0] | ||
self.append_token(token_int) | ||
if token_int == self.eos_token_id: | ||
break | ||
exec_req.input_token_ids.append(token_int) | ||
exec_req.start_position += 1 | ||
|
||
async def beam_search_decode_loop(self, exec_req: InferenceExecRequest): | ||
n_beams = self.n_beams | ||
decode_reqs = [exec_req] | ||
# First, we need to replicate our exec_req, | ||
# such that len(decode_reqs) == self.n_beams | ||
for _ in range(n_beams - 1): | ||
decode_req = exec_req.replicate_self() | ||
decode_reqs.append(decode_req) | ||
|
||
self.beam_manager.create_beam(decode_reqs) | ||
beam_group_id = exec_req.beam_group_id | ||
beam_group = self.beam_manager.beam_map[beam_group_id] | ||
for _ in range(self.max_completion_tokens): | ||
if len(beam_group.completed_reqs) == self.n_beams: | ||
break | ||
|
||
# Submit all decode requests to the batcher from this beam | ||
for exec in beam_group.exec_reqs: | ||
if exec in beam_group.completed_reqs: | ||
continue | ||
exec.reset(InferencePhase.DECODE) | ||
self.client.batcher.submit(exec) | ||
|
||
# Wait for all beams to finish | ||
await beam_group.wait() | ||
beam_group.process_beams(self.eos_token_id) | ||
|
||
if self.gen_req.return_top_k: | ||
reqs = beam_group.completed_reqs | ||
for req in beam_group.exec_reqs: | ||
reqs.add(req) | ||
results = [req.input_token_ids for req in reqs] | ||
self.result_token_ids = results | ||
self.client.stream_results(self) | ||
self.beam_manager.delete_beam(beam_group_id) | ||
return | ||
|
||
selected_req = beam_group.find_top_beam() | ||
self.result_token_ids = selected_req.input_token_ids | ||
self.client.stream_results(self) | ||
self.beam_manager.delete_beam(beam_group_id) | ||
|
||
def append_token(self, token: int): | ||
self.result_token_ids.append(token) | ||
self.client.stream_results(self) | ||
|
@@ -104,6 +159,7 @@ class ClientGenerateBatchProcess(sf.Process): | |
"gen_req", | ||
"responder", | ||
"tokenizer", | ||
"n_beams", | ||
] | ||
|
||
def __init__( | ||
|
@@ -118,6 +174,7 @@ def __init__( | |
self.tokenizer = service.tokenizer | ||
self.batcher = service.batcher | ||
self.complete_infeed = self.system.create_queue() | ||
self.n_beams = service.model_params.n_beams | ||
|
||
async def run(self): | ||
logger.debug("Started ClientBatchGenerateProcess: %r", self) | ||
|
@@ -148,6 +205,7 @@ async def run(self): | |
input_tokens if is_pretokenized else input_tokens.ids, | ||
max_completion_tokens=max_completion_tokens, | ||
eos_token_id=self.tokenizer.eos_token_id, | ||
n_beams=self.n_beams, | ||
) | ||
gen_processes.append(gen_process) | ||
gen_process.launch() | ||
|
@@ -167,12 +225,22 @@ async def run(self): | |
result_tokens = result_tokens[0] | ||
out.write(bytes(json.dumps(result_tokens), "utf-8")) | ||
else: | ||
result_texts = self.tokenizer.decode(result_tokens) | ||
for result_text in result_texts: | ||
out.write(b"data: ") | ||
out.write(result_text.encode()) | ||
out.write(b"\n\n") | ||
if self.gen_req.return_top_k: | ||
for batch in result_tokens: | ||
result_texts = self.tokenizer.decode(batch) | ||
for result_text in result_texts: | ||
out.write(b"data: ") | ||
out.write(result_text.encode()) | ||
out.write(b"\n\n") | ||
else: | ||
result_texts = self.tokenizer.decode(result_tokens) | ||
for result_text in result_texts: | ||
out.write(b"data: ") | ||
out.write(result_text.encode()) | ||
out.write(b"\n\n") | ||
self.responder.send_response(out.getvalue()) | ||
except Exception as e: | ||
logger.error(e) | ||
finally: | ||
self.responder.ensure_response() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be just handled by a default loader. The export
json
is just an example IR, our dataclass loaders should be resilient to missing optional values.