From 5f5287c3b00a477ee96f7786f74625c21b084624 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 3 Sep 2024 16:48:53 -0700 Subject: [PATCH] fmt --- .../langchain_openai/chat_models/azure.py | 20 +++++------ .../langchain_openai/chat_models/base.py | 34 +++++++------------ .../langchain_openai/embeddings/azure.py | 29 +++++++--------- .../langchain_openai/embeddings/base.py | 28 +++++---------- .../openai/langchain_openai/llms/azure.py | 15 ++++---- .../openai/langchain_openai/llms/base.py | 20 ++++------- .../chat_models/test_base.py | 2 +- .../openai/tests/unit_tests/fake/callbacks.py | 4 +-- 8 files changed, 60 insertions(+), 92 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index e5ae0550fd7ff..b3f09333c20a3 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -31,16 +31,15 @@ PydanticToolsParser, ) from langchain_core.outputs import ChatResult -from pydantic import BaseModel, Field, SecretStr, model_validator from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import from_env, secret_from_env from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.pydantic import is_basemodel_subclass - -from langchain_openai.chat_models.base import BaseChatOpenAI +from pydantic import BaseModel, Field, SecretStr, model_validator from typing_extensions import Self +from langchain_openai.chat_models.base import BaseChatOpenAI logger = logging.getLogger(__name__) @@ -604,19 +603,15 @@ def validate_environment(self) -> Self: "Or you can equivalently specify:\n\n" 'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"' ) - client_params = { + client_params: dict = { "api_version": self.openai_api_version, "azure_endpoint": self.azure_endpoint, "azure_deployment": self.deployment_name, "api_key": ( - self.openai_api_key.get_secret_value() - if self.openai_api_key - else None + self.openai_api_key.get_secret_value() if self.openai_api_key else None ), "azure_ad_token": ( - self.azure_ad_token.get_secret_value() - if self.azure_ad_token - else None + self.azure_ad_token.get_secret_value() if self.azure_ad_token else None ), "azure_ad_token_provider": self.azure_ad_token_provider, "organization": self.openai_organization, @@ -628,12 +623,13 @@ def validate_environment(self) -> Self: } if not (self.client or None): sync_specific = {"http_client": self.http_client} - self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) + self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type] self.client = self.root_client.chat.completions if not (self.async_client or None): async_specific = {"http_client": self.http_async_client} self.root_async_client = openai.AsyncAzureOpenAI( - **client_params, **async_specific + **client_params, + **async_specific, # type: ignore[arg-type] ) self.async_client = self.root_async_client.chat.completions return self diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 018d8febefbb3..f21670d5c2f88 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -73,11 +73,10 @@ parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from pydantic import BaseModel, Field, model_validator, SecretStr from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain from langchain_core.runnables.config import run_in_executor from langchain_core.tools import BaseTool -from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain_core.utils import get_pydantic_field_names from langchain_core.utils.function_calling import ( convert_to_openai_function, convert_to_openai_tool, @@ -88,10 +87,9 @@ is_basemodel_subclass, ) from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env -from pydantic import ConfigDict +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self - logger = logging.getLogger(__name__) @@ -361,8 +359,7 @@ class BaseChatOpenAI(BaseChatModel): model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" openai_api_key: Optional[SecretStr] = Field( - alias="api_key", - default_factory=secret_from_env("OPENAI_API_KEY", default=None), + alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) openai_api_base: Optional[str] = Field(default=None, alias="base_url") """Base URL path for API requests, leave blank if not using a proxy or service @@ -431,7 +428,7 @@ class BaseChatOpenAI(BaseChatModel): include_response_headers: bool = False """Whether to include response headers in the output message response_metadata.""" - model_config = ConfigDict(populate_by_name=True,) + model_config = ConfigDict(populate_by_name=True) @model_validator(mode="before") @classmethod @@ -458,14 +455,10 @@ def validate_environment(self) -> Self: or os.getenv("OPENAI_ORG_ID") or os.getenv("OPENAI_ORGANIZATION") ) - self.openai_api_base = self.openai_api_base or os.getenv( - "OPENAI_API_BASE" - ) - client_params = { + self.openai_api_base = self.openai_api_base or os.getenv("OPENAI_API_BASE") + client_params: dict = { "api_key": ( - self.openai_api_key.get_secret_value() - if self.openai_api_key - else None + self.openai_api_key.get_secret_value() if self.openai_api_key else None ), "organization": self.openai_organization, "base_url": self.openai_api_base, @@ -474,9 +467,7 @@ def validate_environment(self) -> Self: "default_headers": self.default_headers, "default_query": self.default_query, } - if self.openai_proxy and ( - self.http_client or self.http_async_client - ): + if self.openai_proxy and (self.http_client or self.http_async_client): openai_proxy = self.openai_proxy http_client = self.http_client http_async_client = self.http_async_client @@ -496,7 +487,7 @@ def validate_environment(self) -> Self: ) from e self.http_client = httpx.Client(proxy=self.openai_proxy) sync_specific = {"http_client": self.http_client} - self.root_client = openai.OpenAI(**client_params, **sync_specific) + self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type] self.client = self.root_client.chat.completions if not (self.async_client or None): if self.openai_proxy and not self.http_async_client: @@ -507,12 +498,11 @@ def validate_environment(self) -> Self: "Could not import httpx python package. " "Please install it with `pip install httpx`." ) from e - self.http_async_client = httpx.AsyncClient( - proxy=self.openai_proxy - ) + self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy) async_specific = {"http_client": self.http_async_client} self.root_async_client = openai.AsyncOpenAI( - **client_params, **async_specific + **client_params, + **async_specific, # type: ignore[arg-type] ) self.async_client = self.root_async_client.chat.completions return self diff --git a/libs/partners/openai/langchain_openai/embeddings/azure.py b/libs/partners/openai/langchain_openai/embeddings/azure.py index 6697cee53926e..a5e600b662b7b 100644 --- a/libs/partners/openai/langchain_openai/embeddings/azure.py +++ b/libs/partners/openai/langchain_openai/embeddings/azure.py @@ -2,15 +2,14 @@ from __future__ import annotations -from typing import Callable, Dict, Optional, Union +from typing import Callable, Optional, Union import openai -from pydantic import Field, SecretStr, root_validator, model_validator from langchain_core.utils import from_env, secret_from_env +from pydantic import Field, SecretStr, model_validator +from typing_extensions import Self, cast from langchain_openai.embeddings.base import OpenAIEmbeddings -from typing_extensions import Self - class AzureOpenAIEmbeddings(OpenAIEmbeddings): @@ -163,7 +162,7 @@ def validate_environment(self) -> Self: openai_api_base = self.openai_api_base if openai_api_base and self.validate_base_url: if "/openai" not in openai_api_base: - self.openai_api_base += "/openai" + self.openai_api_base = cast(str, self.openai_api_base) + "/openai" raise ValueError( "As of openai>=1.0.0, Azure endpoints should be specified via " "the `azure_endpoint` param not `openai_api_base` " @@ -177,19 +176,15 @@ def validate_environment(self) -> Self: "Instead use `deployment` (or alias `azure_deployment`) " "and `azure_endpoint`." ) - client_params = { + client_params: dict = { "api_version": self.openai_api_version, "azure_endpoint": self.azure_endpoint, "azure_deployment": self.deployment, "api_key": ( - self.openai_api_key.get_secret_value() - if self.openai_api_key - else None + self.openai_api_key.get_secret_value() if self.openai_api_key else None ), "azure_ad_token": ( - self.azure_ad_token.get_secret_value() - if self.azure_ad_token - else None + self.azure_ad_token.get_secret_value() if self.azure_ad_token else None ), "azure_ad_token_provider": self.azure_ad_token_provider, "organization": self.openai_organization, @@ -200,14 +195,16 @@ def validate_environment(self) -> Self: "default_query": self.default_query, } if not (self.client or None): - sync_specific = {"http_client": self.http_client} + sync_specific: dict = {"http_client": self.http_client} self.client = openai.AzureOpenAI( - **client_params, **sync_specific + **client_params, # type: ignore[arg-type] + **sync_specific, ).embeddings if not (self.async_client or None): - async_specific = {"http_client": self.http_async_client} + async_specific: dict = {"http_client": self.http_async_client} self.async_client = openai.AsyncAzureOpenAI( - **client_params, **async_specific + **client_params, # type: ignore[arg-type] + **async_specific, ).embeddings return self diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index e6294d32ac2a2..f58471d4bab8e 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -20,13 +20,10 @@ import openai import tiktoken from langchain_core.embeddings import Embeddings -from pydantic import BaseModel, Field, SecretStr, root_validator, model_validator from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env -from pydantic import ConfigDict +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self - - logger = logging.getLogger(__name__) @@ -267,7 +264,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """Whether to check the token length of inputs and automatically split inputs longer than embedding_ctx_length.""" - model_config = ConfigDict(extra="forbid",populate_by_name=True,) + model_config = ConfigDict(extra="forbid", populate_by_name=True) @model_validator(mode="before") @classmethod @@ -304,11 +301,9 @@ def validate_environment(self) -> Self: "If you are using Azure, " "please use the `AzureOpenAIEmbeddings` class." ) - client_params = { + client_params: dict = { "api_key": ( - self.openai_api_key.get_secret_value() - if self.openai_api_key - else None + self.openai_api_key.get_secret_value() if self.openai_api_key else None ), "organization": self.openai_organization, "base_url": self.openai_api_base, @@ -318,9 +313,7 @@ def validate_environment(self) -> Self: "default_query": self.default_query, } - if self.openai_proxy and ( - self.http_client or self.http_async_client - ): + if self.openai_proxy and (self.http_client or self.http_async_client): openai_proxy = self.openai_proxy http_client = self.http_client http_async_client = self.http_async_client @@ -340,9 +333,7 @@ def validate_environment(self) -> Self: ) from e self.http_client = httpx.Client(proxy=self.openai_proxy) sync_specific = {"http_client": self.http_client} - self.client = openai.OpenAI( - **client_params, **sync_specific - ).embeddings + self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type] if not (self.async_client or None): if self.openai_proxy and not self.http_async_client: try: @@ -352,12 +343,11 @@ def validate_environment(self) -> Self: "Could not import httpx python package. " "Please install it with `pip install httpx`." ) from e - self.http_async_client = httpx.AsyncClient( - proxy=self.openai_proxy - ) + self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy) async_specific = {"http_client": self.http_async_client} self.async_client = openai.AsyncOpenAI( - **client_params, **async_specific + **client_params, + **async_specific, # type: ignore[arg-type] ).embeddings return self diff --git a/libs/partners/openai/langchain_openai/llms/azure.py b/libs/partners/openai/langchain_openai/llms/azure.py index e48b1380ff79f..973fd9c271e63 100644 --- a/libs/partners/openai/langchain_openai/llms/azure.py +++ b/libs/partners/openai/langchain_openai/llms/azure.py @@ -5,12 +5,11 @@ import openai from langchain_core.language_models import LangSmithParams -from pydantic import Field, SecretStr, root_validator, model_validator from langchain_core.utils import from_env, secret_from_env +from pydantic import Field, SecretStr, model_validator +from typing_extensions import Self, cast from langchain_openai.llms.base import BaseOpenAI -from typing_extensions import Self - logger = logging.getLogger(__name__) @@ -117,7 +116,7 @@ def validate_environment(self) -> Self: if openai_api_base and self.validate_base_url: if "/openai" not in openai_api_base: self.openai_api_base = ( - self.openai_api_base.rstrip("/") + "/openai" + cast(str, self.openai_api_base).rstrip("/") + "/openai" ) raise ValueError( "As of openai>=1.0.0, Azure endpoints should be specified via " @@ -133,7 +132,7 @@ def validate_environment(self) -> Self: "and `azure_endpoint`." ) self.deployment_name = None - client_params = { + client_params: dict = { "api_version": self.openai_api_version, "azure_endpoint": self.azure_endpoint, "azure_deployment": self.deployment_name, @@ -154,12 +153,14 @@ def validate_environment(self) -> Self: if not (self.client or None): sync_specific = {"http_client": self.http_client} self.client = openai.AzureOpenAI( - **client_params, **sync_specific + **client_params, + **sync_specific, # type: ignore[arg-type] ).completions if not (self.async_client or None): async_specific = {"http_client": self.http_async_client} self.async_client = openai.AsyncAzureOpenAI( - **client_params, **async_specific + **client_params, + **async_specific, # type: ignore[arg-type] ).completions return self diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index 0db987c5e0d5a..248e7657466a0 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -26,14 +26,11 @@ ) from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from pydantic import Field, SecretStr, root_validator, model_validator from langchain_core.utils import get_pydantic_field_names from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env -from pydantic import ConfigDict +from pydantic import ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self - - logger = logging.getLogger(__name__) @@ -156,7 +153,7 @@ class BaseOpenAI(BaseLLM): """Optional additional JSON properties to include in the request parameters when making requests to OpenAI compatible APIs, such as vLLM.""" - model_config = ConfigDict(populate_by_name=True,) + model_config = ConfigDict(populate_by_name=True) @model_validator(mode="before") @classmethod @@ -179,11 +176,9 @@ def validate_environment(self) -> Self: if self.streaming and self.best_of > 1: raise ValueError("Cannot stream results when best_of > 1.") - client_params = { + client_params: dict = { "api_key": ( - self.openai_api_key.get_secret_value() - if self.openai_api_key - else None + self.openai_api_key.get_secret_value() if self.openai_api_key else None ), "organization": self.openai_organization, "base_url": self.openai_api_base, @@ -194,13 +189,12 @@ def validate_environment(self) -> Self: } if not (self.client or None): sync_specific = {"http_client": self.http_client} - self.client = openai.OpenAI( - **client_params, **sync_specific - ).completions + self.client = openai.OpenAI(**client_params, **sync_specific).completions # type: ignore[arg-type] if not (self.async_client or None): async_specific = {"http_client": self.http_async_client} self.async_client = openai.AsyncOpenAI( - **client_params, **async_specific + **client_params, + **async_specific, # type: ignore[arg-type] ).completions return self diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 6f2a39bffdb57..96f32c754f68b 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -20,13 +20,13 @@ ) from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_core.prompts import ChatPromptTemplate -from pydantic import BaseModel, Field from langchain_standard_tests.integration_tests.chat_models import ( _validate_tool_call_message, ) from langchain_standard_tests.integration_tests.chat_models import ( magic_function as invalid_magic_function, ) +from pydantic import BaseModel, Field from langchain_openai import ChatOpenAI from tests.unit_tests.fake.callbacks import FakeCallbackHandler diff --git a/libs/partners/openai/tests/unit_tests/fake/callbacks.py b/libs/partners/openai/tests/unit_tests/fake/callbacks.py index e786601aaa764..d4b8d4b2c256b 100644 --- a/libs/partners/openai/tests/unit_tests/fake/callbacks.py +++ b/libs/partners/openai/tests/unit_tests/fake/callbacks.py @@ -188,7 +188,7 @@ def on_retriever_end(self, *args: Any, **kwargs: Any) -> Any: def on_retriever_error(self, *args: Any, **kwargs: Any) -> Any: self.on_retriever_error_common() - def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": + def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore[override] return self @@ -266,5 +266,5 @@ async def on_agent_finish(self, *args: Any, **kwargs: Any) -> None: async def on_text(self, *args: Any, **kwargs: Any) -> None: self.on_text_common() - def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": + def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore[override] return self