From 9bb0075a82b3f33e1927d45aa024b9dec3c8ccbc Mon Sep 17 00:00:00 2001 From: David Lapsley Date: Sat, 26 Apr 2025 18:07:17 -0400 Subject: [PATCH] Add ProviderError exception and handle missing LLM response This commit introduces a new ProviderError exception to handle cases where the LLM provider response lacks the expected 'choices' attribute. It ensures more robust error handling and improves the reliability of the system by catching and raising explicit errors for invalid responses. --- src/agents/exceptions.py | 4 +++ src/agents/models/openai_chatcompletions.py | 6 +++- tests/test_openai_chatcompletions.py | 40 +++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 78898f01..57cbef70 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -61,3 +61,7 @@ def __init__(self, guardrail_result: "OutputGuardrailResult"): super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" ) + + +class ProviderError(AgentsException): + """Exception raised when the provider fails.""" diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 89619f83..9d1ee7d0 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -12,6 +12,7 @@ from .. import _debug from ..agent_output import AgentOutputSchemaBase +from ..exceptions import ProviderError from ..handoffs import Handoff from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent from ..logger import logger @@ -70,6 +71,9 @@ async def get_response( stream=False, ) + if not getattr(response, "choices", None): + raise ProviderError(f"LLM provider error: {getattr(response, 'error', 'unknown')}") + if _debug.DONT_LOG_MODEL_DATA: logger.debug("Received model response") else: @@ -252,7 +256,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index ba3ec68d..84ffe992 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -1,5 +1,6 @@ from __future__ import annotations +import importlib from collections.abc import AsyncIterator from typing import Any @@ -30,6 +31,7 @@ OpenAIProvider, generation_span, ) +from agents.exceptions import ProviderError from agents.models.chatcmpl_helpers import ChatCmplHelpers from agents.models.fake_id import FAKE_RESPONSES_ID @@ -330,3 +332,41 @@ def test_store_param(): assert ChatCmplHelpers.get_store_param(client, model_settings) is True, ( "Should respect explicitly set store=True" ) + + +@pytest.mark.asyncio +async def test_get_response_raises_provider_error_if_no_choices(monkeypatch): + # Import the class under test _inside_ the function so + # pytest’s conftest autouse fixtures don’t stomp it out. + import agents.models.openai_chatcompletions as chatmod + + chatmod = importlib.reload(chatmod) + + ModelClass = chatmod.OpenAIChatCompletionsModel + + dummy_client = AsyncOpenAI(api_key="fake", base_url="http://localhost") + model = ModelClass(model="test-model", openai_client=dummy_client) + + class FakeResponse: + choices = [] + error = "service unavailable" + + async def fake_fetch_response(*args, **kwargs): + return FakeResponse() + + monkeypatch.setattr(ModelClass, "_fetch_response", fake_fetch_response) + + settings = ModelSettings(temperature=0.0, max_tokens=1) + with pytest.raises(ProviderError) as exc: + await model.get_response( + system_instructions="", + input="Hello?", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + assert "service unavailable" in str(exc.value)