diff --git a/docs/docs/modules/data_connection/text_embedding/caching_embeddings.ipynb b/docs/docs/modules/data_connection/text_embedding/caching_embeddings.ipynb index de2947e312462..2e5241d00d8e3 100644 --- a/docs/docs/modules/data_connection/text_embedding/caching_embeddings.ipynb +++ b/docs/docs/modules/data_connection/text_embedding/caching_embeddings.ipynb @@ -22,10 +22,11 @@ "Caching embeddings can be done using a `CacheBackedEmbeddings`. The cache backed embedder is a wrapper around an embedder that caches\n", "embeddings in a key-value store. The text is hashed and the hash is used as the key in the cache.\n", "\n", - "The main supported way to initialized a `CacheBackedEmbeddings` is `from_bytes_store`. This takes in the following parameters:\n", + "The main supported way to initialize a `CacheBackedEmbeddings` is `from_bytes_store`. It takes the following parameters:\n", "\n", "- underlying_embedder: The embedder to use for embedding.\n", "- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n", + "- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n", "- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n", "\n", "**Attention**: Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models." diff --git a/libs/core/langchain_core/utils/iter.py b/libs/core/langchain_core/utils/iter.py index 60834163c3f0f..4191ea9ab5d57 100644 --- a/libs/core/langchain_core/utils/iter.py +++ b/libs/core/langchain_core/utils/iter.py @@ -165,8 +165,16 @@ def close(self) -> None: safetee = Tee -def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: - """Utility batching function.""" +def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]: + """Utility batching function. + + Args: + size: The size of the batch. If None, returns a single batch. + iterable: The iterable to batch. + + Returns: + An iterator over the batches. + """ it = iter(iterable) while True: chunk = list(islice(it, size)) diff --git a/libs/langchain/langchain/embeddings/cache.py b/libs/langchain/langchain/embeddings/cache.py index 51d3969e53967..05b7e6d923ed3 100644 --- a/libs/langchain/langchain/embeddings/cache.py +++ b/libs/langchain/langchain/embeddings/cache.py @@ -12,10 +12,11 @@ import json import uuid from functools import partial -from typing import Callable, List, Sequence, Union, cast +from typing import Callable, List, Optional, Sequence, Union, cast from langchain_core.embeddings import Embeddings from langchain_core.stores import BaseStore, ByteStore +from langchain_core.utils.iter import batch_iterate from langchain.storage.encoder_backed import EncoderBackedStore @@ -84,16 +85,20 @@ def __init__( self, underlying_embeddings: Embeddings, document_embedding_store: BaseStore[str, List[float]], + *, + batch_size: Optional[int] = None, ) -> None: """Initialize the embedder. Args: underlying_embeddings: the embedder to use for computing embeddings. document_embedding_store: The store to use for caching document embeddings. + batch_size: The number of documents to embed between store updates. """ super().__init__() self.document_embedding_store = document_embedding_store self.underlying_embeddings = underlying_embeddings + self.batch_size = batch_size def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of texts. @@ -111,12 +116,12 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: vectors: List[Union[List[float], None]] = self.document_embedding_store.mget( texts ) - missing_indices: List[int] = [ + all_missing_indices: List[int] = [ i for i, vector in enumerate(vectors) if vector is None ] - missing_texts = [texts[i] for i in missing_indices] - if missing_texts: + for missing_indices in batch_iterate(self.batch_size, all_missing_indices): + missing_texts = [texts[i] for i in missing_indices] missing_vectors = self.underlying_embeddings.embed_documents(missing_texts) self.document_embedding_store.mset( list(zip(missing_texts, missing_vectors)) @@ -144,12 +149,14 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]: vectors: List[ Union[List[float], None] ] = await self.document_embedding_store.amget(texts) - missing_indices: List[int] = [ + all_missing_indices: List[int] = [ i for i, vector in enumerate(vectors) if vector is None ] - missing_texts = [texts[i] for i in missing_indices] - if missing_texts: + # batch_iterate supports None batch_size which returns all elements at once + # as a single batch. + for missing_indices in batch_iterate(self.batch_size, all_missing_indices): + missing_texts = [texts[i] for i in missing_indices] missing_vectors = await self.underlying_embeddings.aembed_documents( missing_texts ) @@ -210,6 +217,7 @@ def from_bytes_store( document_embedding_cache: ByteStore, *, namespace: str = "", + batch_size: Optional[int] = None, ) -> CacheBackedEmbeddings: """On-ramp that adds the necessary serialization and encoding to the store. @@ -220,6 +228,7 @@ def from_bytes_store( namespace: The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used. + batch_size: The number of documents to embed between store updates. """ namespace = namespace key_encoder = _create_key_encoder(namespace) @@ -229,4 +238,4 @@ def from_bytes_store( _value_serializer, _value_deserializer, ) - return cls(underlying_embeddings, encoder_backed_store) + return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size) diff --git a/libs/langchain/tests/unit_tests/embeddings/test_caching.py b/libs/langchain/tests/unit_tests/embeddings/test_caching.py index 8c24e73b95b3f..154f248d6494c 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_caching.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_caching.py @@ -13,6 +13,8 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: # Simulate embedding documents embeddings: List[List[float]] = [] for text in texts: + if text == "RAISE_EXCEPTION": + raise ValueError("Simulated embedding failure") embeddings.append([len(text), len(text) + 1]) return embeddings @@ -31,6 +33,16 @@ def cache_embeddings() -> CacheBackedEmbeddings: ) +@pytest.fixture +def cache_embeddings_batch() -> CacheBackedEmbeddings: + """Create a cache backed embeddings with a batch_size of 3.""" + store = InMemoryStore() + embeddings = MockEmbeddings() + return CacheBackedEmbeddings.from_bytes_store( + embeddings, store, namespace="test_namespace", batch_size=3 + ) + + def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None: texts = ["1", "22", "a", "333"] vectors = cache_embeddings.embed_documents(texts) @@ -42,6 +54,20 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None: assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12" +def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None: + # "RAISE_EXCEPTION" forces a failure in batch 2 + texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"] + try: + cache_embeddings_batch.embed_documents(texts) + except ValueError: + pass + keys = list(cache_embeddings_batch.document_embedding_store.yield_keys()) + # only the first batch of three embeddings should exist + assert len(keys) == 3 + # UUID is expected to be the same for the same text + assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12" + + def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None: text = "query_text" vector = cache_embeddings.embed_query(text) @@ -62,6 +88,25 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12" +async def test_aembed_documents_batch( + cache_embeddings_batch: CacheBackedEmbeddings, +) -> None: + # "RAISE_EXCEPTION" forces a failure in batch 2 + texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"] + try: + await cache_embeddings_batch.aembed_documents(texts) + except ValueError: + pass + keys = [ + key + async for key in cache_embeddings_batch.document_embedding_store.ayield_keys() + ] + # only the first batch of three embeddings should exist + assert len(keys) == 3 + # UUID is expected to be the same for the same text + assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12" + + async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None: text = "query_text" vector = await cache_embeddings.aembed_query(text)