@@ -1061,9 +1061,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
1061
1061
docs = [
1062
1062
(
1063
1063
Document (
1064
- id = str (result .EmbeddingStore . id ),
1065
- page_content = result .EmbeddingStore . document ,
1066
- metadata = result .EmbeddingStore . cmetadata ,
1064
+ id = str (result .id ),
1065
+ page_content = result .document ,
1066
+ metadata = result .cmetadata ,
1067
1067
),
1068
1068
result .distance if self .embeddings is not None else None ,
1069
1069
)
@@ -1396,8 +1396,16 @@ def __query_collection(
1396
1396
embedding : List [float ],
1397
1397
k : int = 4 ,
1398
1398
filter : Optional [Dict [str , str ]] = None ,
1399
+ retrieve_embeddings : bool = False ,
1399
1400
) -> Sequence [Any ]:
1400
1401
"""Query the collection."""
1402
+ columns_to_select = [
1403
+ self .EmbeddingStore .id ,
1404
+ self .EmbeddingStore .document ,
1405
+ self .EmbeddingStore .cmetadata ,
1406
+ ]
1407
+ if retrieve_embeddings :
1408
+ columns_to_select .append (self .EmbeddingStore .embedding )
1401
1409
with self ._make_sync_session () as session : # type: ignore[arg-type]
1402
1410
collection = self .get_collection (session )
1403
1411
if not collection :
@@ -1418,7 +1426,7 @@ def __query_collection(
1418
1426
1419
1427
results : List [Any ] = (
1420
1428
session .query (
1421
- self . EmbeddingStore ,
1429
+ * columns_to_select ,
1422
1430
self .distance_strategy (embedding ).label ("distance" ),
1423
1431
)
1424
1432
.filter (* filter_by )
@@ -1439,8 +1447,16 @@ async def __aquery_collection(
1439
1447
embedding : List [float ],
1440
1448
k : int = 4 ,
1441
1449
filter : Optional [Dict [str , str ]] = None ,
1450
+ retrieve_embeddings : bool = False ,
1442
1451
) -> Sequence [Any ]:
1443
1452
"""Query the collection."""
1453
+ columns_to_select = [
1454
+ self .EmbeddingStore .id ,
1455
+ self .EmbeddingStore .document ,
1456
+ self .EmbeddingStore .cmetadata ,
1457
+ ]
1458
+ if retrieve_embeddings :
1459
+ columns_to_select .append (self .EmbeddingStore .embedding )
1444
1460
async with self ._make_async_session () as session : # type: ignore[arg-type]
1445
1461
collection = await self .aget_collection (session )
1446
1462
if not collection :
@@ -1900,9 +1916,11 @@ def max_marginal_relevance_search_with_score_by_vector(
1900
1916
relevance to the query and score for each.
1901
1917
"""
1902
1918
assert not self ._async_engine , "This method must be called without async_mode"
1903
- results = self .__query_collection (embedding = embedding , k = fetch_k , filter = filter )
1919
+ results = self .__query_collection (
1920
+ embedding = embedding , k = fetch_k , filter = filter , retrieve_embeddings = True
1921
+ )
1904
1922
1905
- embedding_list = [result .EmbeddingStore . embedding for result in results ]
1923
+ embedding_list = [result .embedding for result in results ]
1906
1924
1907
1925
mmr_selected = maximal_marginal_relevance (
1908
1926
np .array (embedding , dtype = np .float32 ),
@@ -1948,10 +1966,14 @@ async def amax_marginal_relevance_search_with_score_by_vector(
1948
1966
await self .__apost_init__ () # Lazy async init
1949
1967
async with self ._make_async_session () as session :
1950
1968
results = await self .__aquery_collection (
1951
- session = session , embedding = embedding , k = fetch_k , filter = filter
1969
+ session = session ,
1970
+ embedding = embedding ,
1971
+ k = fetch_k ,
1972
+ filter = filter ,
1973
+ retrieve_embeddings = True ,
1952
1974
)
1953
1975
1954
- embedding_list = [result .EmbeddingStore . embedding for result in results ]
1976
+ embedding_list = [result .embedding for result in results ]
1955
1977
1956
1978
mmr_selected = maximal_marginal_relevance (
1957
1979
np .array (embedding , dtype = np .float32 ),
0 commit comments