diff --git a/libs/community/langchain_community/embeddings/localai.py b/libs/community/langchain_community/embeddings/localai.py index b5a926e8fe2ca..0ccd475050dce 100644 --- a/libs/community/langchain_community/embeddings/localai.py +++ b/libs/community/langchain_community/embeddings/localai.py @@ -18,6 +18,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain_community.utils.openai import is_openai_v1 from tenacity import ( AsyncRetrying, before_sleep_log, @@ -42,11 +43,11 @@ def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], An stop=stop_after_attempt(embeddings.max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) + retry_if_exception_type(openai.Timeout) + | retry_if_exception_type(openai.APIError) + | retry_if_exception_type(openai.APIConnectionError) + | retry_if_exception_type(openai.RateLimitError) + | retry_if_exception_type(openai.InternalServerError) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -64,11 +65,11 @@ def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any: stop=stop_after_attempt(embeddings.max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) + retry_if_exception_type(openai.Timeout) + | retry_if_exception_type(openai.APIError) + | retry_if_exception_type(openai.APIConnectionError) + | retry_if_exception_type(openai.RateLimitError) + | retry_if_exception_type(openai.InternalServerError) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -86,10 +87,10 @@ async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: # https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings def _check_response(response: dict) -> dict: - if any(len(d["embedding"]) == 1 for d in response["data"]): + if any([len(d.embedding) == 1 for d in response.data]): import openai - raise openai.error.APIError("LocalAI API returned an empty embedding") + raise openai.APIError("LocalAI API returned an empty embedding") return response @@ -110,7 +111,7 @@ async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) - @_async_retry_decorator(embeddings) async def _async_embed_with_retry(**kwargs: Any) -> Any: - response = await embeddings.client.acreate(**kwargs) + response = await embeddings.async_client.acreate(**kwargs) return _check_response(response) return await _async_embed_with_retry(**kwargs) @@ -137,24 +138,27 @@ class LocalAIEmbeddings(BaseModel, Embeddings): """ - client: Any #: :meta private: + client: Any = Field(default=None, exclude=True) #: :meta private: + async_client: Any = Field(default=None, exclude=True) #: :meta private: model: str = "text-embedding-ada-002" deployment: str = model openai_api_version: Optional[str] = None - openai_api_base: Optional[str] = None + openai_api_base: Optional[str] = Field(default=None, alias="base_url") # to support explicit proxy for LocalAI openai_proxy: Optional[str] = None embedding_ctx_length: int = 8191 """The maximum number of tokens to embed at once.""" - openai_api_key: Optional[str] = None - openai_organization: Optional[str] = None + openai_api_key: Optional[str] = Field(default=None, alias="api_key") + openai_organization: Optional[str] = Field(default=None, alias="organization") allowed_special: Union[Literal["all"], Set[str]] = set() disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" chunk_size: int = 1000 """Maximum number of texts to embed in each batch""" max_retries: int = 6 """Maximum number of retries to make when generating.""" - request_timeout: Optional[Union[float, Tuple[float, float]]] = None + request_timeout: Union[float, Tuple[float, float], Any, None] = Field( + default=None, alias="timeout" + ) """Timeout in seconds for the LocalAI request.""" headers: Any = None show_progress_bar: bool = False @@ -165,7 +169,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings): class Config: """Configuration for this pydantic object.""" - extra = Extra.forbid + allow_population_by_field_name = True @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: @@ -228,7 +232,25 @@ def validate_environment(cls, values: Dict) -> Dict: try: import openai - values["client"] = openai.Embedding + if is_openai_v1(): + client_params = { + "api_key": values["openai_api_key"], + "organization": values["openai_organization"], + "base_url": values["openai_api_base"], + "timeout": values["request_timeout"], + "max_retries": values["max_retries"], + } + + if not values.get("client"): + values["client"] = openai.OpenAI(**client_params).embeddings + if not values.get("async_client"): + values["async_client"] = openai.AsyncOpenAI( + **client_params + ).embeddings + elif not values.get("client"): + values["client"] = openai.Embedding + else: + pass except ImportError: raise ImportError( "Could not import openai python package. " @@ -240,12 +262,8 @@ def validate_environment(cls, values: Dict) -> Dict: def _invocation_params(self) -> Dict: openai_args = { "model": self.model, - "request_timeout": self.request_timeout, - "headers": self.headers, - "api_key": self.openai_api_key, - "organization": self.openai_organization, - "api_base": self.openai_api_base, - "api_version": self.openai_api_version, + "timeout": self.request_timeout, + "extra_headers": self.headers, **self.model_kwargs, } if self.openai_proxy: @@ -268,7 +286,7 @@ def _embedding_func(self, text: str, *, engine: str) -> List[float]: self, input=[text], **self._invocation_params, - )["data"][0]["embedding"] + ).data[0].embedding async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: """Call out to LocalAI's embedding endpoint.""" @@ -283,7 +301,7 @@ async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: input=[text], **self._invocation_params, ) - )["data"][0]["embedding"] + ).data[0].embedding def embed_documents( self, texts: List[str], chunk_size: Optional[int] = 0