Skip to content

Commit

Permalink
mistralai[patch]: 16k token batching logic embed (#17136)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored Feb 6, 2024
1 parent 863f96b commit f881a33
Show file tree
Hide file tree
Showing 4 changed files with 472 additions and 213 deletions.
58 changes: 49 additions & 9 deletions libs/partners/mistralai/langchain_mistralai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import Dict, List, Optional
from typing import Dict, Iterable, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
Expand All @@ -16,9 +17,12 @@
ENDPOINT as DEFAULT_MISTRAL_ENDPOINT,
)
from mistralai.exceptions import MistralException
from tokenizers import Tokenizer # type: ignore

logger = logging.getLogger(__name__)

MAX_TOKENS = 16_000


class MistralAIEmbeddings(BaseModel, Embeddings):
"""MistralAI embedding models.
Expand All @@ -43,6 +47,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
max_retries: int = 5
timeout: int = 120
max_concurrent_requests: int = 64
tokenizer: Tokenizer = Field(default=None)

model: str = "mistral-embed"

Expand Down Expand Up @@ -72,8 +77,33 @@ def validate_environment(cls, values: Dict) -> Dict:
timeout=values["timeout"],
max_concurrent_requests=values["max_concurrent_requests"],
)
if values["tokenizer"] is None:
values["tokenizer"] = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
return values

def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
"""Split a list of texts into batches of less than 16k tokens
for Mistral API."""
batch: List[str] = []
batch_tokens = 0

text_token_lengths = [
len(encoded) for encoded in self.tokenizer.encode_batch(texts)
]

for text, text_tokens in zip(texts, text_token_lengths):
if batch_tokens + text_tokens > MAX_TOKENS:
yield batch
batch = [text]
batch_tokens = text_tokens
else:
batch.append(text)
batch_tokens += text_tokens
if batch:
yield batch

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of document texts.
Expand All @@ -84,13 +114,17 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
List of embeddings, one for each text.
"""
try:
embeddings_batch_response = self.client.embeddings(
model=self.model,
input=texts,
batch_responses = (
self.client.embeddings(
model=self.model,
input=batch,
)
for batch in self._get_batches(texts)
)
return [
list(map(float, embedding_obj.embedding))
for embedding_obj in embeddings_batch_response.data
for response in batch_responses
for embedding_obj in response.data
]
except MistralException as e:
logger.error(f"An error occurred with MistralAI: {e}")
Expand All @@ -106,13 +140,19 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
List of embeddings, one for each text.
"""
try:
embeddings_batch_response = await self.async_client.embeddings(
model=self.model,
input=texts,
batch_responses = await asyncio.gather(
*[
self.async_client.embeddings(
model=self.model,
input=batch,
)
for batch in self._get_batches(texts)
]
)
return [
list(map(float, embedding_obj.embedding))
for embedding_obj in embeddings_batch_response.data
for response in batch_responses
for embedding_obj in response.data
]
except MistralException as e:
logger.error(f"An error occurred with MistralAI: {e}")
Expand Down
Loading

0 comments on commit f881a33

Please sign in to comment.