Skip to content

Commit

Permalink
feat: retrieve embeddings only from database when necessary
Browse files Browse the repository at this point in the history
When performing a similarity search without using maximal marginal
relevance, the database query includes the embeddings by default,
whereas the retrived embeddings are discarded without use.

This can be very suboptimal when retrieve a large number of documents
due to communication overhead.
  • Loading branch information
fangyi-zhou committed Sep 18, 2024
1 parent b72b86a commit 3faed2b
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions langchain_postgres/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,9 +1061,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
docs = [
(
Document(
id=str(result.EmbeddingStore.id),
page_content=result.EmbeddingStore.document,
metadata=result.EmbeddingStore.cmetadata,
id=str(result.id),
page_content=result.document,
metadata=result.cmetadata,
),
result.distance if self.embeddings is not None else None,
)
Expand Down Expand Up @@ -1396,8 +1396,16 @@ def __query_collection(
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
retrieve_embeddings: bool = False,
) -> Sequence[Any]:
"""Query the collection."""
columns_to_select = [
self.EmbeddingStore.id,
self.EmbeddingStore.document,
self.EmbeddingStore.cmetadata,
]
if retrieve_embeddings:
columns_to_select.append(self.EmbeddingStore.embedding)
with self._make_sync_session() as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
Expand All @@ -1418,7 +1426,7 @@ def __query_collection(

results: List[Any] = (
session.query(
self.EmbeddingStore,
*columns_to_select,
self.distance_strategy(embedding).label("distance"),
)
.filter(*filter_by)
Expand All @@ -1439,8 +1447,16 @@ async def __aquery_collection(
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
retrieve_embeddings: bool = False,
) -> Sequence[Any]:
"""Query the collection."""
columns_to_select = [
self.EmbeddingStore.id,
self.EmbeddingStore.document,
self.EmbeddingStore.cmetadata,
]
if retrieve_embeddings:
columns_to_select.append(self.EmbeddingStore.embedding)
async with self._make_async_session() as session: # type: ignore[arg-type]
collection = await self.aget_collection(session)
if not collection:
Expand Down Expand Up @@ -1900,9 +1916,11 @@ def max_marginal_relevance_search_with_score_by_vector(
relevance to the query and score for each.
"""
assert not self._async_engine, "This method must be called without async_mode"
results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
results = self.__query_collection(
embedding=embedding, k=fetch_k, filter=filter, retrieve_embeddings=True
)

embedding_list = [result.EmbeddingStore.embedding for result in results]
embedding_list = [result.embedding for result in results]

mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
Expand Down Expand Up @@ -1948,10 +1966,14 @@ async def amax_marginal_relevance_search_with_score_by_vector(
await self.__apost_init__() # Lazy async init
async with self._make_async_session() as session:
results = await self.__aquery_collection(
session=session, embedding=embedding, k=fetch_k, filter=filter
session=session,
embedding=embedding,
k=fetch_k,
filter=filter,
retrieve_embeddings=True,
)

embedding_list = [result.EmbeddingStore.embedding for result in results]
embedding_list = [result.embedding for result in results]

mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
Expand Down

0 comments on commit 3faed2b

Please sign in to comment.