@@ -1060,9 +1060,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
1060
1060
docs = [
1061
1061
(
1062
1062
Document (
1063
- id = str (result .EmbeddingStore . id ),
1064
- page_content = result .EmbeddingStore . document ,
1065
- metadata = result .EmbeddingStore . cmetadata ,
1063
+ id = str (result .id ),
1064
+ page_content = result .document ,
1065
+ metadata = result .cmetadata ,
1066
1066
),
1067
1067
result .distance if self .embeddings is not None else None ,
1068
1068
)
@@ -1395,8 +1395,16 @@ def __query_collection(
1395
1395
embedding : List [float ],
1396
1396
k : int = 4 ,
1397
1397
filter : Optional [Dict [str , str ]] = None ,
1398
+ retrieve_embeddings : bool = False ,
1398
1399
) -> Sequence [Any ]:
1399
1400
"""Query the collection."""
1401
+ columns_to_select = [
1402
+ self .EmbeddingStore .id ,
1403
+ self .EmbeddingStore .document ,
1404
+ self .EmbeddingStore .cmetadata ,
1405
+ ]
1406
+ if retrieve_embeddings :
1407
+ columns_to_select .append (self .EmbeddingStore .embedding )
1400
1408
with self ._make_sync_session () as session : # type: ignore[arg-type]
1401
1409
collection = self .get_collection (session )
1402
1410
if not collection :
@@ -1417,7 +1425,7 @@ def __query_collection(
1417
1425
1418
1426
results : List [Any ] = (
1419
1427
session .query (
1420
- self . EmbeddingStore ,
1428
+ * columns_to_select ,
1421
1429
self .distance_strategy (embedding ).label ("distance" ),
1422
1430
)
1423
1431
.filter (* filter_by )
@@ -1438,8 +1446,16 @@ async def __aquery_collection(
1438
1446
embedding : List [float ],
1439
1447
k : int = 4 ,
1440
1448
filter : Optional [Dict [str , str ]] = None ,
1449
+ retrieve_embeddings : bool = False ,
1441
1450
) -> Sequence [Any ]:
1442
1451
"""Query the collection."""
1452
+ columns_to_select = [
1453
+ self .EmbeddingStore .id ,
1454
+ self .EmbeddingStore .document ,
1455
+ self .EmbeddingStore .cmetadata ,
1456
+ ]
1457
+ if retrieve_embeddings :
1458
+ columns_to_select .append (self .EmbeddingStore .embedding )
1443
1459
async with self ._make_async_session () as session : # type: ignore[arg-type]
1444
1460
collection = await self .aget_collection (session )
1445
1461
if not collection :
@@ -1460,7 +1476,7 @@ async def __aquery_collection(
1460
1476
1461
1477
stmt = (
1462
1478
select (
1463
- self . EmbeddingStore ,
1479
+ * columns_to_select ,
1464
1480
self .distance_strategy (embedding ).label ("distance" ),
1465
1481
)
1466
1482
.filter (* filter_by )
@@ -1899,9 +1915,11 @@ def max_marginal_relevance_search_with_score_by_vector(
1899
1915
relevance to the query and score for each.
1900
1916
"""
1901
1917
assert not self ._async_engine , "This method must be called without async_mode"
1902
- results = self .__query_collection (embedding = embedding , k = fetch_k , filter = filter )
1918
+ results = self .__query_collection (
1919
+ embedding = embedding , k = fetch_k , filter = filter , retrieve_embeddings = True
1920
+ )
1903
1921
1904
- embedding_list = [result .EmbeddingStore . embedding for result in results ]
1922
+ embedding_list = [result .embedding for result in results ]
1905
1923
1906
1924
mmr_selected = maximal_marginal_relevance (
1907
1925
np .array (embedding , dtype = np .float32 ),
@@ -1947,10 +1965,14 @@ async def amax_marginal_relevance_search_with_score_by_vector(
1947
1965
await self .__apost_init__ () # Lazy async init
1948
1966
async with self ._make_async_session () as session :
1949
1967
results = await self .__aquery_collection (
1950
- session = session , embedding = embedding , k = fetch_k , filter = filter
1968
+ session = session ,
1969
+ embedding = embedding ,
1970
+ k = fetch_k ,
1971
+ filter = filter ,
1972
+ retrieve_embeddings = True ,
1951
1973
)
1952
1974
1953
- embedding_list = [result .EmbeddingStore . embedding for result in results ]
1975
+ embedding_list = [result .embedding for result in results ]
1954
1976
1955
1977
mmr_selected = maximal_marginal_relevance (
1956
1978
np .array (embedding , dtype = np .float32 ),
0 commit comments