Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Throttle API Calls #2283

Merged
merged 20 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from autogen.runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
from autogen.token_count_utils import count_token

from .rate_limiters import RateLimiter, TimeRateLimiter

TOOL_ENABLED = False
try:
import openai
Expand Down Expand Up @@ -207,7 +209,9 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
"""
iostream = IOStream.get_default()

completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
completions: Completions = (
self._oai_client.chat.completions if "messages" in params else self._oai_client.completions
) # type: ignore [attr-defined]
# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
response_contents = [""] * params.get("n", 1)
Expand Down Expand Up @@ -427,8 +431,11 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base

self._clients: List[ModelClient] = []
self._config_list: List[Dict[str, Any]] = []
self._rate_limiters: List[Optional[RateLimiter]] = []

if config_list:
self._initialize_rate_limiters(config_list)

config_list = [config.copy() for config in config_list] # make a copy before modifying
for config in config_list:
self._register_default_client(config, openai_config) # could modify the config
Expand Down Expand Up @@ -749,6 +756,7 @@ def yes_or_no_filter(context, response):
return response
continue # filter is not passed; try the next config
try:
self._throttle_api_calls(i)
request_ts = get_current_ts()
response = client.create(params)
except APITimeoutError as err:
Expand Down Expand Up @@ -1042,3 +1050,20 @@ def extract_text_or_completion_object(
A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
"""
return response.message_retrieval_function(response)

def _throttle_api_calls(self, idx: int) -> None:
"""Rate limit api calls."""
if self._rate_limiters[idx]:
limiter = self._rate_limiters[idx]

assert limiter is not None
limiter.sleep()

def _initialize_rate_limiters(self, config_list: List[Dict[str, Any]]) -> None:
for config in config_list:
# Instantiate the rate limiter
if "api_rate_limit" in config:
self._rate_limiters.append(TimeRateLimiter(config["api_rate_limit"]))
del config["api_rate_limit"]
else:
self._rate_limiters.append(None)
36 changes: 36 additions & 0 deletions autogen/oai/rate_limiters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import time
from typing import Protocol


class RateLimiter(Protocol):
def sleep(self, *args, **kwargs): ...


class TimeRateLimiter:
"""A class to implement a time-based rate limiter.

This rate limiter ensures that a certain operation does not exceed a specified frequency.
It can be used to limit the rate of requests sent to a server or the rate of any repeated action.
"""

def __init__(self, rate: float):
"""
Args:
rate (int): The frequency of the time-based rate limiter (NOT time interval).
"""
self._time_interval_seconds = 1.0 / rate
self._last_time_called = 0.0

def sleep(self, *args, **kwargs):
"""Synchronously waits until enough time has passed to allow the next operation.

If the elapsed time since the last operation is less than the required time interval,
this method will block the execution by sleeping for the remaining time.
"""
if self._elapsed_time() < self._time_interval_seconds:
time.sleep(self._time_interval_seconds - self._elapsed_time())

self._last_time_called = time.perf_counter()

def _elapsed_time(self):
return time.perf_counter() - self._last_time_called
66 changes: 64 additions & 2 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import sys
import time
from types import SimpleNamespace

import pytest

Expand Down Expand Up @@ -31,6 +32,40 @@
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"


class _MockClient:
def __init__(self, config, **kwargs):
pass

def create(self, params):
# can create my own data response class
# here using SimpleNamespace for simplicity
# as long as it adheres to the ModelClientResponseProtocol

response = SimpleNamespace()
response.choices = []
response.model = "mock_model"

text = "this is a dummy text response"
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)
return response

def message_retrieval(self, response):
choices = response.choices
return [choice.message.content for choice in choices]

def cost(self, response) -> float:
response.cost = 0
return 0

@staticmethod
def get_usage(response):
return {}


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_aoai_chat_completion():
config_list = config_list_from_json(
Expand Down Expand Up @@ -322,12 +357,39 @@ def test_cache():
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))


def test_throttled_api_calls():
# Api calling limited at 0.2 request per second, or 1 request per 5 seconds
rate = 1 / 5.0

config_list = [
{
"model": "mock_model",
"model_client_cls": "_MockClient",
# Adding a timeout to catch false positives
"timeout": 1 / rate,
"api_rate_limit": rate,
}
]

client = OpenAIWrapper(config_list=config_list, cache_seed=None)
client.register_model_client(_MockClient)

n_loops = 2
current_time = time.time()
for _ in range(n_loops):
client.create(messages=[{"role": "user", "content": "hello"}])

min_expected_time = (n_loops - 1) / rate
assert time.time() - current_time > min_expected_time


if __name__ == "__main__":
# test_aoai_chat_completion()
# test_oai_tool_calling_extraction()
# test_chat_completion()
test_completion()
# # test_cost()
# test_usage_summary()
# test_legacy_cache()
# test_cache()
test_legacy_cache()
test_cache()
test_throttled_api_calls()
21 changes: 21 additions & 0 deletions test/oai/test_rate_limiters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import time

import pytest

from autogen.oai.rate_limiters import TimeRateLimiter


@pytest.mark.parametrize("execute_n_times", range(5))
def test_time_rate_limiter(execute_n_times):
current_time_seconds = time.time()

rate = 1
rate_limiter = TimeRateLimiter(rate)

n_loops = 2
for _ in range(n_loops):
rate_limiter.sleep()

total_time = time.time() - current_time_seconds
min_expected_time = (n_loops - 1) / rate
assert total_time >= min_expected_time
10 changes: 9 additions & 1 deletion website/docs/FAQ.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ Yes. You currently have two options:
- Autogen can work with any API endpoint which complies with OpenAI-compatible RESTful APIs - e.g. serving local LLM via FastChat or LM Studio. Please check https://microsoft.github.io/autogen/blog/2023/07/14/Local-LLMs for an example.
- You can supply your own custom model implementation and use it with Autogen. Please check https://microsoft.github.io/autogen/blog/2024/01/26/Custom-Models for more information.

## Handle Rate Limit Error and Timeout Error
## Handling API Rate Limits
WaelKarkoub marked this conversation as resolved.
Show resolved Hide resolved

### Setting the API Rate Limit

You can set the `api_rate_limit` in a `config_list` for an agent, which will be used to control the rate at which API requests are sent.

- `api_rate_limit` (float): the maximum number of API requests allowed per second.

### Handle Rate Limit Error and Timeout Error

You can set `max_retries` to handle rate limit error. And you can set `timeout` to handle timeout error. They can all be specified in `llm_config` for an agent, which will be used in the OpenAI client for LLM inference. They can be set differently for different clients if they are set in the `config_list`.

Expand Down
4 changes: 4 additions & 0 deletions website/docs/topics/llm_configuration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
" <TabItem value=\"openai\" label=\"OpenAI\" default>\n",
" - `model` (str, required): The identifier of the model to be used, such as 'gpt-4', 'gpt-3.5-turbo'.\n",
" - `api_key` (str, optional): The API key required for authenticating requests to the model's API endpoint.\n",
" - `api_rate_limit` (float, optional): Specifies the maximum number of API requests permitted per second.\n",
" - `base_url` (str, optional): The base URL of the API endpoint. This is the root address where API calls are directed.\n",
" - `tags` (List[str], optional): Tags which can be used for filtering.\n",
"\n",
Expand All @@ -72,6 +73,7 @@
" {\n",
" \"model\": \"gpt-4\",\n",
" \"api_key\": os.environ['OPENAI_API_KEY']\n",
" \"api_rate_limit\": 60.0, // Set to allow up to 60 API requests per second.\n",
" }\n",
" ]\n",
" ```\n",
Expand All @@ -80,6 +82,7 @@
" - `model` (str, required): The deployment to be used. The model corresponds to the deployment name on Azure OpenAI.\n",
" - `api_key` (str, optional): The API key required for authenticating requests to the model's API endpoint.\n",
" - `api_type`: `azure`\n",
" - `api_rate_limit` (float, optional): Specifies the maximum number of API requests permitted per second.\n",
" - `base_url` (str, optional): The base URL of the API endpoint. This is the root address where API calls are directed.\n",
" - `api_version` (str, optional): The version of the Azure API you wish to use.\n",
" - `tags` (List[str], optional): Tags which can be used for filtering.\n",
Expand All @@ -100,6 +103,7 @@
" <TabItem value=\"other\" label=\"Other OpenAI compatible\">\n",
" - `model` (str, required): The identifier of the model to be used, such as 'llama-7B'.\n",
" - `api_key` (str, optional): The API key required for authenticating requests to the model's API endpoint.\n",
" - `api_rate_limit` (float, optional): Specifies the maximum number of API requests permitted per second.\n",
" - `base_url` (str, optional): The base URL of the API endpoint. This is the root address where API calls are directed.\n",
" - `tags` (List[str], optional): Tags which can be used for filtering.\n",
"\n",
Expand Down
Loading