-
Notifications
You must be signed in to change notification settings - Fork 424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: added k, v cache for inference speed up #7
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,66 +35,95 @@ def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] | |
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v | ||
|
||
|
||
def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] | ||
def mha(x, c_attn, c_proj, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd] | ||
# qkv projection | ||
# when we pass kvcache, n_seq = 1. so we will compute new_q, new_k and new_v | ||
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd] | ||
|
||
# split into qkv | ||
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd] | ||
|
||
if kvcache: | ||
# qkv | ||
new_q, new_k, new_v = qkv # new_q, new_k, new_v = [1, n_embd] | ||
old_k, old_v = kvcache | ||
k = np.vstack([old_k, new_k]) # k = [n_seq, n_embd], where n_seq = prev_n_seq + 1 | ||
v = np.vstack([old_v, new_v]) # v = [n_seq, n_embd], where n_seq = prev_n_seq + 1 | ||
qkv = [new_q, k, v] | ||
|
||
current_cache = [qkv[1], qkv[2]] | ||
|
||
# split into heads | ||
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head] | ||
|
||
# causal mask to hide future inputs from being attended to | ||
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq] | ||
if kvcache: | ||
# when we pass kvcache, we are passing single token as input which need to attend to all previous tokens, so we create mask with all 0s | ||
causal_mask = np.zeros((1, k.shape[0])) | ||
else: | ||
# create triangular causal mask | ||
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq] | ||
|
||
# perform attention over each head | ||
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head] | ||
|
||
|
||
# merge heads | ||
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd] | ||
|
||
# out projection | ||
x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd] | ||
|
||
return x | ||
return x, current_cache | ||
|
||
|
||
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] | ||
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd] | ||
# multi-head causal self attention | ||
x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd] | ||
attn_out, kvcache_updated = mha(layer_norm(x, **ln_1), **attn, n_head=n_head, kvcache=kvcache) | ||
x = x + attn_out # [n_seq, n_embd] -> [n_seq, n_embd] | ||
|
||
# position-wise feed forward network | ||
x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd] | ||
|
||
return x | ||
return x, kvcache_updated | ||
|
||
|
||
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, kvcache = None): # [n_seq] -> [n_seq, n_vocab] | ||
if not kvcache: | ||
kvcache = [None]*len(blocks) | ||
wpe_out = wpe[range(len(inputs))] | ||
else: | ||
wpe_out = wpe[[len(inputs)-1]] | ||
inputs = [inputs[-1]] | ||
|
||
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab] | ||
# token + positional embeddings | ||
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd] | ||
x = wte[inputs] + wpe_out # [n_seq] -> [n_seq, n_embd] | ||
|
||
|
||
# forward pass through n_layer transformer blocks | ||
for block in blocks: | ||
x = transformer_block(x, **block, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd] | ||
new_kvcache = [] | ||
for block, kvcache_block in zip(blocks, kvcache): | ||
x, updated_cache = transformer_block(x, **block, n_head=n_head, kvcache=kvcache_block) # [n_seq, n_embd] -> [n_seq, n_embd] | ||
new_kvcache.append(updated_cache) # TODO: inplace extend new cache instead of re-saving whole | ||
|
||
# projection to vocab | ||
x = layer_norm(x, **ln_f) # [n_seq, n_embd] -> [n_seq, n_embd] | ||
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab] | ||
return x @ wte.T, new_kvcache # [n_seq, n_embd] -> [n_seq, n_vocab] | ||
|
||
|
||
def generate(inputs, params, n_head, n_tokens_to_generate): | ||
from tqdm import tqdm | ||
|
||
kvcache = None | ||
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop | ||
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass | ||
logits, kvcache = gpt2(inputs, **params, n_head=n_head, kvcache=kvcache) # model forward pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main benefit of KV caching is that you don't need to recalculate the MLPs again for the tokens you already calculated the forward for, and so in the decoding phase you only pass the new token as input to the network. You should only pass the more: https://www.perplexity.ai/search/what-should-be-the-input-to-th-bsYpXZiuRFinjT11Ck33EA#0 |
||
next_id = np.argmax(logits[-1]) # greedy sampling | ||
inputs = np.append(inputs, [next_id]) # append prediction to input | ||
|
||
return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids | ||
|
||
|
||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): | ||
def main(prompt: str = "Alan Turing theorized that computers would one day become", n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): | ||
from utils import load_encoder_hparams_and_params | ||
|
||
# load encoder, hparams, and params from the released open-ai gpt-2 files | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@panaali You're correct, if kvcache is there then only the last token should be passed. But this is I being lazy and don't want to change function signatures. So, I am doing it inside function. I just use the last token as input if kvcache is there.