Skip to content

Commit cc40106

Browse files
committed
feat: retrieve embeddings from database only when necessary
When performing a similarity search without using maximal marginal relevance, the database query includes the embeddings by default, whereas the retrived embeddings are discarded without use. This can be very suboptimal when retrieve a large number of documents due to communication overhead.
1 parent be2fc47 commit cc40106

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

langchain_postgres/vectorstores.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -1060,9 +1060,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
10601060
docs = [
10611061
(
10621062
Document(
1063-
id=str(result.EmbeddingStore.id),
1064-
page_content=result.EmbeddingStore.document,
1065-
metadata=result.EmbeddingStore.cmetadata,
1063+
id=str(result.id),
1064+
page_content=result.document,
1065+
metadata=result.cmetadata,
10661066
),
10671067
result.distance if self.embeddings is not None else None,
10681068
)
@@ -1395,8 +1395,16 @@ def __query_collection(
13951395
embedding: List[float],
13961396
k: int = 4,
13971397
filter: Optional[Dict[str, str]] = None,
1398+
retrieve_embeddings: bool = False,
13981399
) -> Sequence[Any]:
13991400
"""Query the collection."""
1401+
columns_to_select = [
1402+
self.EmbeddingStore.id,
1403+
self.EmbeddingStore.document,
1404+
self.EmbeddingStore.cmetadata,
1405+
]
1406+
if retrieve_embeddings:
1407+
columns_to_select.append(self.EmbeddingStore.embedding)
14001408
with self._make_sync_session() as session: # type: ignore[arg-type]
14011409
collection = self.get_collection(session)
14021410
if not collection:
@@ -1417,7 +1425,7 @@ def __query_collection(
14171425

14181426
results: List[Any] = (
14191427
session.query(
1420-
self.EmbeddingStore,
1428+
*columns_to_select,
14211429
self.distance_strategy(embedding).label("distance"),
14221430
)
14231431
.filter(*filter_by)
@@ -1438,8 +1446,16 @@ async def __aquery_collection(
14381446
embedding: List[float],
14391447
k: int = 4,
14401448
filter: Optional[Dict[str, str]] = None,
1449+
retrieve_embeddings: bool = False,
14411450
) -> Sequence[Any]:
14421451
"""Query the collection."""
1452+
columns_to_select = [
1453+
self.EmbeddingStore.id,
1454+
self.EmbeddingStore.document,
1455+
self.EmbeddingStore.cmetadata,
1456+
]
1457+
if retrieve_embeddings:
1458+
columns_to_select.append(self.EmbeddingStore.embedding)
14431459
async with self._make_async_session() as session: # type: ignore[arg-type]
14441460
collection = await self.aget_collection(session)
14451461
if not collection:
@@ -1460,7 +1476,7 @@ async def __aquery_collection(
14601476

14611477
stmt = (
14621478
select(
1463-
self.EmbeddingStore,
1479+
*columns_to_select,
14641480
self.distance_strategy(embedding).label("distance"),
14651481
)
14661482
.filter(*filter_by)
@@ -1899,9 +1915,11 @@ def max_marginal_relevance_search_with_score_by_vector(
18991915
relevance to the query and score for each.
19001916
"""
19011917
assert not self._async_engine, "This method must be called without async_mode"
1902-
results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
1918+
results = self.__query_collection(
1919+
embedding=embedding, k=fetch_k, filter=filter, retrieve_embeddings=True
1920+
)
19031921

1904-
embedding_list = [result.EmbeddingStore.embedding for result in results]
1922+
embedding_list = [result.embedding for result in results]
19051923

19061924
mmr_selected = maximal_marginal_relevance(
19071925
np.array(embedding, dtype=np.float32),
@@ -1947,10 +1965,14 @@ async def amax_marginal_relevance_search_with_score_by_vector(
19471965
await self.__apost_init__() # Lazy async init
19481966
async with self._make_async_session() as session:
19491967
results = await self.__aquery_collection(
1950-
session=session, embedding=embedding, k=fetch_k, filter=filter
1968+
session=session,
1969+
embedding=embedding,
1970+
k=fetch_k,
1971+
filter=filter,
1972+
retrieve_embeddings=True,
19511973
)
19521974

1953-
embedding_list = [result.EmbeddingStore.embedding for result in results]
1975+
embedding_list = [result.embedding for result in results]
19541976

19551977
mmr_selected = maximal_marginal_relevance(
19561978
np.array(embedding, dtype=np.float32),

0 commit comments

Comments
 (0)