Skip to content

Enhance openai-agents with Model Improvements, Retry Logic, and Caching #450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/agents/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Model implementations and utilities for working with language models."""

from ._openai_shared import (
TOpenAIClient,
create_client,
get_default_openai_client,
get_default_openai_key,
get_use_responses_by_default,
set_default_openai_client,
set_default_openai_key,
set_use_responses_by_default,
)
from .interface import Model, ModelProvider, ModelRetrySettings, ModelTracing
from .openai_chatcompletions import OpenAIChatCompletionsModel
from .openai_provider import OpenAIProvider
from .openai_responses import OpenAIResponsesModel
from .utils import (
cache_model_response,
clear_cache,
compute_cache_key,
get_token_count_estimate,
set_cache_ttl,
validate_response,
)

__all__ = [
# Interface
"Model",
"ModelProvider",
"ModelRetrySettings",
"ModelTracing",

# OpenAI utilities
"get_default_openai_client",
"get_default_openai_key",
"get_use_responses_by_default",
"set_default_openai_client",
"set_default_openai_key",
"set_use_responses_by_default",
"TOpenAIClient",
"create_client",

# Model implementations
"OpenAIChatCompletionsModel",
"OpenAIProvider",
"OpenAIResponsesModel",

# Caching and utilities
"cache_model_response",
"clear_cache",
"compute_cache_key",
"get_token_count_estimate",
"set_cache_ttl",
"validate_response",
]
74 changes: 71 additions & 3 deletions src/agents/models/_openai_shared.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,102 @@
from __future__ import annotations

import logging
from typing import Any, TypeAlias

from openai import AsyncOpenAI

# Type aliases for common OpenAI types
TOpenAIClient: TypeAlias = AsyncOpenAI
TOpenAIClientOptions: TypeAlias = dict[str, Any]

_default_openai_key: str | None = None
_default_openai_client: AsyncOpenAI | None = None
_default_openai_client: TOpenAIClient | None = None
_use_responses_by_default: bool = True
_logger = logging.getLogger(__name__)


def set_default_openai_key(key: str) -> None:
"""Set the default OpenAI API key to use when creating clients.

Args:
key: The OpenAI API key
"""
global _default_openai_key
_default_openai_key = key


def get_default_openai_key() -> str | None:
"""Get the default OpenAI API key.

Returns:
The default OpenAI API key, or None if not set
"""
return _default_openai_key


def set_default_openai_client(client: AsyncOpenAI) -> None:
def set_default_openai_client(client: TOpenAIClient) -> None:
"""Set the default OpenAI client to use.

Args:
client: The OpenAI client instance
"""
global _default_openai_client
_default_openai_client = client


def get_default_openai_client() -> AsyncOpenAI | None:
def get_default_openai_client() -> TOpenAIClient | None:
"""Get the default OpenAI client.

Returns:
The default OpenAI client, or None if not set
"""
return _default_openai_client


def set_use_responses_by_default(use_responses: bool) -> None:
"""Set whether to use the Responses API by default.

Args:
use_responses: Whether to use the Responses API
"""
global _use_responses_by_default
_use_responses_by_default = use_responses


def get_use_responses_by_default() -> bool:
"""Get whether to use the Responses API by default.

Returns:
Whether to use the Responses API by default
"""
return _use_responses_by_default


def create_client(
api_key: str | None = None,
base_url: str | None = None,
organization: str | None = None,
project: str | None = None,
http_client: Any = None,
) -> TOpenAIClient:
"""Create a new OpenAI client with the given parameters.

This is a utility function to standardize client creation across the codebase.

Args:
api_key: The API key to use. If not provided, uses the default.
base_url: The base URL to use. If not provided, uses the default.
organization: The organization to use.
project: The project to use.
http_client: The HTTP client to use.

Returns:
A new OpenAI client
"""
return AsyncOpenAI(
api_key=api_key or get_default_openai_key(),
base_url=base_url,
organization=organization,
project=project,
http_client=http_client,
)
74 changes: 73 additions & 1 deletion src/agents/models/interface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import abc
import asyncio
import enum
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable

from ..agent_output import AgentOutputSchema
from ..handoffs import Handoff
Expand Down Expand Up @@ -31,6 +33,76 @@ def include_data(self) -> bool:
return self == ModelTracing.ENABLED


@dataclass
class ModelRetrySettings:
"""Settings for retrying model calls on failure.

This class helps manage backoff and retry logic when API calls fail.
"""

max_retries: int = 3
"""Maximum number of retries to attempt."""

initial_backoff_seconds: float = 1.0
"""Initial backoff time in seconds before the first retry."""

max_backoff_seconds: float = 30.0
"""Maximum backoff time in seconds between retries."""

backoff_multiplier: float = 2.0
"""Multiplier for backoff time after each retry."""

retryable_status_codes: list[int] = field(default_factory=lambda: [429, 500, 502, 503, 504])
"""HTTP status codes that should trigger a retry."""

async def execute_with_retry(
self,
operation: Callable[[], Any],
should_retry: Callable[[Exception], bool] | None = None
) -> Any:
"""Execute an operation with retry logic.

Args:
operation: Async function to execute
should_retry: Optional function to determine if an exception should trigger a retry

Returns:
The result of the operation if successful

Raises:
The last exception encountered if all retries fail
"""
last_exception = None
backoff = self.initial_backoff_seconds

for attempt in range(self.max_retries + 1):
try:
return await operation()
except Exception as e:
last_exception = e

# Check if we should retry
if attempt >= self.max_retries:
break

should_retry_exception = True
if should_retry is not None:
should_retry_exception = should_retry(e)

if not should_retry_exception:
break

# Wait before retrying
await asyncio.sleep(backoff)
backoff = min(backoff * self.backoff_multiplier, self.max_backoff_seconds)

if last_exception:
raise last_exception

# This should never happen, but just in case
raise RuntimeError("Retry logic failed in an unexpected way")


class Model(abc.ABC):
"""The base interface for calling an LLM."""

Expand Down
50 changes: 37 additions & 13 deletions src/agents/models/openai_provider.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import logging

import httpx
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from openai import DefaultAsyncHttpxClient, OpenAIError

from . import _openai_shared
from ._openai_shared import TOpenAIClient, create_client
from .interface import Model, ModelProvider
from .openai_chatcompletions import OpenAIChatCompletionsModel
from .openai_responses import OpenAIResponsesModel

DEFAULT_MODEL: str = "gpt-4o"

_logger = logging.getLogger(__name__)

_http_client: httpx.AsyncClient | None = None

Expand All @@ -29,10 +32,11 @@ def __init__(
*,
api_key: str | None = None,
base_url: str | None = None,
openai_client: AsyncOpenAI | None = None,
openai_client: TOpenAIClient | None = None,
organization: str | None = None,
project: str | None = None,
use_responses: bool | None = None,
default_model: str = DEFAULT_MODEL,
) -> None:
"""Create a new OpenAI provider.

Expand All @@ -46,12 +50,13 @@ def __init__(
organization: The organization to use for the OpenAI client.
project: The project to use for the OpenAI client.
use_responses: Whether to use the OpenAI responses API.
default_model: The default model to use if none is specified.
"""
if openai_client is not None:
assert api_key is None and base_url is None, (
"Don't provide api_key or base_url if you provide openai_client"
)
self._client: AsyncOpenAI | None = openai_client
self._client: TOpenAIClient | None = openai_client
else:
self._client = None
self._stored_api_key = api_key
Expand All @@ -64,23 +69,42 @@ def __init__(
else:
self._use_responses = _openai_shared.get_use_responses_by_default()

self._default_model = default_model

# We lazy load the client in case you never actually use OpenAIProvider(). Otherwise
# AsyncOpenAI() raises an error if you don't have an API key set.
def _get_client(self) -> AsyncOpenAI:
def _get_client(self) -> TOpenAIClient:
if self._client is None:
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
api_key=self._stored_api_key or _openai_shared.get_default_openai_key(),
base_url=self._stored_base_url,
organization=self._stored_organization,
project=self._stored_project,
http_client=shared_http_client(),
)
default_client = _openai_shared.get_default_openai_client()
if default_client:
self._client = default_client
else:
try:
self._client = create_client(
api_key=self._stored_api_key,
base_url=self._stored_base_url,
organization=self._stored_organization,
project=self._stored_project,
http_client=shared_http_client(),
)
except OpenAIError as e:
_logger.error(f"Failed to create OpenAI client: {e}")
raise

return self._client

def get_model(self, model_name: str | None) -> Model:
"""Get a model instance by name.

Args:
model_name: The name of the model to get. If None, uses the default model.

Returns:
An OpenAI model implementation (either Responses or ChatCompletions
based on configuration)
"""
if model_name is None:
model_name = DEFAULT_MODEL
model_name = self._default_model

client = self._get_client()

Expand Down
Loading