Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shortfin llm beam search #1011

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def generate_params_json(
"block_seq_stride": llama_config.block_seq_stride,
"device_block_count": args.device_block_count, # so that this makes its way into the config file & can be edited.
},
"n_beams": llama_config.n_beams,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be just handled by a default loader. The export json is just an example IR, our dataclass loaders should be resilient to missing optional values.

}

# Unrolling cache updates by batch row makes dynamo sad without an
Expand Down
5 changes: 5 additions & 0 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ class LlamaModelConfig:
# the program and not.
static_tables: bool = True

# The number of beams to use when generating tokens for a given prompt.
# When n_beams == 1, `greedy` selection is used,
# when n_beams > 1, `beam search` is used.
n_beams: int = 1


@dataclass
class T5Config:
Expand Down
6 changes: 6 additions & 0 deletions sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ def add_model_options(parser: argparse.ArgumentParser):
type=int,
default=512,
)
parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beam search is a service level control so there is no benefit to including it during export.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove all reference in sharktank

"--n-beams",
help="Number of beams to use when generating tokens.",
type=int,
default=1,
)


def add_quantization_options(parser: argparse.ArgumentParser):
Expand Down
149 changes: 149 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/beam_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from asyncio import gather
from typing import Dict, List, Tuple
from uuid import uuid4

import numpy as np

from .messages import InferenceExecRequest


class BeamGroup:
def __init__(self, n_beams: int, exec_reqs: list[InferenceExecRequest]):
self.n_beams = n_beams
self.exec_reqs = exec_reqs
self.completed_reqs: set[InferenceExecRequest] = set()

async def wait(self):
done_signals = [
req.done for req in self.exec_reqs if req not in self.completed_reqs
]
return await gather(*done_signals)

def topk(
self, logits: np.array, k: int, axis: int
) -> Tuple[List[float], List[int]]:
# TODO: Move this to sfnp.array
indices = np.argpartition(logits, -k, axis=axis)
topk_indices = indices[axis][-k:]
topk_values = logits[axis][topk_indices]

sorted_indices = np.argsort(topk_values)[::-1]
Copy link
Contributor

@rsuderman rsuderman Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most importantly when moving we want to use a topk algorithm and not sort and select topk. This is easy 20k+ elements for most models so even simple top-k implementations will outperform sorting the whole array. This can be included in the TODO above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids sorting the 20k+ elements in the array, unless I missed something somewhere.

I use argpartition to ensure that the k last elements of the array are the largest. This runs in O(n) time. It doesn't guarantee that those k last elements are sorted though.
source

From there, I slice off those maximum k indices. Then I do an index view into logits to obtain topk_values.

So, len(topk_values) == k.

This argsort is just sorting the k top values (i.e. 4 top values), not values of the entire array.

Will double-check the shape in debugger

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, double-checked, len(topk_values) == k

topk_values = topk_values[sorted_indices]
topk_indices = topk_indices[sorted_indices]

return topk_values, topk_indices

def log_softmax(self, logits: np.array) -> np.array:
# TODO: Move this to sfnp.array
c = logits.max()
logsumexp = np.log(np.exp(logits - c).sum())
return logits - c - logsumexp

def evaluate_topk(self) -> List[tuple[float, InferenceExecRequest, int]]:
# TODO: Use temperature when processing logits for better diversity of
# outputs.
exec_reqs = self.exec_reqs

log_prob_map: Dict[float, tuple[InferenceExecRequest, int]] = {}
# Find the topk tokens for each req in our beam group
for exec_req in exec_reqs:
if exec_req in self.completed_reqs:
continue
# NOTE: This copy is slow, and part of why this needs to be moved to
# `shortfin.array`
logits = np.array(exec_req.result_logits)
# Take log_softmax. This is to avoid a req's cumulative probability
# becoming too small, which can lead precision issues.
# This allows us to obtain cumulative probability by summing
# the log_probabilities, instead of multiplying the probabilities.
log_logits = self.log_softmax(logits)
log_logits = np.squeeze(log_logits, 1)
values, tokens = self.topk(log_logits, self.n_beams, -1)
for value, token in zip(values, tokens):
cumulative_log_prob = exec_req.cumulative_log_prob + value
log_prob_map[cumulative_log_prob] = (exec_req, token)

# Find the topk tokens across all exec_reqs
sorted_keys = sorted(log_prob_map.keys(), reverse=True)
exec_req_selections: List[tuple[float, InferenceExecRequest, int]] = []
for key in sorted_keys[: self.n_beams - len(self.completed_reqs)]:
exec_req, token = log_prob_map[key]
exec_req.cumulative_log_prob = key
exec_req_selections.append(
(
key,
*log_prob_map[key],
)
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post selecting the K beams you want to continue down you should normalize the running decode. Float addition can accumulate error so if you find the minimum value and subtract that from all continued beams we can continue down the hypothesis train. There will need to be some correction for different length hypothesis to accommodate for them but it just means tracking what the accumulated normalization was.

return exec_req_selections

def process_beams(self, eos_token_id):
exec_reqs_selections = self.evaluate_topk()
visited_reqs: Dict[str, InferenceExecRequest] = {}
new_reqs = set()

for log_prob, req, token in exec_reqs_selections:
new_req = req
if new_req.instance_id not in visited_reqs:
new_req.input_token_ids.append(token)
new_req.start_position += 1

else:
visited_req = visited_reqs[new_req.instance_id]
new_req = visited_req.replicate_self()
new_req.input_token_ids.append(token)

new_req.cumulative_log_prob = log_prob
visited_reqs[new_req.instance_id] = new_req
new_reqs.add(new_req)
if token == eos_token_id:
self.completed_reqs.add(new_req)

for req in self.exec_reqs:
if req not in new_reqs:
req.free_cache_pages()

self.exec_reqs = list(new_reqs)

def find_top_beam(self) -> InferenceExecRequest:
completed_reqs = list(self.completed_reqs)
if not completed_reqs:
completed_reqs = self.exec_reqs
max_score = completed_reqs[0].cumulative_log_prob
selected_req = completed_reqs[0]
for req in completed_reqs[1:]:
if req.cumulative_log_prob > max_score:
selected_req = req
max_score = req.cumulative_log_prob

return selected_req

def __del__(self):
for req in self.exec_reqs:
req.free_cache_pages()

for req in self.completed_reqs:
req.free_cache_pages()


class BeamManager:
def __init__(self, n_beams):
self.n_beams: int = n_beams
self.beam_map: dict[str, BeamGroup] = {}

def create_beam(self, requests: list[InferenceExecRequest]) -> BeamGroup:
beam_group_id = str(uuid4())
for req in requests:
req.beam_group_id = beam_group_id

beam_group = BeamGroup(
self.n_beams,
requests,
)
self.beam_map[beam_group_id] = beam_group
return beam_group

def delete_beam(self, beam_group_id: str):
beam_group = self.beam_map[beam_group_id]
del beam_group
3 changes: 3 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ class ModelParams:
# Cache parameters.
paged_kv_cache: PagedKVCacheParams | None = None

# Number of beams to use during token generation.
n_beams: int = 1

# Size in bytes of the KV cache dtype.
@property
def attn_dtype_size(self) -> int:
Expand Down
100 changes: 84 additions & 16 deletions shortfin/python/shortfin_apps/llm/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import asyncio
import copy
import io
import json
import logging
Expand All @@ -15,6 +16,7 @@
# TODO: Have a generic "Responder" interface vs just the concrete impl.
from shortfin.interop.fastapi import FastAPIResponder

from .beam_manager import BeamManager
from .io_struct import GenerateReqInput
from .messages import InferenceExecRequest, InferencePhase
from .service import GenerateService
Expand All @@ -39,6 +41,7 @@ def __init__(
input_token_ids: list[int],
max_completion_tokens: int,
eos_token_id: int,
n_beams: int,
):
super().__init__(fiber=client.fiber)
self.client = client
Expand All @@ -48,6 +51,8 @@ def __init__(
self.result_token_ids: list[int] = []
self.max_completion_tokens = max_completion_tokens
self.eos_token_id = eos_token_id
self.n_beams = n_beams
self.beam_manager = BeamManager(n_beams)

self.streamed_tokens_index = 0

Expand All @@ -68,20 +73,70 @@ async def run(self):
self.append_token(token_int)
# Decode loop.
exec.start_position = len(self.input_token_ids) - 1
for i in range(self.max_completion_tokens):
exec.reset(InferencePhase.DECODE)
exec.input_token_ids.append(token_int)
exec.start_position += 1
self.client.batcher.submit(exec)
await exec.done
token = sfnp.argmax(exec.result_logits)
token_int = token.items[0]
self.append_token(token_int)
if token_int == self.eos_token_id:
break
exec.input_token_ids.append(token_int)
if self.n_beams > 1:
await self.beam_search_decode_loop(exec)
else:
await self.greedy_decode_loop(exec)
finally:
logger.info(f"Freeing cache pages: {exec.rid}")
exec.free_cache_pages()

async def greedy_decode_loop(self, exec_req: InferenceExecRequest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions should be abstracted into a helper componsite class instead of integrated into generate. The idea should be that you create the class for the type of search you do.

E.g.

class DecodeManager():
  def __init__():
    pass
   
   def decode():
      ...
 
class GreedyDecoder():
    def decode():
       ....
       
       
class BeamSearchDecoder():
    def decode():
       ....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially we want this class to define how someone would write different sampling methods and we can just dependency inject the different behavior into the generator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea, will make this refactor

for _ in range(self.max_completion_tokens):
exec_req.reset(InferencePhase.DECODE)
self.client.batcher.submit(exec_req)
await exec_req.done
token = sfnp.argmax(exec_req.result_logits)
token_int = token.items[0]
self.append_token(token_int)
if token_int == self.eos_token_id:
break
exec_req.input_token_ids.append(token_int)
exec_req.start_position += 1

async def beam_search_decode_loop(self, exec_req: InferenceExecRequest):
n_beams = self.n_beams
decode_reqs = [exec_req]
# First, we need to replicate our exec_req,
# such that len(decode_reqs) == self.n_beams
for _ in range(n_beams - 1):
decode_req = exec_req.replicate_self()
decode_reqs.append(decode_req)

self.beam_manager.create_beam(decode_reqs)
beam_group_id = exec_req.beam_group_id
beam_group = self.beam_manager.beam_map[beam_group_id]
for _ in range(self.max_completion_tokens):
if len(beam_group.completed_reqs) == self.n_beams:
break

# Submit all decode requests to the batcher from this beam
for exec in beam_group.exec_reqs:
if exec in beam_group.completed_reqs:
continue
exec.reset(InferencePhase.DECODE)
self.client.batcher.submit(exec)

# Wait for all beams to finish
await beam_group.wait()
beam_group.process_beams(self.eos_token_id)

if self.gen_req.return_top_k:
reqs = beam_group.completed_reqs
for req in beam_group.exec_reqs:
reqs.add(req)
results = [req.input_token_ids for req in reqs]
self.result_token_ids = results
self.client.stream_results(self)
self.beam_manager.delete_beam(beam_group_id)
return

selected_req = beam_group.find_top_beam()
self.result_token_ids = selected_req.input_token_ids
self.client.stream_results(self)
self.beam_manager.delete_beam(beam_group_id)

def append_token(self, token: int):
self.result_token_ids.append(token)
self.client.stream_results(self)
Expand All @@ -104,6 +159,7 @@ class ClientGenerateBatchProcess(sf.Process):
"gen_req",
"responder",
"tokenizer",
"n_beams",
]

def __init__(
Expand All @@ -118,6 +174,7 @@ def __init__(
self.tokenizer = service.tokenizer
self.batcher = service.batcher
self.complete_infeed = self.system.create_queue()
self.n_beams = service.model_params.n_beams

async def run(self):
logger.debug("Started ClientBatchGenerateProcess: %r", self)
Expand Down Expand Up @@ -148,6 +205,7 @@ async def run(self):
input_tokens if is_pretokenized else input_tokens.ids,
max_completion_tokens=max_completion_tokens,
eos_token_id=self.tokenizer.eos_token_id,
n_beams=self.n_beams,
)
gen_processes.append(gen_process)
gen_process.launch()
Expand All @@ -167,12 +225,22 @@ async def run(self):
result_tokens = result_tokens[0]
out.write(bytes(json.dumps(result_tokens), "utf-8"))
else:
result_texts = self.tokenizer.decode(result_tokens)
for result_text in result_texts:
out.write(b"data: ")
out.write(result_text.encode())
out.write(b"\n\n")
if self.gen_req.return_top_k:
for batch in result_tokens:
result_texts = self.tokenizer.decode(batch)
for result_text in result_texts:
out.write(b"data: ")
out.write(result_text.encode())
out.write(b"\n\n")
else:
result_texts = self.tokenizer.decode(result_tokens)
for result_text in result_texts:
out.write(b"data: ")
out.write(result_text.encode())
out.write(b"\n\n")
self.responder.send_response(out.getvalue())
except Exception as e:
logger.error(e)
finally:
self.responder.ensure_response()

Expand Down
2 changes: 2 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class GenerateReqInput:
top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False
# Whether to return multiple beams from server when using `beam_search`
return_top_k: bool = False
# Whether to stream output.
stream: bool = False
# The modalities of the image data [image, multi-images, video]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def release_pages(self) -> None:
"""Releases the allocation's reference to pages."""
pass

@abstractmethod
def replicate_self(self) -> "PageAllocation":
"""Replicate the pages of self in a new PageAllocation instance."""
pass

@abstractmethod
def extend_allocation(self, tokens, *, extra_token_slots=0) -> None:
"""
Expand Down Expand Up @@ -79,6 +84,13 @@ def release_pages(self) -> None:
self._cache.page_pool.free_pages(self._pages)
self._is_released = True

def replicate_self(self) -> "BasePagedAttentionCacheAllocation":
new_pages = self._cache.page_pool.copy_pages(self.pages)
if new_pages is None:
raise CacheAllocationFailure()

return BasePagedAttentionCacheAllocation(new_pages, self._cache)

def extend_allocation(self, tokens, *, extra_token_slots=0) -> None:
# assert old tokens are a prefix of incoming tokens
# if we don't have enough pages to hold the tokens, we need to allocate more pages
Expand Down
Loading
Loading