Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: make LocalAIEmbeddings comaptible with openai>1.0 #17154

Closed
wants to merge 11 commits into from
159 changes: 115 additions & 44 deletions libs/community/langchain_community/embeddings/localai.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make this compatible with both openai v0 and v1, the way langchain_community.embedding.OpenAIEmbeddings is? or if that's too hard we should at least add an explicit version check of the openai sdk and raise an informative error if it's less than 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surely it has been done.

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
from tenacity import (
AsyncRetrying,
Expand All @@ -27,27 +27,47 @@
wait_exponential,
)

from langchain_community.utils.openai import is_openai_v1

logger = logging.getLogger(__name__)


class DataItem(BaseModel):
embedding: List[float]


class Response(BaseModel):
data: List[DataItem]


def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]:
import openai

min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
if is_openai_v1():
retry_ = (
retry_if_exception_type(openai.Timeout)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
Comment on lines +52 to +56
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think this is v0 compatible

)
else:
retry_ = (
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
)
return retry(
reraise=True,
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=retry_,
before_sleep=before_sleep_log(logger, logging.WARNING),
)

Expand All @@ -59,17 +79,27 @@ def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
async_retrying = AsyncRetrying(
reraise=True,
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
if is_openai_v1():
retry_ = (
retry_if_exception_type(openai.Timeout)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
)
else:
retry_ = (
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
)
async_retrying = AsyncRetrying(
reraise=True,
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=retry_,
before_sleep=before_sleep_log(logger, logging.WARNING),
)

Expand All @@ -85,11 +115,15 @@ async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:


# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
def _check_response(response: dict) -> dict:
if any(len(d["embedding"]) == 1 for d in response["data"]):
def _check_response(response: Response) -> Response:
if any(len(d.embedding) == 1 for d in response.data):
import openai

raise openai.error.APIError("LocalAI API returned an empty embedding")
if is_openai_v1():
error_cls = openai.APIError
else:
error_cls = openai.error.APIError
raise error_cls("LocalAI API returned an empty embedding")
return response


Expand All @@ -110,7 +144,7 @@ async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -

@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.client.acreate(**kwargs)
response = await embeddings.async_client.acreate(**kwargs)
return _check_response(response)

return await _async_embed_with_retry(**kwargs)
Expand All @@ -137,24 +171,27 @@ class LocalAIEmbeddings(BaseModel, Embeddings):

"""

client: Any #: :meta private:
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = "text-embedding-ada-002"
deployment: str = model
openai_api_version: Optional[str] = None
openai_api_base: Optional[str] = None
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
# to support explicit proxy for LocalAI
openai_proxy: Optional[str] = None
embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: Optional[str] = None
openai_organization: Optional[str] = None
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_organization: Optional[str] = Field(default=None, alias="organization")
allowed_special: Union[Literal["all"], Set[str]] = set()
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
chunk_size: int = 1000
"""Maximum number of texts to embed in each batch"""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout in seconds for the LocalAI request."""
headers: Any = None
show_progress_bar: bool = False
Expand All @@ -165,7 +202,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
allow_population_by_field_name = True

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -228,7 +265,25 @@ def validate_environment(cls, values: Dict) -> Dict:
try:
import openai

values["client"] = openai.Embedding
if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
}

if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).embeddings
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(
**client_params
).embeddings
elif not values.get("client"):
values["client"] = openai.Embedding
else:
pass
except ImportError:
raise ImportError(
"Could not import openai python package. "
Expand All @@ -238,16 +293,24 @@ def validate_environment(cls, values: Dict) -> Dict:

@property
def _invocation_params(self) -> Dict:
openai_args = {
"model": self.model,
"request_timeout": self.request_timeout,
"headers": self.headers,
"api_key": self.openai_api_key,
"organization": self.openai_organization,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
**self.model_kwargs,
}
if is_openai_v1():
openai_args = {
"model": self.model,
"timeout": self.request_timeout,
"extra_headers": self.headers,
**self.model_kwargs,
}
else:
openai_args = {
"model": self.model,
"request_timeout": self.request_timeout,
"headers": self.headers,
"api_key": self.openai_api_key,
"organization": self.openai_organization,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
**self.model_kwargs,
}
if self.openai_proxy:
import openai

Expand All @@ -264,11 +327,15 @@ def _embedding_func(self, text: str, *, engine: str) -> List[float]:
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return embed_with_retry(
self,
input=[text],
**self._invocation_params,
)["data"][0]["embedding"]
return (
embed_with_retry(
self,
input=[text],
**self._invocation_params,
)
.data[0]
.embedding
)
Comment on lines +331 to +338
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think this is v0 compatible

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same below


async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint."""
Expand All @@ -278,12 +345,16 @@ async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (
await async_embed_with_retry(
self,
input=[text],
**self._invocation_params,
(
await async_embed_with_retry(
self,
input=[text],
**self._invocation_params,
)
)
)["data"][0]["embedding"]
.data[0]
.embedding
)

def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
Expand Down
Loading