diff --git a/prompt2model/utils/api_tools.py b/prompt2model/utils/api_tools.py index bce8a1dd7..5b160ab88 100644 --- a/prompt2model/utils/api_tools.py +++ b/prompt2model/utils/api_tools.py @@ -72,6 +72,7 @@ def generate_one_completion( temperature: float = 0, presence_penalty: float = 0, frequency_penalty: float = 0, + token_buffer: int = 300, ) -> openai.Completion: """Generate a chat completion using an API-based model. @@ -86,6 +87,11 @@ def generate_one_completion( frequency_penalty: Float between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood of repeating the same line verbatim. + token_buffer: Number of tokens below the LLM's limit to generate. In case + our tokenizer does not exactly match the LLM API service's perceived + number of tokens, this prevents service errors. On the other hand, this + may lead to generating fewer tokens in the completion than is actually + possible. Returns: An OpenAI-like response object if there were no errors in generation. @@ -93,9 +99,9 @@ def generate_one_completion( """ num_prompt_tokens = count_tokens_from_string(prompt) if self.max_tokens: - max_tokens = self.max_tokens - num_prompt_tokens + max_tokens = self.max_tokens - num_prompt_tokens - token_buffer else: - max_tokens = 4 * num_prompt_tokens + max_tokens = 3 * num_prompt_tokens response = completion( # completion gets the key from os.getenv model=self.model_name, @@ -116,6 +122,7 @@ async def generate_batch_completion( temperature: float = 1, responses_per_request: int = 5, requests_per_minute: int = 80, + token_buffer: int = 300, ) -> list[openai.Completion]: """Generate a batch responses from OpenAI Chat Completion API. @@ -126,6 +133,11 @@ async def generate_batch_completion( responses_per_request: Number of responses for each request. i.e. the parameter n of API call. requests_per_minute: Number of requests per minute to allow. + token_buffer: Number of tokens below the LLM's limit to generate. In case + our tokenizer does not exactly match the LLM API service's perceived + number of tokens, this prevents service errors. On the other hand, this + may lead to generating fewer tokens in the completion than is actually + possible. Returns: List of generated responses. @@ -183,9 +195,9 @@ async def _throttled_completion_acreate( num_prompt_tokens = max(count_tokens_from_string(prompt) for prompt in prompts) if self.max_tokens: - max_tokens = self.max_tokens - num_prompt_tokens + max_tokens = self.max_tokens - num_prompt_tokens - token_buffer else: - max_tokens = 4 * num_prompt_tokens + max_tokens = 3 * num_prompt_tokens async_responses = [ _throttled_completion_acreate( diff --git a/test_helpers/mock_api.py b/test_helpers/mock_api.py index 5925d52ec..abec258ed 100644 --- a/test_helpers/mock_api.py +++ b/test_helpers/mock_api.py @@ -196,6 +196,7 @@ def generate_one_completion( temperature: float = 0, presence_penalty: float = 0, frequency_penalty: float = 0, + token_buffer: int = 300, ) -> openai.Completion: """Return a mocked object and increment the counter.""" self.generate_one_call_counter += 1 @@ -207,6 +208,7 @@ async def generate_batch_completion( temperature: float = 1, responses_per_request: int = 5, requests_per_minute: int = 80, + token_buffer: int = 300, ) -> list[openai.Completion]: """Return a mocked object and increment the counter.""" self.generate_batch_call_counter += 1