Skip to content

Can key-value caching be used for optimizing batched retrieval of logits from inputs with common prompt prefixes? #1169

Answered by awni
chimezie asked this question in Q&A
Discussion options

You must be logged in to vote

You could definitely re-use the computation for common prefixes. It would look something like this:

cache = make_prompt_cache(model)

# prefill cache on common prefix
_ = model(predix_tokens, cache=cache)

# Deep copy the cache when evaluating per-example suffix
logits_q1 = model(q1_tokens, cache=copy.deepcopy(cache))

logits_q2 = model(q2_tokens, cache=copy.deepcopy(cache))

If so, will question_batch have to have the prefix removed?

Yes.. you would need to remove the cached prefix tokens. So q1_tokens is the first example with the common prefix removed.

The other motivation is related to this use case and has to do with LoRA training on a batch of data, all of which also share a commo…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@chimezie
Comment options

@awni
Comment options

awni Dec 23, 2024
Maintainer

Answer selected by chimezie
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants