diff --git a/libs/community/langchain_community/embeddings/localai.py b/libs/community/langchain_community/embeddings/localai.py index b5a926e8fe2ca..d6c341f825d58 100644 --- a/libs/community/langchain_community/embeddings/localai.py +++ b/libs/community/langchain_community/embeddings/localai.py @@ -16,7 +16,7 @@ ) from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names from tenacity import ( AsyncRetrying, @@ -27,9 +27,19 @@ wait_exponential, ) +from langchain_community.utils.openai import is_openai_v1 + logger = logging.getLogger(__name__) +class DataItem(BaseModel): + embedding: List[float] + + +class Response(BaseModel): + data: List[DataItem] + + def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]: import openai @@ -37,17 +47,27 @@ def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], An max_seconds = 10 # Wait 2^x * 1 second between each retry starting with # 4 seconds, then up to 10 seconds, then 10 seconds afterwards - return retry( - reraise=True, - stop=stop_after_attempt(embeddings.max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=( + if is_openai_v1(): + retry_ = ( + retry_if_exception_type(openai.Timeout) + | retry_if_exception_type(openai.APIError) + | retry_if_exception_type(openai.APIConnectionError) + | retry_if_exception_type(openai.RateLimitError) + | retry_if_exception_type(openai.InternalServerError) + ) + else: + retry_ = ( retry_if_exception_type(openai.error.Timeout) | retry_if_exception_type(openai.error.APIError) | retry_if_exception_type(openai.error.APIConnectionError) | retry_if_exception_type(openai.error.RateLimitError) | retry_if_exception_type(openai.error.ServiceUnavailableError) - ), + ) + return retry( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=retry_, before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -59,17 +79,27 @@ def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any: max_seconds = 10 # Wait 2^x * 1 second between each retry starting with # 4 seconds, then up to 10 seconds, then 10 seconds afterwards - async_retrying = AsyncRetrying( - reraise=True, - stop=stop_after_attempt(embeddings.max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=( + if is_openai_v1(): + retry_ = ( + retry_if_exception_type(openai.Timeout) + | retry_if_exception_type(openai.APIError) + | retry_if_exception_type(openai.APIConnectionError) + | retry_if_exception_type(openai.RateLimitError) + | retry_if_exception_type(openai.InternalServerError) + ) + else: + retry_ = ( retry_if_exception_type(openai.error.Timeout) | retry_if_exception_type(openai.error.APIError) | retry_if_exception_type(openai.error.APIConnectionError) | retry_if_exception_type(openai.error.RateLimitError) | retry_if_exception_type(openai.error.ServiceUnavailableError) - ), + ) + async_retrying = AsyncRetrying( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=retry_, before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -85,11 +115,15 @@ async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: # https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings -def _check_response(response: dict) -> dict: - if any(len(d["embedding"]) == 1 for d in response["data"]): +def _check_response(response: Response) -> Response: + if any(len(d.embedding) == 1 for d in response.data): import openai - raise openai.error.APIError("LocalAI API returned an empty embedding") + if is_openai_v1(): + error_cls = openai.APIError + else: + error_cls = openai.error.APIError + raise error_cls("LocalAI API returned an empty embedding") return response @@ -110,7 +144,7 @@ async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) - @_async_retry_decorator(embeddings) async def _async_embed_with_retry(**kwargs: Any) -> Any: - response = await embeddings.client.acreate(**kwargs) + response = await embeddings.async_client.acreate(**kwargs) return _check_response(response) return await _async_embed_with_retry(**kwargs) @@ -137,24 +171,27 @@ class LocalAIEmbeddings(BaseModel, Embeddings): """ - client: Any #: :meta private: + client: Any = Field(default=None, exclude=True) #: :meta private: + async_client: Any = Field(default=None, exclude=True) #: :meta private: model: str = "text-embedding-ada-002" deployment: str = model openai_api_version: Optional[str] = None - openai_api_base: Optional[str] = None + openai_api_base: Optional[str] = Field(default=None, alias="base_url") # to support explicit proxy for LocalAI openai_proxy: Optional[str] = None embedding_ctx_length: int = 8191 """The maximum number of tokens to embed at once.""" - openai_api_key: Optional[str] = None - openai_organization: Optional[str] = None + openai_api_key: Optional[str] = Field(default=None, alias="api_key") + openai_organization: Optional[str] = Field(default=None, alias="organization") allowed_special: Union[Literal["all"], Set[str]] = set() disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" chunk_size: int = 1000 """Maximum number of texts to embed in each batch""" max_retries: int = 6 """Maximum number of retries to make when generating.""" - request_timeout: Optional[Union[float, Tuple[float, float]]] = None + request_timeout: Union[float, Tuple[float, float], Any, None] = Field( + default=None, alias="timeout" + ) """Timeout in seconds for the LocalAI request.""" headers: Any = None show_progress_bar: bool = False @@ -165,7 +202,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings): class Config: """Configuration for this pydantic object.""" - extra = Extra.forbid + allow_population_by_field_name = True @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: @@ -228,7 +265,25 @@ def validate_environment(cls, values: Dict) -> Dict: try: import openai - values["client"] = openai.Embedding + if is_openai_v1(): + client_params = { + "api_key": values["openai_api_key"], + "organization": values["openai_organization"], + "base_url": values["openai_api_base"], + "timeout": values["request_timeout"], + "max_retries": values["max_retries"], + } + + if not values.get("client"): + values["client"] = openai.OpenAI(**client_params).embeddings + if not values.get("async_client"): + values["async_client"] = openai.AsyncOpenAI( + **client_params + ).embeddings + elif not values.get("client"): + values["client"] = openai.Embedding + else: + pass except ImportError: raise ImportError( "Could not import openai python package. " @@ -238,16 +293,24 @@ def validate_environment(cls, values: Dict) -> Dict: @property def _invocation_params(self) -> Dict: - openai_args = { - "model": self.model, - "request_timeout": self.request_timeout, - "headers": self.headers, - "api_key": self.openai_api_key, - "organization": self.openai_organization, - "api_base": self.openai_api_base, - "api_version": self.openai_api_version, - **self.model_kwargs, - } + if is_openai_v1(): + openai_args = { + "model": self.model, + "timeout": self.request_timeout, + "extra_headers": self.headers, + **self.model_kwargs, + } + else: + openai_args = { + "model": self.model, + "request_timeout": self.request_timeout, + "headers": self.headers, + "api_key": self.openai_api_key, + "organization": self.openai_organization, + "api_base": self.openai_api_base, + "api_version": self.openai_api_version, + **self.model_kwargs, + } if self.openai_proxy: import openai @@ -264,11 +327,15 @@ def _embedding_func(self, text: str, *, engine: str) -> List[float]: # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") - return embed_with_retry( - self, - input=[text], - **self._invocation_params, - )["data"][0]["embedding"] + return ( + embed_with_retry( + self, + input=[text], + **self._invocation_params, + ) + .data[0] + .embedding + ) async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: """Call out to LocalAI's embedding endpoint.""" @@ -278,12 +345,16 @@ async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") return ( - await async_embed_with_retry( - self, - input=[text], - **self._invocation_params, + ( + await async_embed_with_retry( + self, + input=[text], + **self._invocation_params, + ) ) - )["data"][0]["embedding"] + .data[0] + .embedding + ) def embed_documents( self, texts: List[str], chunk_size: Optional[int] = 0