From 7e4413b1dd71026fdff207fd935321710e34a084 Mon Sep 17 00:00:00 2001 From: N8 Date: Thu, 31 Oct 2024 02:20:55 -0400 Subject: [PATCH 1/2] add max token limit --- llms/mlx_lm/chat.py | 15 +++++++-- llms/mlx_lm/generate.py | 11 +++++-- llms/mlx_lm/utils.py | 73 +++++++++++++++++++++++++++++++---------- 3 files changed, 77 insertions(+), 22 deletions(-) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 7968a868..9a7667cf 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -29,7 +29,16 @@ def setup_arg_parser(): help="Optional path for the trained adapter weights and config.", ) parser.add_argument( - "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" + "--max-tokens-per-sec", + type=int, + help="Maximum tokens to generate per second.", + default=None, + ) + parser.add_argument( + "--max-tokens-per-sec", + type=int, + default=None, + help="Maximum tokens to generate per second", ) parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" @@ -56,7 +65,7 @@ def main(): tokenizer_config={"trust_remote_code": True}, ) - print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") + print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: query = input(">> ") @@ -72,7 +81,9 @@ def main(): prompt, temp=args.temp, top_p=args.top_p, + max_tokens_per_sec=args.max_tokens_per_sec, prompt_cache=prompt_cache, + max_tokens=4096 # Ensure this is set to a reasonable limit ): print(response, flush=True, end="") print() diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0bf98ab2..b6862d7f 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -9,8 +9,8 @@ from .models.cache import load_prompt_cache from .utils import generate, load -DEFAULT_PROMPT = "hello" -DEFAULT_MAX_TOKENS = 100 +DEFAULT_PROMPT = "Tell me a story!" +DEFAULT_MAX_TOKENS = 1000 DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 @@ -61,6 +61,12 @@ def setup_arg_parser(): default=DEFAULT_MAX_TOKENS, help="Maximum number of tokens to generate", ) + parser.add_argument( + "--max-tokens-per-sec", + type=int, + default=None, + help="Maximum tokens to generate per second", + ) parser.add_argument( "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" ) @@ -227,6 +233,7 @@ def main(): top_p=args.top_p, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, + max_tokens_per_sec=args.max_tokens_per_sec, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 92741b68..720151e9 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -123,11 +123,11 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) logits[:, tokens] = selected_logits return logits - def generate_step( prompt: mx.array, model: nn.Module, temp: float = 0.0, + max_tokens_per_sec: Optional[float] = None, # Add new parameter repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, top_p: float = 1.0, @@ -145,9 +145,8 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating + temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. + repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. Default: ``20``. @@ -171,7 +170,6 @@ def generate_step( Generator[Tuple[mx.array, mx.array], None, None]: A generator producing one token and a vector of log probabilities. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: logprobs = logits - mx.logsumexp(logits) @@ -193,16 +191,21 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: raise ValueError( f"repetition_penalty must be a non-negative float, got {repetition_penalty}" ) + + if max_tokens_per_sec is not None: + if not isinstance(max_tokens_per_sec, (int, float)) or max_tokens_per_sec <= 0: + raise ValueError( + f"max_tokens_per_sec must be a positive number, got {max_tokens_per_sec}" + ) logits_processor = logits_processor or [] + last_token_time = time.perf_counter() # Track time for rate limiting if repetition_penalty: - def repetition_penalty_processor(tokens, logits): return apply_repetition_penalty( logits, tokens[-repetition_context_size:], repetition_penalty ) - logits_processor.append(repetition_penalty_processor) if logit_bias: @@ -247,33 +250,67 @@ def _step(y): y, logprobs = _step(y) mx.async_eval(y, logprobs) + last_target_time = time.perf_counter() # Track when we WANTED the last token + while True: next_y, next_logprobs = _step(y) mx.async_eval(next_y, next_logprobs) + + if max_tokens_per_sec is not None: + target_time = 1.0 / max_tokens_per_sec + last_target_time += target_time # When we want next token + + # Sleep until target time if needed + sleep_time = last_target_time - time.perf_counter() + if sleep_time > 0: + time.sleep(sleep_time) + yield y.item(), logprobs y, logprobs = next_y, next_logprobs - def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, max_tokens: int = 100, + max_tokens_per_sec: Optional[float] = None, # Add parameter **kwargs, ) -> Union[str, Generator[str, None, None]]: """ - A generator producing text based on the given prompt from the model. + A generator producing token ids based on the given prompt from the model. Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - max_tokens (int): The ma - kwargs: The remaining options get passed to :func:`generate_step`. - See :func:`generate_step` for more details. - + temp (float): The temperature for sampling, if 0 the argmax is used. + Default: ``0``. + repetition_penalty (float, optional): The penalty factor for repeating + tokens. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. + min_p (float, optional): The minimum value (scaled by the top token's + probability) that a token probability must have to be considered. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered by min_p sampling. + prefill_step_size (int): Step size for processing the prompt. + max_kv_size (int, optional): Maximum size of the key-value cache. Old + entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. + logit_bias (dictionary, optional): Additive logit bias. + logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. + max_tokens_per_sec (float, optional): If set, limits generation speed to approximately + this many tokens per second by adding delays between tokens. Useful for thermal/power + management. Default: None (no limit). Yields: - Generator[Tuple[mx.array, mx.array]]: A generator producing text. + Generator[Tuple[mx.array, mx.array], None, None]: A generator producing + one token and a vector of log probabilities. """ + if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) @@ -283,13 +320,11 @@ def stream_generate( detokenizer.reset() for n, (token, _) in zip( range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt_tokens, model, max_tokens_per_sec=max_tokens_per_sec, **kwargs), ): if token == tokenizer.eos_token_id: break detokenizer.add_token(token) - - # Yield the last segment if streaming yield detokenizer.last_segment detokenizer.finalize() @@ -301,6 +336,7 @@ def generate( tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, max_tokens: int = 100, + max_tokens_per_sec: Optional[float] = None, # Add parameter verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, @@ -313,6 +349,7 @@ def generate( tokenizer (PreTrainedTokenizer): The tokenizer. prompt (str): The string prompt. max_tokens (int): The maximum number of tokens. Default: ``100``. + max_tokens_per_sec (float, optional): If set, limits generation speed to approximately max_tokens_per_sec. May go slightly over this limit. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. formatter (Optional[Callable]): A function which takes a token and a @@ -335,7 +372,7 @@ def generate( for n, (token, logprobs) in zip( range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt_tokens, model, max_tokens_per_sec=max_tokens_per_sec, **kwargs), ): if n == 0: prompt_time = time.perf_counter() - tic From e6d35301bde6f0a046a3e5ad38c943d797b0c6e1 Mon Sep 17 00:00:00 2001 From: N8 Date: Thu, 31 Oct 2024 02:37:14 -0400 Subject: [PATCH 2/2] smol modification --- llms/mlx_lm/utils.py | 36 ++++++++---------------------------- 1 file changed, 8 insertions(+), 28 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 720151e9..c26cdeb0 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -146,7 +146,7 @@ def generate_step( prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating + repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. Default: ``20``. @@ -277,38 +277,18 @@ def stream_generate( **kwargs, ) -> Union[str, Generator[str, None, None]]: """ - A generator producing token ids based on the given prompt from the model. + A generator producing text based on the given prompt from the model. Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating - tokens. - repetition_context_size (int, optional): The number of tokens to - consider for repetition penalty. Default: ``20``. - top_p (float, optional): Nulceus sampling, higher means model considers - more less likely words. - min_p (float, optional): The minimum value (scaled by the top token's - probability) that a token probability must have to be considered. - min_tokens_to_keep (int, optional): Minimum number of tokens that cannot - be filtered by min_p sampling. - prefill_step_size (int): Step size for processing the prompt. - max_kv_size (int, optional): Maximum size of the key-value cache. Old - entries (except the first 4 tokens) will be overwritten. - prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if - provided, the cache will be updated in place. - logit_bias (dictionary, optional): Additive logit bias. - logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. - max_tokens_per_sec (float, optional): If set, limits generation speed to approximately - this many tokens per second by adding delays between tokens. Useful for thermal/power - management. Default: None (no limit). + max_tokens (int): The maximum number of tokens. Default: ``100``. + max_tokens_per_sec (float, optional): If set, limits generation speed to approximately max_tokens_per_sec. May go slightly over this limit. + kwargs: The remaining options get passed to :func:`generate_step`. + See :func:`generate_step` for more details. + Yields: - Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + Generator[Tuple[mx.array, mx.array]]: A generator producing text. """ if not isinstance(tokenizer, TokenizerWrapper):