Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mistral: catch GatedRepoError, release 0.1.3 #20802

Merged
merged 5 commits into from
Apr 23, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions libs/partners/mistralai/langchain_mistralai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import warnings
from typing import Dict, Iterable, List, Optional

import httpx
Expand All @@ -19,6 +20,13 @@
MAX_TOKENS = 16_000


class DummyTokenizer:
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""

def encode_batch(self, texts: List[str]) -> List[List[str]]:
return [list(text) for text in texts]


class MistralAIEmbeddings(BaseModel, Embeddings):
"""MistralAI embedding models.

Expand Down Expand Up @@ -83,9 +91,18 @@ def validate_environment(cls, values: Dict) -> Dict:
timeout=values["timeout"],
)
if values["tokenizer"] is None:
values["tokenizer"] = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
try:
values["tokenizer"] = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
except IOError: # huggingface_hub GatedRepoError
warnings.warn(
"Could not download mistral tokenizer from Huggingface for "
"calculating batch sizes. Set a Huggingface token via the"
"HF_TOKEN environment variable to download the real tokenizer. "
"Falling back to a dummy tokenizer that uses `len()`."
)
values["tokenizer"] = DummyTokenizer()
return values

def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
Expand All @@ -100,7 +117,10 @@ def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:

for text, text_tokens in zip(texts, text_token_lengths):
if batch_tokens + text_tokens > MAX_TOKENS:
yield batch
if len(batch) > 0:
# edge case where first batch exceeds max tokens
# should not yield an empty batch.
yield batch
batch = [text]
batch_tokens = text_tokens
else:
Expand Down
Loading