Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Max Token Limit for Generation #1078

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions llms/mlx_lm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@ def setup_arg_parser():
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
"--max-tokens-per-sec",
type=int,
help="Maximum tokens to generate per second.",
default=None,
)
parser.add_argument(
"--max-tokens-per-sec",
type=int,
default=None,
help="Maximum tokens to generate per second",
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
Expand All @@ -56,7 +65,7 @@ def main():
tokenizer_config={"trust_remote_code": True},
)

print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
Expand All @@ -72,7 +81,9 @@ def main():
prompt,
temp=args.temp,
top_p=args.top_p,
max_tokens_per_sec=args.max_tokens_per_sec,
prompt_cache=prompt_cache,
max_tokens=4096 # Ensure this is set to a reasonable limit
):
print(response, flush=True, end="")
print()
Expand Down
11 changes: 9 additions & 2 deletions llms/mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from .models.cache import load_prompt_cache
from .utils import generate, load

DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_PROMPT = "Tell me a story!"
DEFAULT_MAX_TOKENS = 1000
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
Expand Down Expand Up @@ -61,6 +61,12 @@ def setup_arg_parser():
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--max-tokens-per-sec",
type=int,
default=None,
help="Maximum tokens to generate per second",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
Expand Down Expand Up @@ -227,6 +233,7 @@ def main():
top_p=args.top_p,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
max_tokens_per_sec=args.max_tokens_per_sec,
)
if not args.verbose:
print(response)
Expand Down
41 changes: 29 additions & 12 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
logits[:, tokens] = selected_logits
return logits


def generate_step(
prompt: mx.array,
model: nn.Module,
temp: float = 0.0,
max_tokens_per_sec: Optional[float] = None, # Add new parameter
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
Expand All @@ -145,8 +145,7 @@ def generate_step(
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
Expand All @@ -171,7 +170,6 @@ def generate_step(
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
"""

def sample(logits: mx.array) -> Tuple[mx.array, float]:
logprobs = logits - mx.logsumexp(logits)

Expand All @@ -193,16 +191,21 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]:
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)

if max_tokens_per_sec is not None:
if not isinstance(max_tokens_per_sec, (int, float)) or max_tokens_per_sec <= 0:
raise ValueError(
f"max_tokens_per_sec must be a positive number, got {max_tokens_per_sec}"
)

logits_processor = logits_processor or []
last_token_time = time.perf_counter() # Track time for rate limiting

if repetition_penalty:

def repetition_penalty_processor(tokens, logits):
return apply_repetition_penalty(
logits, tokens[-repetition_context_size:], repetition_penalty
)

logits_processor.append(repetition_penalty_processor)

if logit_bias:
Expand Down Expand Up @@ -247,18 +250,30 @@ def _step(y):
y, logprobs = _step(y)

mx.async_eval(y, logprobs)
last_target_time = time.perf_counter() # Track when we WANTED the last token

while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)

if max_tokens_per_sec is not None:
target_time = 1.0 / max_tokens_per_sec
last_target_time += target_time # When we want next token

# Sleep until target time if needed
sleep_time = last_target_time - time.perf_counter()
if sleep_time > 0:
time.sleep(sleep_time)

yield y.item(), logprobs
y, logprobs = next_y, next_logprobs


def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
max_tokens_per_sec: Optional[float] = None, # Add parameter
**kwargs,
) -> Union[str, Generator[str, None, None]]:
"""
Expand All @@ -267,13 +282,15 @@ def stream_generate(
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The ma
max_tokens (int): The maximum number of tokens. Default: ``100``.
max_tokens_per_sec (float, optional): If set, limits generation speed to approximately max_tokens_per_sec. May go slightly over this limit.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.

Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
"""

if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)

Expand All @@ -283,13 +300,11 @@ def stream_generate(
detokenizer.reset()
for n, (token, _) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
generate_step(prompt_tokens, model, max_tokens_per_sec=max_tokens_per_sec, **kwargs),
):
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)

# Yield the last segment if streaming
yield detokenizer.last_segment

detokenizer.finalize()
Expand All @@ -301,6 +316,7 @@ def generate(
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
max_tokens_per_sec: Optional[float] = None, # Add parameter
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
Expand All @@ -313,6 +329,7 @@ def generate(
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
max_tokens_per_sec (float, optional): If set, limits generation speed to approximately max_tokens_per_sec. May go slightly over this limit.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a
Expand All @@ -335,7 +352,7 @@ def generate(

for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
generate_step(prompt_tokens, model, max_tokens_per_sec=max_tokens_per_sec, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
Expand Down