From 476981022c8340815685b53369e8c16605563b97 Mon Sep 17 00:00:00 2001 From: Kyle Cassidy <159751176+kyle-cassidy@users.noreply.github.com> Date: Thu, 16 May 2024 12:30:52 -0400 Subject: [PATCH] Standardized openai init params (#21739) ## Patch Summary community:openai[patch]: standardize init args ## Details I made changes to the OpenAI Chat API wrapper test in the Langchain open-source repository - **File**: `libs/community/tests/unit_tests/chat_models/test_openai.py` - **Changes**: - Updated `max_retries` with Pydantic Field - Updated the corresponding unit test - **Related Issues**: #20085 - Updated max_retries with Pydantic Field, updated the unit test. --------- Co-authored-by: JuHyung Son --- .../langchain_community/chat_models/openai.py | 3 ++- .../unit_tests/chat_models/test_openai.py | 20 ++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/chat_models/openai.py b/libs/community/langchain_community/chat_models/openai.py index 515570c3af3f4..daea2c1050d81 100644 --- a/libs/community/langchain_community/chat_models/openai.py +++ b/libs/community/langchain_community/chat_models/openai.py @@ -1,4 +1,5 @@ """OpenAI chat wrapper.""" + from __future__ import annotations import logging @@ -217,7 +218,7 @@ def is_lc_serializable(cls) -> bool: ) """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or None.""" - max_retries: int = 2 + max_retries: int = Field(default=2) """Maximum number of retries to make when generating.""" streaming: bool = False """Whether to stream the results or not.""" diff --git a/libs/community/tests/unit_tests/chat_models/test_openai.py b/libs/community/tests/unit_tests/chat_models/test_openai.py index 8315821ed3c7f..4cb3efc6d8e74 100644 --- a/libs/community/tests/unit_tests/chat_models/test_openai.py +++ b/libs/community/tests/unit_tests/chat_models/test_openai.py @@ -1,6 +1,7 @@ """Test OpenAI Chat API wrapper.""" + import json -from typing import Any +from typing import Any, List from unittest.mock import MagicMock, patch import pytest @@ -17,10 +18,19 @@ @pytest.mark.requires("openai") def test_openai_model_param() -> None: - llm = ChatOpenAI(model="foo", openai_api_key="foo") # type: ignore[call-arg] - assert llm.model_name == "foo" - llm = ChatOpenAI(model_name="foo", openai_api_key="foo") # type: ignore[call-arg] - assert llm.model_name == "foo" + test_cases: List[dict] = [ + {"model_name": "foo", "openai_api_key": "foo"}, + {"model": "foo", "openai_api_key": "foo"}, + {"model_name": "foo", "api_key": "foo"}, + {"model_name": "foo", "openai_api_key": "foo", "max_retries": 2}, + ] + + for case in test_cases: + llm = ChatOpenAI(**case) + assert llm.model_name == "foo", "Model name should be 'foo'" + assert llm.openai_api_key == "foo", "API key should be 'foo'" + assert hasattr(llm, "max_retries"), "max_retries attribute should exist" + assert llm.max_retries == 2, "max_retries default should be set to 2" def test_function_message_dict_to_function_message() -> None: