diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 7968a868..9a7667cf 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -29,7 +29,16 @@ def setup_arg_parser(): help="Optional path for the trained adapter weights and config.", ) parser.add_argument( - "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" + "--max-tokens-per-sec", + type=int, + help="Maximum tokens to generate per second.", + default=None, + ) + parser.add_argument( + "--max-tokens-per-sec", + type=int, + default=None, + help="Maximum tokens to generate per second", ) parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" @@ -56,7 +65,7 @@ def main(): tokenizer_config={"trust_remote_code": True}, ) - print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") + print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: query = input(">> ") @@ -72,7 +81,9 @@ def main(): prompt, temp=args.temp, top_p=args.top_p, + max_tokens_per_sec=args.max_tokens_per_sec, prompt_cache=prompt_cache, + max_tokens=4096 # Ensure this is set to a reasonable limit ): print(response, flush=True, end="") print() diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0bf98ab2..b6862d7f 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -9,8 +9,8 @@ from .models.cache import load_prompt_cache from .utils import generate, load -DEFAULT_PROMPT = "hello" -DEFAULT_MAX_TOKENS = 100 +DEFAULT_PROMPT = "Tell me a story!" +DEFAULT_MAX_TOKENS = 1000 DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 @@ -61,6 +61,12 @@ def setup_arg_parser(): default=DEFAULT_MAX_TOKENS, help="Maximum number of tokens to generate", ) + parser.add_argument( + "--max-tokens-per-sec", + type=int, + default=None, + help="Maximum tokens to generate per second", + ) parser.add_argument( "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" ) @@ -227,6 +233,7 @@ def main(): top_p=args.top_p, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, + max_tokens_per_sec=args.max_tokens_per_sec, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 92741b68..c26cdeb0 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -123,11 +123,11 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) logits[:, tokens] = selected_logits return logits - def generate_step( prompt: mx.array, model: nn.Module, temp: float = 0.0, + max_tokens_per_sec: Optional[float] = None, # Add new parameter repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, top_p: float = 1.0, @@ -145,8 +145,7 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. + temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_context_size (int, optional): The number of tokens to @@ -171,7 +170,6 @@ def generate_step( Generator[Tuple[mx.array, mx.array], None, None]: A generator producing one token and a vector of log probabilities. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: logprobs = logits - mx.logsumexp(logits) @@ -193,16 +191,21 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: raise ValueError( f"repetition_penalty must be a non-negative float, got {repetition_penalty}" ) + + if max_tokens_per_sec is not None: + if not isinstance(max_tokens_per_sec, (int, float)) or max_tokens_per_sec <= 0: + raise ValueError( + f"max_tokens_per_sec must be a positive number, got {max_tokens_per_sec}" + ) logits_processor = logits_processor or [] + last_token_time = time.perf_counter() # Track time for rate limiting if repetition_penalty: - def repetition_penalty_processor(tokens, logits): return apply_repetition_penalty( logits, tokens[-repetition_context_size:], repetition_penalty ) - logits_processor.append(repetition_penalty_processor) if logit_bias: @@ -247,18 +250,30 @@ def _step(y): y, logprobs = _step(y) mx.async_eval(y, logprobs) + last_target_time = time.perf_counter() # Track when we WANTED the last token + while True: next_y, next_logprobs = _step(y) mx.async_eval(next_y, next_logprobs) + + if max_tokens_per_sec is not None: + target_time = 1.0 / max_tokens_per_sec + last_target_time += target_time # When we want next token + + # Sleep until target time if needed + sleep_time = last_target_time - time.perf_counter() + if sleep_time > 0: + time.sleep(sleep_time) + yield y.item(), logprobs y, logprobs = next_y, next_logprobs - def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, max_tokens: int = 100, + max_tokens_per_sec: Optional[float] = None, # Add parameter **kwargs, ) -> Union[str, Generator[str, None, None]]: """ @@ -267,13 +282,15 @@ def stream_generate( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - max_tokens (int): The ma + max_tokens (int): The maximum number of tokens. Default: ``100``. + max_tokens_per_sec (float, optional): If set, limits generation speed to approximately max_tokens_per_sec. May go slightly over this limit. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. Yields: Generator[Tuple[mx.array, mx.array]]: A generator producing text. """ + if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) @@ -283,13 +300,11 @@ def stream_generate( detokenizer.reset() for n, (token, _) in zip( range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt_tokens, model, max_tokens_per_sec=max_tokens_per_sec, **kwargs), ): if token == tokenizer.eos_token_id: break detokenizer.add_token(token) - - # Yield the last segment if streaming yield detokenizer.last_segment detokenizer.finalize() @@ -301,6 +316,7 @@ def generate( tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, max_tokens: int = 100, + max_tokens_per_sec: Optional[float] = None, # Add parameter verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, @@ -313,6 +329,7 @@ def generate( tokenizer (PreTrainedTokenizer): The tokenizer. prompt (str): The string prompt. max_tokens (int): The maximum number of tokens. Default: ``100``. + max_tokens_per_sec (float, optional): If set, limits generation speed to approximately max_tokens_per_sec. May go slightly over this limit. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. formatter (Optional[Callable]): A function which takes a token and a @@ -335,7 +352,7 @@ def generate( for n, (token, logprobs) in zip( range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt_tokens, model, max_tokens_per_sec=max_tokens_per_sec, **kwargs), ): if n == 0: prompt_time = time.perf_counter() - tic