-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shortfin llm beam search #1011
base: main
Are you sure you want to change the base?
Shortfin llm beam search #1011
Conversation
However, this achieves a stable implementation of unsharded beam search for varying `n_beams` and varying batch sizes. Both an initial implementation, and a checkpoint
Cleanup `BeamGroup.process_beams`
…in-llm-beam-search
Ensure beam_group gets deleted
src_view = page_table.view(src_page.index) | ||
dst_view = page_table.view(dst_page.index) | ||
# Copy the data | ||
dst_view.copy_from(src_view) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't be copying every page but only incomplete pages. I.e. the final page if it is incomplete. The rest should be tracked with the retain count system.
@renxida should be able to point to it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Messaged @renxida and asked if he can provide context for the page copying. I think I see what you're saying though. If a page is full, we just need to read from it, not write to it. So, we should be able to share the full pages, which is all except the last one. That last non-full one is where we may actually see differences, so needs to be copied?
topk_indices = indices[axis][-k:] | ||
topk_values = logits[axis][topk_indices] | ||
|
||
sorted_indices = np.argsort(topk_values)[::-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most importantly when moving we want to use a topk algorithm and not sort and select topk. This is easy 20k+ elements for most models so even simple top-k implementations will outperform sorting the whole array. This can be included in the TODO above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This avoids sorting the 20k+ elements in the array, unless I missed something somewhere.
I use argpartition to ensure that the k
last elements of the array are the largest. This runs in O(n) time. It doesn't guarantee that those k last elements are sorted though.
source
From there, I slice off those maximum k indices. Then I do an index view into logits to obtain topk_values
.
So, len(topk_values) == k
.
This argsort is just sorting the k top values (i.e. 4 top values), not values of the entire array.
Will double-check the shape in debugger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, double-checked, len(topk_values) == k
@@ -118,6 +118,7 @@ def generate_params_json( | |||
"block_seq_stride": llama_config.block_seq_stride, | |||
"device_block_count": args.device_block_count, # so that this makes its way into the config file & can be edited. | |||
}, | |||
"n_beams": llama_config.n_beams, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be just handled by a default loader. The export json
is just an example IR, our dataclass loaders should be resilient to missing optional values.
@@ -127,6 +127,12 @@ def add_model_options(parser: argparse.ArgumentParser): | |||
type=int, | |||
default=512, | |||
) | |||
parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beam search is a service level control so there is no benefit to including it during export.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will remove all reference in sharktank
exec.free_cache_pages() | ||
|
||
async def greedy_decode_loop(self, exec_req: InferenceExecRequest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These functions should be abstracted into a helper componsite class instead of integrated into generate. The idea should be that you create the class for the type of search you do.
E.g.
class DecodeManager():
def __init__():
pass
def decode():
...
class GreedyDecoder():
def decode():
....
class BeamSearchDecoder():
def decode():
....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Essentially we want this class to define how someone would write different sampling methods and we can just dependency inject the different behavior into the generator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea, will make this refactor
self.rid, | ||
) | ||
new_exec_req.start_position = self.start_position | ||
result_logits: sfnp.device_array = self.result_logits.for_transfer() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It concerns me that we see result_logits being retained / surviving. If we do want it to survive this buffer should be held via retainment and not copied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate on the issue and diff between retainment and coped? I think there's some background/context that I'm missing
*log_prob_map[key], | ||
) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Post selecting the K beams you want to continue down you should normalize the running decode. Float addition can accumulate error so if you find the minimum value and subtract that from all continued beams we can continue down the hypothesis train. There will need to be some correction for different length hypothesis to accommodate for them but it just means tracking what the accumulated normalization was.
Initial implementation of
beam_search
for the LLM server. Putting it up as a draft for now to get some feedback on the code, and because it needs unit/integration tests. Too big/important of a change to merge without it.The beam_search specific logic is contained with
beam_manager.py
, whilegenerate.py
just contains logic for managing theInferenceExecRequest
s and orchestrating the overall flow.This also keeps all of the logic above the Batcher level, which minimized changes a lot from what I previously tried.
At a high level, the idea is that we:
n_beams
InferenceExecRequests. Here wereplicate
the initial req that we used for prefill, including replicating the KVCaches pages.BeamGroup
. This is a helper class that handles actually performing beam_search token selection, and tracking the reqs.top_k
tokens, based on cumulative log probability.top_k
tokens, from all reqs.eos_token
, add it to the set ofcompleted_reqs
.max_completion_tokens
is reached, or all beams generated aneos_token
Selecting top-k
Extra attention in this section would be appreciated.
For each beam:
<begin_loop>
log_softmax
to logits. By taking the log of softmax, we can use addition to track the cumulative probabilities, instead of multiplication with the raw probabilities. If we did multiplication with raw probabilities, our cumulative probabilities will become smaller and smaller, until we lose precision. source (search for 'From products to addition`)top_k
values and tokens from the-1
axis. Track the cumulative log probs for each possible token.<end_loop>
We then return the top possible selections, in sorted order, based on which potential tokens would yield the beams with the highest probability.