Skip to content

Commit

Permalink
core: implement a batch_size parameter for CacheBackedEmbeddings (#18070
Browse files Browse the repository at this point in the history
)

**Description:**

Currently, `CacheBackedEmbeddings` computes vectors for *all* uncached
documents before updating the store. This pull request updates the
embedding computation loop to compute embeddings in batches, updating
the store after each batch.

I noticed this when I tried `CacheBackedEmbeddings` on our 30k document
set and the cache directory hadn't appeared on disk after 30 minutes.

The motivation is to minimize compute/data loss when problems occur:

* If there is a transient embedding failure (e.g. a network outage at
the embedding endpoint triggers an exception), at least the completed
vectors are written to the store instead of being discarded.
* If there is an issue with the store (e.g. no write permissions), the
condition is detected early without computing (and discarding!) all the
vectors.

**Issue:**
Implements enhancement #18026.

**Testing:**
I was unable to run unit tests; details in [this
post](#15019 (comment)).

---------

Signed-off-by: chrispy <[email protected]>
Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
chrispy-snps and eyurtsev authored Mar 19, 2024
1 parent 89af308 commit 305d74c
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
"Caching embeddings can be done using a `CacheBackedEmbeddings`. The cache backed embedder is a wrapper around an embedder that caches\n",
"embeddings in a key-value store. The text is hashed and the hash is used as the key in the cache.\n",
"\n",
"The main supported way to initialized a `CacheBackedEmbeddings` is `from_bytes_store`. This takes in the following parameters:\n",
"The main supported way to initialize a `CacheBackedEmbeddings` is `from_bytes_store`. It takes the following parameters:\n",
"\n",
"- underlying_embedder: The embedder to use for embedding.\n",
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
"- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n",
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
"\n",
"**Attention**: Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models."
Expand Down
12 changes: 10 additions & 2 deletions libs/core/langchain_core/utils/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,16 @@ def close(self) -> None:
safetee = Tee


def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
"""Utility batching function."""
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]:
"""Utility batching function.
Args:
size: The size of the batch. If None, returns a single batch.
iterable: The iterable to batch.
Returns:
An iterator over the batches.
"""
it = iter(iterable)
while True:
chunk = list(islice(it, size))
Expand Down
25 changes: 17 additions & 8 deletions libs/langchain/langchain/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
import json
import uuid
from functools import partial
from typing import Callable, List, Sequence, Union, cast
from typing import Callable, List, Optional, Sequence, Union, cast

from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.utils.iter import batch_iterate

from langchain.storage.encoder_backed import EncoderBackedStore

Expand Down Expand Up @@ -84,16 +85,20 @@ def __init__(
self,
underlying_embeddings: Embeddings,
document_embedding_store: BaseStore[str, List[float]],
*,
batch_size: Optional[int] = None,
) -> None:
"""Initialize the embedder.
Args:
underlying_embeddings: the embedder to use for computing embeddings.
document_embedding_store: The store to use for caching document embeddings.
batch_size: The number of documents to embed between store updates.
"""
super().__init__()
self.document_embedding_store = document_embedding_store
self.underlying_embeddings = underlying_embeddings
self.batch_size = batch_size

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of texts.
Expand All @@ -111,12 +116,12 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
texts
)
missing_indices: List[int] = [
all_missing_indices: List[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
missing_texts = [texts[i] for i in missing_indices]

if missing_texts:
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
self.document_embedding_store.mset(
list(zip(missing_texts, missing_vectors))
Expand Down Expand Up @@ -144,12 +149,14 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
vectors: List[
Union[List[float], None]
] = await self.document_embedding_store.amget(texts)
missing_indices: List[int] = [
all_missing_indices: List[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
missing_texts = [texts[i] for i in missing_indices]

if missing_texts:
# batch_iterate supports None batch_size which returns all elements at once
# as a single batch.
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = await self.underlying_embeddings.aembed_documents(
missing_texts
)
Expand Down Expand Up @@ -210,6 +217,7 @@ def from_bytes_store(
document_embedding_cache: ByteStore,
*,
namespace: str = "",
batch_size: Optional[int] = None,
) -> CacheBackedEmbeddings:
"""On-ramp that adds the necessary serialization and encoding to the store.
Expand All @@ -220,6 +228,7 @@ def from_bytes_store(
namespace: The namespace to use for document cache.
This namespace is used to avoid collisions with other caches.
For example, set it to the name of the embedding model used.
batch_size: The number of documents to embed between store updates.
"""
namespace = namespace
key_encoder = _create_key_encoder(namespace)
Expand All @@ -229,4 +238,4 @@ def from_bytes_store(
_value_serializer,
_value_deserializer,
)
return cls(underlying_embeddings, encoder_backed_store)
return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size)
45 changes: 45 additions & 0 deletions libs/langchain/tests/unit_tests/embeddings/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
# Simulate embedding documents
embeddings: List[List[float]] = []
for text in texts:
if text == "RAISE_EXCEPTION":
raise ValueError("Simulated embedding failure")
embeddings.append([len(text), len(text) + 1])
return embeddings

Expand All @@ -31,6 +33,16 @@ def cache_embeddings() -> CacheBackedEmbeddings:
)


@pytest.fixture
def cache_embeddings_batch() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings with a batch_size of 3."""
store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings, store, namespace="test_namespace", batch_size=3
)


def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
texts = ["1", "22", "a", "333"]
vectors = cache_embeddings.embed_documents(texts)
Expand All @@ -42,6 +54,20 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"


def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
try:
cache_embeddings_batch.embed_documents(texts)
except ValueError:
pass
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
# only the first batch of three embeddings should exist
assert len(keys) == 3
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"


def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = cache_embeddings.embed_query(text)
Expand All @@ -62,6 +88,25 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"


async def test_aembed_documents_batch(
cache_embeddings_batch: CacheBackedEmbeddings,
) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
try:
await cache_embeddings_batch.aembed_documents(texts)
except ValueError:
pass
keys = [
key
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
]
# only the first batch of three embeddings should exist
assert len(keys) == 3
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"


async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = await cache_embeddings.aembed_query(text)
Expand Down

0 comments on commit 305d74c

Please sign in to comment.