Skip to content

Commit

Permalink
Standardized openai init params (#21739)
Browse files Browse the repository at this point in the history
## Patch Summary
community:openai[patch]: standardize init args

## Details
I made changes to the OpenAI Chat API wrapper test in the Langchain
open-source repository

- **File**: `libs/community/tests/unit_tests/chat_models/test_openai.py`
- **Changes**:
  - Updated `max_retries` with Pydantic Field
  - Updated the corresponding unit test
- **Related Issues**: #20085
  - Updated max_retries with Pydantic Field, updated the unit test.

---------

Co-authored-by: JuHyung Son <[email protected]>
  • Loading branch information
2 people authored and hinthornw committed Jun 20, 2024
1 parent c2bdbfb commit 4769810
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
3 changes: 2 additions & 1 deletion libs/community/langchain_community/chat_models/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""OpenAI chat wrapper."""

from __future__ import annotations

import logging
Expand Down Expand Up @@ -217,7 +218,7 @@ def is_lc_serializable(cls) -> bool:
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
max_retries: int = 2
max_retries: int = Field(default=2)
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
Expand Down
20 changes: 15 additions & 5 deletions libs/community/tests/unit_tests/chat_models/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test OpenAI Chat API wrapper."""

import json
from typing import Any
from typing import Any, List
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -17,10 +18,19 @@

@pytest.mark.requires("openai")
def test_openai_model_param() -> None:
llm = ChatOpenAI(model="foo", openai_api_key="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"
llm = ChatOpenAI(model_name="foo", openai_api_key="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"
test_cases: List[dict] = [
{"model_name": "foo", "openai_api_key": "foo"},
{"model": "foo", "openai_api_key": "foo"},
{"model_name": "foo", "api_key": "foo"},
{"model_name": "foo", "openai_api_key": "foo", "max_retries": 2},
]

for case in test_cases:
llm = ChatOpenAI(**case)
assert llm.model_name == "foo", "Model name should be 'foo'"
assert llm.openai_api_key == "foo", "API key should be 'foo'"
assert hasattr(llm, "max_retries"), "max_retries attribute should exist"
assert llm.max_retries == 2, "max_retries default should be set to 2"


def test_function_message_dict_to_function_message() -> None:
Expand Down

0 comments on commit 4769810

Please sign in to comment.