-
My primary motivation for this question is to determine if the recently added (and excellent) lm evaluation functionality can be sped up using a prefix cache, as many multiple choice question tasks in the LM evaluation harness have a common prefix that I assume could benefit from a prefix cache in the same way as described in Long Prompts and Generations and LLM inference: Attention Layer. But, it is not clear from the documentation or examples of the caching if and how it can be used for this The lm_eval bits that use the cache are the only example I see of how it could be used for batched logit calculation. However, unless I'm mistaken, it seems it just uses the cache to quickly load a fresh state of the model before extracting logits scores. If you are evaluating the
If I wanted to create a cache for the common prompt prefix for all the tasks questions and use it when evaluating their scores to save redundant computation, is that possible, and is this how it would be done? prefix_tokens = tokenizer.encode("The following are multiple choice questions (with answers) about anatomy .")
model, tokenizer = load(".. some model and I ain't talking 'bout the ones that work the cat walk ..")
#Create batch filled with the prefix
prefix_cache_batch = mx.repeat(mx.array(prefix_tokens)[None, ...], batch_size, 0)
cache = make_prompt_cache(model)
_ = model(prefix_cache_batch[:, :-1], cache=cache)
#... save the cache for later if used outside running code ...
#question_batch is a padded batch of questions (each of which share the common prefix)
logits = model(question_batch[:, :-1], cache=cache) If so, will question_batch have to have the prefix removed? The other motivation is related to this use case and has to do with LoRA training on a batch of data, all of which also share a common prefix. At the point the loss is calculated, the process of extracting logits is identical to the situation above, which suggests this very specific training scenario could also benefit from a prefix cache. Would the same approach work for this situation? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
You could definitely re-use the computation for common prefixes. It would look something like this: cache = make_prompt_cache(model)
# prefill cache on common prefix
_ = model(predix_tokens, cache=cache)
# Deep copy the cache when evaluating per-example suffix
logits_q1 = model(q1_tokens, cache=copy.deepcopy(cache))
logits_q2 = model(q2_tokens, cache=copy.deepcopy(cache))
Yes.. you would need to remove the cached prefix tokens. So
It's more complicated in this case because the cache even for the prefix tokens would change from step-to-step of training (since the adapter weights are training). In the very edge case that you have a bunch of examples in the same mini batch with a large common prefix you could indeed speed things up by sharing the prefix computation. But that case is pretty uncommon and gets quite messy if some examples share a prefix while others do not in the same batch. |
Beta Was this translation helpful? Give feedback.
You could definitely re-use the computation for common prefixes. It would look something like this:
Yes.. you would need to remove the cached prefix tokens. So
q1_tokens
is the first example with the common prefix removed.