-
Notifications
You must be signed in to change notification settings - Fork 16.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
community: make LocalAIEmbeddings comaptible with openai>1.0 #17154
Changes from all commits
7634c51
0879b58
5ea0e2a
ce3999d
bee87d9
f47cef8
50dca7c
a1871bc
44ad76a
daf04f0
69577ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
) | ||
|
||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator | ||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator | ||
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names | ||
from tenacity import ( | ||
AsyncRetrying, | ||
|
@@ -27,27 +27,47 @@ | |
wait_exponential, | ||
) | ||
|
||
from langchain_community.utils.openai import is_openai_v1 | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DataItem(BaseModel): | ||
embedding: List[float] | ||
|
||
|
||
class Response(BaseModel): | ||
data: List[DataItem] | ||
|
||
|
||
def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]: | ||
import openai | ||
|
||
min_seconds = 4 | ||
max_seconds = 10 | ||
# Wait 2^x * 1 second between each retry starting with | ||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards | ||
return retry( | ||
reraise=True, | ||
stop=stop_after_attempt(embeddings.max_retries), | ||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), | ||
retry=( | ||
if is_openai_v1(): | ||
retry_ = ( | ||
retry_if_exception_type(openai.Timeout) | ||
| retry_if_exception_type(openai.APIError) | ||
| retry_if_exception_type(openai.APIConnectionError) | ||
| retry_if_exception_type(openai.RateLimitError) | ||
| retry_if_exception_type(openai.InternalServerError) | ||
Comment on lines
+52
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't think this is v0 compatible |
||
) | ||
else: | ||
retry_ = ( | ||
retry_if_exception_type(openai.error.Timeout) | ||
| retry_if_exception_type(openai.error.APIError) | ||
| retry_if_exception_type(openai.error.APIConnectionError) | ||
| retry_if_exception_type(openai.error.RateLimitError) | ||
| retry_if_exception_type(openai.error.ServiceUnavailableError) | ||
), | ||
) | ||
return retry( | ||
reraise=True, | ||
stop=stop_after_attempt(embeddings.max_retries), | ||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), | ||
retry=retry_, | ||
before_sleep=before_sleep_log(logger, logging.WARNING), | ||
) | ||
|
||
|
@@ -59,17 +79,27 @@ def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any: | |
max_seconds = 10 | ||
# Wait 2^x * 1 second between each retry starting with | ||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards | ||
async_retrying = AsyncRetrying( | ||
reraise=True, | ||
stop=stop_after_attempt(embeddings.max_retries), | ||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), | ||
retry=( | ||
if is_openai_v1(): | ||
retry_ = ( | ||
retry_if_exception_type(openai.Timeout) | ||
| retry_if_exception_type(openai.APIError) | ||
| retry_if_exception_type(openai.APIConnectionError) | ||
| retry_if_exception_type(openai.RateLimitError) | ||
| retry_if_exception_type(openai.InternalServerError) | ||
) | ||
else: | ||
retry_ = ( | ||
retry_if_exception_type(openai.error.Timeout) | ||
| retry_if_exception_type(openai.error.APIError) | ||
| retry_if_exception_type(openai.error.APIConnectionError) | ||
| retry_if_exception_type(openai.error.RateLimitError) | ||
| retry_if_exception_type(openai.error.ServiceUnavailableError) | ||
), | ||
) | ||
async_retrying = AsyncRetrying( | ||
reraise=True, | ||
stop=stop_after_attempt(embeddings.max_retries), | ||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), | ||
retry=retry_, | ||
before_sleep=before_sleep_log(logger, logging.WARNING), | ||
) | ||
|
||
|
@@ -85,11 +115,15 @@ async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: | |
|
||
|
||
# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings | ||
def _check_response(response: dict) -> dict: | ||
if any(len(d["embedding"]) == 1 for d in response["data"]): | ||
def _check_response(response: Response) -> Response: | ||
if any(len(d.embedding) == 1 for d in response.data): | ||
import openai | ||
|
||
raise openai.error.APIError("LocalAI API returned an empty embedding") | ||
if is_openai_v1(): | ||
error_cls = openai.APIError | ||
else: | ||
error_cls = openai.error.APIError | ||
raise error_cls("LocalAI API returned an empty embedding") | ||
return response | ||
|
||
|
||
|
@@ -110,7 +144,7 @@ async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) - | |
|
||
@_async_retry_decorator(embeddings) | ||
async def _async_embed_with_retry(**kwargs: Any) -> Any: | ||
response = await embeddings.client.acreate(**kwargs) | ||
response = await embeddings.async_client.acreate(**kwargs) | ||
return _check_response(response) | ||
|
||
return await _async_embed_with_retry(**kwargs) | ||
|
@@ -137,24 +171,27 @@ class LocalAIEmbeddings(BaseModel, Embeddings): | |
|
||
""" | ||
|
||
client: Any #: :meta private: | ||
client: Any = Field(default=None, exclude=True) #: :meta private: | ||
async_client: Any = Field(default=None, exclude=True) #: :meta private: | ||
model: str = "text-embedding-ada-002" | ||
deployment: str = model | ||
openai_api_version: Optional[str] = None | ||
openai_api_base: Optional[str] = None | ||
openai_api_base: Optional[str] = Field(default=None, alias="base_url") | ||
# to support explicit proxy for LocalAI | ||
openai_proxy: Optional[str] = None | ||
embedding_ctx_length: int = 8191 | ||
"""The maximum number of tokens to embed at once.""" | ||
openai_api_key: Optional[str] = None | ||
openai_organization: Optional[str] = None | ||
openai_api_key: Optional[str] = Field(default=None, alias="api_key") | ||
openai_organization: Optional[str] = Field(default=None, alias="organization") | ||
allowed_special: Union[Literal["all"], Set[str]] = set() | ||
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" | ||
chunk_size: int = 1000 | ||
"""Maximum number of texts to embed in each batch""" | ||
max_retries: int = 6 | ||
"""Maximum number of retries to make when generating.""" | ||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None | ||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( | ||
default=None, alias="timeout" | ||
) | ||
"""Timeout in seconds for the LocalAI request.""" | ||
headers: Any = None | ||
show_progress_bar: bool = False | ||
|
@@ -165,7 +202,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings): | |
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
allow_population_by_field_name = True | ||
|
||
@root_validator(pre=True) | ||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: | ||
|
@@ -228,7 +265,25 @@ def validate_environment(cls, values: Dict) -> Dict: | |
try: | ||
import openai | ||
|
||
values["client"] = openai.Embedding | ||
if is_openai_v1(): | ||
client_params = { | ||
"api_key": values["openai_api_key"], | ||
"organization": values["openai_organization"], | ||
"base_url": values["openai_api_base"], | ||
"timeout": values["request_timeout"], | ||
"max_retries": values["max_retries"], | ||
} | ||
|
||
if not values.get("client"): | ||
values["client"] = openai.OpenAI(**client_params).embeddings | ||
if not values.get("async_client"): | ||
values["async_client"] = openai.AsyncOpenAI( | ||
**client_params | ||
).embeddings | ||
elif not values.get("client"): | ||
values["client"] = openai.Embedding | ||
else: | ||
pass | ||
except ImportError: | ||
raise ImportError( | ||
"Could not import openai python package. " | ||
|
@@ -238,16 +293,24 @@ def validate_environment(cls, values: Dict) -> Dict: | |
|
||
@property | ||
def _invocation_params(self) -> Dict: | ||
openai_args = { | ||
"model": self.model, | ||
"request_timeout": self.request_timeout, | ||
"headers": self.headers, | ||
"api_key": self.openai_api_key, | ||
"organization": self.openai_organization, | ||
"api_base": self.openai_api_base, | ||
"api_version": self.openai_api_version, | ||
**self.model_kwargs, | ||
} | ||
if is_openai_v1(): | ||
openai_args = { | ||
"model": self.model, | ||
"timeout": self.request_timeout, | ||
"extra_headers": self.headers, | ||
**self.model_kwargs, | ||
} | ||
else: | ||
openai_args = { | ||
"model": self.model, | ||
"request_timeout": self.request_timeout, | ||
"headers": self.headers, | ||
"api_key": self.openai_api_key, | ||
"organization": self.openai_organization, | ||
"api_base": self.openai_api_base, | ||
"api_version": self.openai_api_version, | ||
**self.model_kwargs, | ||
} | ||
if self.openai_proxy: | ||
import openai | ||
|
||
|
@@ -264,11 +327,15 @@ def _embedding_func(self, text: str, *, engine: str) -> List[float]: | |
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 | ||
# replace newlines, which can negatively affect performance. | ||
text = text.replace("\n", " ") | ||
return embed_with_retry( | ||
self, | ||
input=[text], | ||
**self._invocation_params, | ||
)["data"][0]["embedding"] | ||
return ( | ||
embed_with_retry( | ||
self, | ||
input=[text], | ||
**self._invocation_params, | ||
) | ||
.data[0] | ||
.embedding | ||
) | ||
Comment on lines
+331
to
+338
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't think this is v0 compatible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bump There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same below |
||
|
||
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: | ||
"""Call out to LocalAI's embedding endpoint.""" | ||
|
@@ -278,12 +345,16 @@ async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: | |
# replace newlines, which can negatively affect performance. | ||
text = text.replace("\n", " ") | ||
return ( | ||
await async_embed_with_retry( | ||
self, | ||
input=[text], | ||
**self._invocation_params, | ||
( | ||
await async_embed_with_retry( | ||
self, | ||
input=[text], | ||
**self._invocation_params, | ||
) | ||
) | ||
)["data"][0]["embedding"] | ||
.data[0] | ||
.embedding | ||
) | ||
|
||
def embed_documents( | ||
self, texts: List[str], chunk_size: Optional[int] = 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we make this compatible with both openai v0 and v1, the way langchain_community.embedding.OpenAIEmbeddings is? or if that's too hard we should at least add an explicit version check of the openai sdk and raise an informative error if it's less than 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
surely it has been done.