Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added k, v cache for inference speed up #7

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

immortal3
Copy link

Hi, @jaymody, Awesome blog post. I was interested in learning kvcache during inference and searched for it but existing articles on kvcache don't focus on the implementation part of it. So, I decided to implement it in picoGPT.

Are you interested in writing a post for optimization inference time? I would love to collaborate on it.

@jaymody
Copy link
Owner

jaymody commented Feb 12, 2023

Thanks for the implementation!

What kind of speedups did you get with this and did you get an identical output to the non-kv cache version?

Just FYI, I'm going to leave this unmerged to keep the implementation as simple as possible. However, will keep this PR open if people want to reference it in the future.

There's also an inference optimization section in my blog post with some further resources to read up on.

@immortal3
Copy link
Author

Yes, the Output is identical. I am seeing a 25% speedup of CPU.

 the most powerful machines on the planet.

The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.

Yeah, it makes sense to not merge it. Probably, we can create another file gpt2_inference_speed.py which can have these sorts of optimizations.

@clam004
Copy link

clam004 commented Jan 14, 2024

Hi @immortal3 I love the minimal implementation I'm having trouble reproducing the 25% speedup though. I've been using time to compare the two implementations and the 125M model for generating the above output text. If you are up for it, a before and after comparison in your own repo would be so cool and very compelling.

@immortal3
Copy link
Author

immortal3 commented Jan 14, 2024

@clam004 i don't remember exactly how I ended as 25% speedup but it was definitely not a scientific one. 😄

The speedup number will heavily rely on the combination of CPU/Memory and the length of the input tokens. So, I think you might not be getting the exact number 25%, but try feeding a sufficiently longer sequence that should definitely indicate some performance improvement compared to a normal forward pass with KV cache.

On the proper comparison side, I am not sure if it would be worth it (time-wise) at this point to do it thoroughly.

for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
logits, kvcache = gpt2(inputs, **params, n_head=n_head, kvcache=kvcache) # model forward pass
Copy link

@panaali panaali Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main benefit of KV caching is that you don't need to recalculate the MLPs again for the tokens you already calculated the forward for, and so in the decoding phase you only pass the new token as input to the network.

You should only pass the next_id as input in the decoding phase. In prefill phase, the initial inputs should be passed. checkout https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L72 or https://github.com/meta-llama/llama/blob/main/llama/generation.py#L187C51-L187C59 for an example.

more: https://www.perplexity.ai/search/what-should-be-the-input-to-th-bsYpXZiuRFinjT11Ck33EA#0

wpe_out = wpe[range(len(inputs))]
else:
wpe_out = wpe[[len(inputs)-1]]
inputs = [inputs[-1]]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@panaali You're correct, if kvcache is there then only the last token should be passed. But this is I being lazy and don't want to change function signatures. So, I am doing it inside function. I just use the last token as input if kvcache is there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants