Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shortfin llm beam search #1011

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

stbaione
Copy link
Contributor

Initial implementation of beam_search for the LLM server. Putting it up as a draft for now to get some feedback on the code, and because it needs unit/integration tests. Too big/important of a change to merge without it.

The beam_search specific logic is contained with beam_manager.py, while generate.py just contains logic for managing the InferenceExecRequests and orchestrating the overall flow.

This also keeps all of the logic above the Batcher level, which minimized changes a lot from what I previously tried.

At a high level, the idea is that we:

  1. Run prefill normally. Get the kvcache initialized and obtain first token.
  2. Create n_beams InferenceExecRequests. Here we replicate the initial req that we used for prefill, including replicating the KVCaches pages.
  3. Group each of these reqs under a BeamGroup. This is a helper class that handles actually performing beam_search token selection, and tracking the reqs.
  4. Submit all reqs to batch, and wait for them all to finish.
  5. For each req select the top_k tokens, based on cumulative log probability.
  6. Select the overall top_k tokens, from all reqs.
  7. Update our beams. Do any replication/beam collapses if needed.
  8. If a req generates an eos_token, add it to the set of completed_reqs.
  9. Repeat until either, max_completion_tokens is reached, or all beams generated an eos_token
  10. When we return, we either select the highest cumulative_probability beam, or return all beams, depending on the request params.

Selecting top-k

Extra attention in this section would be appreciated.

For each beam:

<begin_loop>

  1. Obtain logits from decode invocation.
  2. Apply log_softmax to logits. By taking the log of softmax, we can use addition to track the cumulative probabilities, instead of multiplication with the raw probabilities. If we did multiplication with raw probabilities, our cumulative probabilities will become smaller and smaller, until we lose precision. source (search for 'From products to addition`)
  3. Select the top_k values and tokens from the -1 axis. Track the cumulative log probs for each possible token.
    <end_loop>

We then return the top possible selections, in sorted order, based on which potential tokens would yield the beams with the highest probability.

However, this achieves a stable implementation of unsharded beam search for varying `n_beams` and varying batch sizes.

Both an initial implementation, and a checkpoint
Cleanup `BeamGroup.process_beams`
Ensure beam_group gets deleted
@stbaione stbaione requested a review from rsuderman February 27, 2025 21:09
src_view = page_table.view(src_page.index)
dst_view = page_table.view(dst_page.index)
# Copy the data
dst_view.copy_from(src_view)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be copying every page but only incomplete pages. I.e. the final page if it is incomplete. The rest should be tracked with the retain count system.

@renxida should be able to point to it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Messaged @renxida and asked if he can provide context for the page copying. I think I see what you're saying though. If a page is full, we just need to read from it, not write to it. So, we should be able to share the full pages, which is all except the last one. That last non-full one is where we may actually see differences, so needs to be copied?

topk_indices = indices[axis][-k:]
topk_values = logits[axis][topk_indices]

sorted_indices = np.argsort(topk_values)[::-1]
Copy link
Contributor

@rsuderman rsuderman Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most importantly when moving we want to use a topk algorithm and not sort and select topk. This is easy 20k+ elements for most models so even simple top-k implementations will outperform sorting the whole array. This can be included in the TODO above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids sorting the 20k+ elements in the array, unless I missed something somewhere.

I use argpartition to ensure that the k last elements of the array are the largest. This runs in O(n) time. It doesn't guarantee that those k last elements are sorted though.
source

From there, I slice off those maximum k indices. Then I do an index view into logits to obtain topk_values.

So, len(topk_values) == k.

This argsort is just sorting the k top values (i.e. 4 top values), not values of the entire array.

Will double-check the shape in debugger

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, double-checked, len(topk_values) == k

@@ -118,6 +118,7 @@ def generate_params_json(
"block_seq_stride": llama_config.block_seq_stride,
"device_block_count": args.device_block_count, # so that this makes its way into the config file & can be edited.
},
"n_beams": llama_config.n_beams,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be just handled by a default loader. The export json is just an example IR, our dataclass loaders should be resilient to missing optional values.

@@ -127,6 +127,12 @@ def add_model_options(parser: argparse.ArgumentParser):
type=int,
default=512,
)
parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beam search is a service level control so there is no benefit to including it during export.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove all reference in sharktank

exec.free_cache_pages()

async def greedy_decode_loop(self, exec_req: InferenceExecRequest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions should be abstracted into a helper componsite class instead of integrated into generate. The idea should be that you create the class for the type of search you do.

E.g.

class DecodeManager():
  def __init__():
    pass
   
   def decode():
      ...
 
class GreedyDecoder():
    def decode():
       ....
       
       
class BeamSearchDecoder():
    def decode():
       ....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially we want this class to define how someone would write different sampling methods and we can just dependency inject the different behavior into the generator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea, will make this refactor

self.rid,
)
new_exec_req.start_position = self.start_position
result_logits: sfnp.device_array = self.result_logits.for_transfer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It concerns me that we see result_logits being retained / surviving. If we do want it to survive this buffer should be held via retainment and not copied.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on the issue and diff between retainment and coped? I think there's some background/context that I'm missing

*log_prob_map[key],
)
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post selecting the K beams you want to continue down you should normalize the running decode. Float addition can accumulate error so if you find the minimum value and subtract that from all continued beams we can continue down the hypothesis train. There will need to be some correction for different length hypothesis to accommodate for them but it just means tracking what the accumulated normalization was.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants