-
Notifications
You must be signed in to change notification settings - Fork 896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prompt Lookup Decoding - merged under Speculative example #237
base: main
Are you sure you want to change the base?
Prompt Lookup Decoding - merged under Speculative example #237
Conversation
@awni perhaps we can leave this as T5 and then make an attempt at swapping to Llama in a new PR? I was thinking we could adopt the model format / conversion from |
Yea that sounds like a great plan to me! Sorry for the delay in the review here, I will get to it shortly! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really nice! I think we can get this in soon. I didn't look yet at the core of the prompt decoder but left a few comments.
Thanks a ton for refactoring them together, I think it makes a lot of sense this way.
…ricsson/mlx-examples into speculative_decoding_prompt_lookup
Thanks! Addressed all your comments |
I've been poking around your code @LeonEricsson because I have some long summarization tasks that I'd like to speed up, but noticed a significant bottleneck from the loop. This is probably still not perfect, but I've had a go at speeding it up. This implementation is about 500x faster: def find_draft(self, input_ids):
# Convert MLX array to NumPy for vectorized operations
input_ids_np = np.array(input_ids)
# Create a sliding window of the last ngram_max tokens
ngram = input_ids_np[-self.ngram_max:]
# Vectorized comparison of ngram with all possible sub-arrays of input_ids
matches = np.lib.stride_tricks.sliding_window_view(input_ids_np, self.ngram_max) == ngram
# Check if all elements in ngram match for each sub-array
match_indices = np.all(matches, axis=1).nonzero()[0]
# Filter out matches that are too short or overlap with the ngram itself
match_indices = match_indices[(match_indices + self.ngram_max <= input_ids_np.size - self.ngram_max) & (match_indices >= self.ngram_min)]
# Find the largest match
if match_indices.size > 0:
largest_match_idx = match_indices[-1] # Assuming the last match is the largest
start_idx = largest_match_idx + self.ngram_max
end_idx = min(start_idx + self.n_draft, input_ids_np.size)
candidate = input_ids_np[start_idx:end_idx]
# Convert the candidate back to MLX array
return mx.array(candidate, dtype=mx.uint32)
return mx.array([], dtype=mx.uint32) |
nice 🚀 the original implementation employed numpy's sliding windows, but I chose to maintain a purely mlx approach. However, as these are user examples, we should prioritize what is most beneficial for the user. A performance bottleneck like this is indeed a significant issue, and I concur that it warrants a change. sidenote: perhaps we can attain comparable speed improvements using mlx.core.vmap? |
That makes sense. And I'm guessing you didn't notice a huge performance gap, because you didn't try it on long texts? I'm shaving ~30 seconds off inference time. I was thinking about trying a vmap version next. Edit: I should clarify, without vectorization prompt lookup is slower than generate for anything but the most trivial task (e.g. repetition). So I think this change is necessary to really justify its existence as a useful example for the community. |
@cmcmaster1 finally implemented a pure MLX version that should be comparable in performance to the numpy one. Would be great if you could confirm this on your end. However, before you do so note that your current implementation does not consider ngram matches other than of size @awni imo this is ready to be merged, sorry for the delay. |
@LeonEricsson oops, you're right. I somehow missed that and just tested on examples where it made no difference! Still much faster than the original and definitely comparable to the (flawed) numpy implementation. |
ping @awni |
Sorry for the delay!! I will review and get this in early next week |
Continuation of #202. Decided to merge the Prompt Lookup Decoding under the Speculative Decoding example.
This PR implements a example for the "Prompt Lookup Decoding" technique:
https://github.com/apoorvumang/prompt-lookup-decoding
TODO
--color
flag to SpeculativeDecoderEnded up being quite a messy implementation; need to deal with the fact that output can be truncated and hence only come from draft model