Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt Lookup Decoding - merged under Speculative example #237

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions llms/speculative_decoding/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.npz
21 changes: 19 additions & 2 deletions llms/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ python convert.py --model t5-small
You can run with the default arguments:

```
python main.py
python speculative.py
```

To see a full list of options use:
```
python main.py --help
python speculative.py --help
```

### Notes
Expand All @@ -64,3 +64,20 @@ draft tokens at the expense of more large model evaluations.
Decoding](https://arxiv.org/abs/2211.17192)
[^2]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).

## Prompt Lookup Decoding
When speculative decoding works, it significantly accelerates inference. However, selecting an appropriate draft model can be challenging. Prompt lookup decoding[^3] modifies speculative decoding by substituting the draft model with a straightforward sliding window search across the prompt. This alteration eliminates the need for a draft model while offering comparable speed enhancements, particularly when applied to the right task. Prompt lookup decoding excels in *input-grounded* tasks like summarization, document Q/A, and code editing, where there's substantial overlap between input and output.

## Run
[Setup](#setup) is the same as for Speculative Decoding. You can the run with default arguments:
LeonEricsson marked this conversation as resolved.
Show resolved Hide resolved

```
python prompt_lookup.py
```

To see a full list of options use:
```
python prompt_lookup --help
```

[^3] Check out the [original implementation](https://github.com/apoorvumang/prompt-lookup-decoding).
193 changes: 183 additions & 10 deletions llms/speculative_decoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ def __init__(
model: Model,
draft_model: Model,
tokenizer: str,
color: bool,
num_draft: int = 5,
delta: float = 0.0,
delta: float = 0.0
):
self.tokenizer = Tokenizer(tokenizer)
self.model = model
self.draft_model = draft_model
self.num_draft = num_draft
self.delta = delta
self.color = color

def _generate(
self,
Expand Down Expand Up @@ -91,6 +93,7 @@ def generate(
print()
self.model.reset_cache()

# Accept / Reject criteria (see Section 2.3 https://arxiv.org/pdf/2211.17192.pdf)
LeonEricsson marked this conversation as resolved.
Show resolved Hide resolved
def _get_num_accept(self, draft_tokens, draft_probs, model_logits):
# accept_toks = mx.argmax(model_logits, axis=-1) == draft_tokens
model_probs = mx.take_along_axis(
Expand Down Expand Up @@ -120,9 +123,9 @@ def sample(logits):
tokens = mx.array([self.tokenizer.decoder_start_id])

n_steps = 0
ntoks = 0
n_generated = 0
n_accepted = 0
n_draft = 0
n_drafted = 0

outputs = []
skip = 0
Expand All @@ -133,7 +136,7 @@ def sample(logits):
draft_tokens = []
draft_probs = []
for _, (t, p) in zip(
range(ntoks, min(ntoks + self.num_draft, max_tokens)),
range(n_generated, min(n_generated + self.num_draft, max_tokens)),
self._generate(draft_inputs, draft_memory, draft=True),
):
draft_tokens.append(t)
Expand Down Expand Up @@ -163,7 +166,7 @@ def sample(logits):
)

n_accepted += num_to_accept
n_draft += draft_tokens.size
n_drafted += draft_tokens.size

# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
Expand All @@ -172,17 +175,39 @@ def sample(logits):

n_steps += 1

truncated = False
for t in new_tokens.tolist():
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
if t == self.tokenizer.eos_id or n_generated >= max_tokens:
truncated = True
break
outputs.append(t)
ntoks += 1
n_generated += 1

str_output = self.tokenizer.decode(outputs)
print(str_output[skip:], end="", flush=True)

if self.color and not truncated:
model_token = len(self.tokenizer.decode(outputs[-1]))
print(
"\033[34m"
+ str_output[skip:-model_token]
+ "\033[30m",
end="",
)
print(str_output[-model_token:], end="", flush=True)
elif self.color and truncated:
if truncated:
print(
"\033[34m"
+ str_output[skip:]
+ "\033[30m",
end="",
)
else:
print(str_output[skip:], end="", flush=True)

skip = len(str_output)

if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
inputs = draft_inputs[-1:]
Expand All @@ -192,4 +217,152 @@ def sample(logits):

self.model.reset_cache()
self.draft_model.reset_cache()
return {"n_accepted": n_accepted, "n_draft": n_draft, "n_steps": n_steps}
return {"n_accepted": n_accepted, "n_draft": n_drafted, "n_steps": n_steps}


########################################################


class PromptLookupDecoder:
def __init__(
self,
model: Model,
tokenizer: str,
n_draft: int,
ngram_max: int,
ngram_min: int,
temp: float,
seed: int,
color: bool,
):
self.model = model
self.tokenizer = Tokenizer(tokenizer)
self.n_draft = n_draft
self.ngram_max = ngram_max
self.ngram_min = ngram_min
self.temp = temp
self.seed = seed
self.color = color

def generate_draft(self, input_ids):
ngram = input_ids[-self.ngram_max :]

largest_match = 0
draft = mx.array([], dtype=mx.uint32)

# Sliding window search
for i in range(1, input_ids.size - self.ngram_max):
matches = input_ids[max(0, i - self.ngram_max) : i] == ngram[-i:]

# reverse through the matches array
match_length = 0
for j in range(matches.size - 1, -1, -1):
if matches[j]:
match_length += 1
else:
break

if match_length >= self.ngram_min and match_length > largest_match:
largest_match = match_length
start_idx = i
end_idx = start_idx + self.n_draft
draft = input_ids[start_idx:end_idx]

return draft

def prompt_lookup(
self,
prompt: str,
max_tokens: int,
):
def sample(logits):
if self.temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / self.temp))

prompt = mx.array(self.tokenizer.encode(prompt), mx.uint32)[None]
memory = self.model.encode(prompt)

history = prompt.squeeze(0)[
:-1
] # remove eos token from prompt lookup search space

n_steps = 0
n_generated = 0
n_accepted = 0
n_drafted = 0

outputs = []
skip = 0
inputs = mx.array([self.tokenizer.decoder_start_id])
while True:
# For each decoding step: generate n_draft tokens by searching the prompt
draft_tokens = self.generate_draft(history)

# Verify draft tokens with the last verified token
verify_tokens = mx.concatenate([inputs, draft_tokens])
logits = self.model.decode(verify_tokens[None], memory)

# Only keep samples that match the draft:
# draft tokens aren't sampled - hence no accept / reject critera
sampled = sample(logits).squeeze(0)
equal_toks = sampled[:-1] == draft_tokens
num_to_accept = (equal_toks.tolist() + [False]).index(False)
new_tokens = sampled[
: max(1, num_to_accept + 1)
] # accepted draft tokens + next token from main model

n_accepted += num_to_accept
n_drafted += draft_tokens.size

# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
self.model.truncate_cache(n - new_tokens.size + 1)

n_steps += 1

truncated = False
for t in new_tokens.tolist():
if t == self.tokenizer.eos_id or n_generated >= max_tokens:
truncated = True
break
outputs.append(t)
n_generated += 1

str_output = self.tokenizer.decode(outputs)

if self.color and not truncated:
model_token = len(self.tokenizer.decode(outputs[-1]))
print(
"\033[34m"
+ str_output[skip:-model_token]
+ "\033[30m",
end="",
)
print(str_output[-model_token:], end="", flush=True)
elif self.color and truncated:
if truncated:
print(
"\033[34m"
+ str_output[skip:]
+ "\033[30m",
end="",
)
else:
print(str_output[skip:], end="", flush=True)

skip = len(str_output)

if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break

history = mx.concatenate([history, new_tokens])
inputs = history[-1:]

print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
print()

self.model.reset_cache()

return {"n_accepted": n_accepted, "n_draft": n_drafted, "n_steps": n_steps}
100 changes: 100 additions & 0 deletions llms/speculative_decoding/prompt_lookup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import argparse
import time

import mlx.core as mx
from decoder import PromptLookupDecoder
from mlx.utils import tree_unflatten
from model import Model
from transformers import T5Config


def load_model(model_name: str):
config = T5Config.from_pretrained(model_name)
model = Model(config)
weights = mx.load(f"{model_name}.npz")
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model
LeonEricsson marked this conversation as resolved.
Show resolved Hide resolved


def main(args):
mx.random.seed(args.seed)

lookup_decoder = PromptLookupDecoder(
model=load_model(args.model_name),
tokenizer=args.model_name,
n_draft=args.n_draft,
ngram_max=args.ngram_max,
ngram_min=args.ngram_min,
temp=args.temp,
seed=args.seed,
color=args.color,
)

tic = time.time()
print(args.prompt)

stats = lookup_decoder.prompt_lookup(args.prompt, max_tokens=args.max_tokens)
print("=" * 10)
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
print(f"Decoding steps {stats['n_steps']}.")

toc = time.time()
print("=" * 10)
print(f"Full generation time {toc - tic:.3f}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prompt Lookup Decoding")

parser.add_argument(
"--n-draft",
type=int,
default=10,
help="Number of draft tokens to generate upon prompt lookup match",
)
parser.add_argument(
"--model-name",
help="Name of the model.",
default="t5-base",
)

parser.add_argument(
"--prompt",
help="The prompt processed by the model.",
default="Translate the following from English to English: Let's go to the store and buy some groceries including eggs, avocadoes, and bread.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--ngram-max",
type=int,
default=3,
help="Maximum ngrams to match against input during prompt lookup",
)
parser.add_argument(
"--ngram-min",
type=int,
default=1,
help="Minimum ngrams to match against input during prompt lookup",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
parser.add_argument(
"--color", type=bool, default=False, help="Color the accepted draft tokens"
LeonEricsson marked this conversation as resolved.
Show resolved Hide resolved
)

args = parser.parse_args()

main(args)
Loading