Skip to content

Commit

Permalink
Refactor MMR flow so as to enable it with vectorize stores (#40)
Browse files Browse the repository at this point in the history
* Refactor MMR flow so as to enable it with vectorize stores

* bump package version
  • Loading branch information
hemidactylus authored Jul 10, 2024
1 parent 0a38120 commit f28d956
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 51 deletions.
170 changes: 135 additions & 35 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,9 +1283,107 @@ async def asimilarity_search_with_score(
filter=filter,
)

def _run_mmr_query_by_sort(
self,
sort: Dict[str, Any],
k: int,
fetch_k: int,
lambda_mult: float,
metadata_parameter: Dict[str, Any],
**kwargs: Any,
) -> List[Document]:
query_vector_l: List[Optional[List[float]]] = [None]

def _response_setter(
resp: Dict[str, Any],
qvl: List[Optional[List[float]]] = query_vector_l,
) -> None:
qvl[0] = resp["status"]["sortVector"]

prefetch_hits = list(
# self.collection is not None (by _ensure_astra_db_client)
self.collection.paginated_find( # type: ignore[union-attr]
filter=metadata_parameter,
sort=sort,
options={
"limit": fetch_k,
"includeSimilarity": True,
"includeSortVector": True,
},
projection={
"_id": 1,
"content": 1,
"metadata": 1,
"$vector": 1,
"$vectorize": 1,
},
raw_response_callback=_response_setter,
)
)
query_vector = query_vector_l[0]
# the callback has surely filled query_vector:
return self._get_mmr_hits(
embedding=query_vector, # type: ignore[arg-type]
k=k,
lambda_mult=lambda_mult,
prefetch_hits=prefetch_hits,
content_field="$vectorize" if self._using_vectorize() else "content",
)

async def _arun_mmr_query_by_sort(
self,
sort: Dict[str, Any],
k: int,
fetch_k: int,
lambda_mult: float,
metadata_parameter: Dict[str, Any],
**kwargs: Any,
) -> List[Document]:
query_vector_l: List[Optional[List[float]]] = [None]

def _response_setter(
resp: Dict[str, Any],
qvl: List[Optional[List[float]]] = query_vector_l,
) -> None:
qvl[0] = resp["status"]["sortVector"]

prefetch_hits = [
hit
async for hit in self.async_collection.paginated_find(
filter=metadata_parameter,
sort=sort,
options={
"limit": fetch_k,
"includeSimilarity": True,
"includeSortVector": True,
},
projection={
"_id": 1,
"content": 1,
"metadata": 1,
"$vector": 1,
"$vectorize": 1,
},
raw_response_callback=_response_setter,
)
]
# the callback has surely filled query_vector:
query_vector = query_vector_l[0]
return self._get_mmr_hits(
embedding=query_vector, # type: ignore[arg-type]
k=k,
lambda_mult=lambda_mult,
prefetch_hits=prefetch_hits,
content_field="$vectorize" if self._using_vectorize() else "content",
)

@staticmethod
def _get_mmr_hits(
embedding: List[float], k: int, lambda_mult: float, prefetch_hits: List[DocDict]
embedding: List[float],
k: int,
lambda_mult: float,
prefetch_hits: List[DocDict],
content_field: str,
) -> List[Document]:
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
Expand All @@ -1300,7 +1398,7 @@ def _get_mmr_hits(
]
return [
Document(
page_content=hit["content"],
page_content=hit[content_field],
metadata=hit["metadata"],
)
for hit in mmr_hits
Expand Down Expand Up @@ -1335,23 +1433,14 @@ def max_marginal_relevance_search_by_vector(
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)

prefetch_hits = list(
# self.collection is not None (by _ensure_astra_db_client)
self.collection.paginated_find( # type: ignore[union-attr]
filter=metadata_parameter,
sort={"$vector": embedding},
options={"limit": fetch_k, "includeSimilarity": True},
projection={
"_id": 1,
"content": 1,
"metadata": 1,
"$vector": 1,
},
)
return self._run_mmr_query_by_sort(
sort={"$vector": embedding},
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
metadata_parameter=metadata_parameter,
)

return self._get_mmr_hits(embedding, k, lambda_mult, prefetch_hits)

async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
Expand Down Expand Up @@ -1381,22 +1470,13 @@ async def amax_marginal_relevance_search_by_vector(
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)

prefetch_hits = [
hit
async for hit in self.async_collection.paginated_find(
filter=metadata_parameter,
sort={"$vector": embedding},
options={"limit": fetch_k, "includeSimilarity": True},
projection={
"_id": 1,
"content": 1,
"metadata": 1,
"$vector": 1,
},
)
]

return self._get_mmr_hits(embedding, k, lambda_mult, prefetch_hits)
return await self._arun_mmr_query_by_sort(
sort={"$vector": embedding},
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
metadata_parameter=metadata_parameter,
)

def max_marginal_relevance_search(
self,
Expand Down Expand Up @@ -1425,7 +1505,17 @@ def max_marginal_relevance_search(
The list of Documents selected by maximal marginal relevance.
"""
if self._using_vectorize():
raise ValueError("MMR search is unsupported for server-side embeddings.")
# this case goes directly to the "_by_sort" method
# (and does its own filter normalization, as it cannot
# use the path for the with-embedding mmr querying)
metadata_parameter = self._filter_to_metadata(filter)
return self._run_mmr_query_by_sort(
sort={"$vectorize": query},
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
metadata_parameter=metadata_parameter,
)
else:
assert self.embedding is not None
embedding_vector = self.embedding.embed_query(query)
Expand Down Expand Up @@ -1464,7 +1554,17 @@ async def amax_marginal_relevance_search(
The list of Documents selected by maximal marginal relevance.
"""
if self._using_vectorize():
raise ValueError("MMR search is unsupported for server-side embeddings.")
# this case goes directly to the "_by_sort" method
# (and does its own filter normalization, as it cannot
# use the path for the with-embedding mmr querying)
metadata_parameter = self._filter_to_metadata(filter)
return await self._arun_mmr_query_by_sort(
sort={"$vectorize": query},
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
metadata_parameter=metadata_parameter,
)
else:
assert self.embedding is not None
embedding_vector = await self.embedding.aembed_query(query)
Expand Down
24 changes: 18 additions & 6 deletions libs/astradb/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions libs/astradb/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-astradb"
version = "0.3.3"
version = "0.3.4"
description = "An integration package connecting Astra DB and LangChain"
authors = []
readme = "README.md"
Expand All @@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.1.31,<0.3"
astrapy = "^1.2"
astrapy = "^1.3"
numpy = "^1"

[tool.poetry.group.test]
Expand Down
40 changes: 32 additions & 8 deletions libs/astradb/tests/integration_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,23 +805,47 @@ def _v_from_i(i: int, N: int) -> str:
res_i_vals = {doc.metadata["i"] for doc in res1}
assert res_i_vals == {0, 4}

def test_astradb_vectorstore_mmr_vectorize_unsupported_sync(
def test_astradb_vectorstore_mmr_vectorize_sync(
self, vectorize_store: AstraDBVectorStore
) -> None:
"""
MMR testing with vectorize, currently unsupported.
MMR testing with vectorize, sync.
"""
with pytest.raises(ValueError):
vectorize_store.max_marginal_relevance_search("aa", k=2, fetch_k=3)
vectorize_store.add_texts(
[
"Dog",
"Wolf",
"Ant",
"Sunshine and piadina",
],
ids=["d", "w", "a", "s"],
)

async def test_astradb_vectorstore_mmr_vectorize_unsupported_async(
hits = vectorize_store.max_marginal_relevance_search("Dingo", k=2, fetch_k=3)
assert {doc.page_content for doc in hits} == {"Dog", "Ant"}

async def test_astradb_vectorstore_mmr_vectorize_async(
self, vectorize_store: AstraDBVectorStore
) -> None:
"""
MMR async testing with vectorize, currently unsupported.
MMR async testing with vectorize, async.
"""
with pytest.raises(ValueError):
await vectorize_store.amax_marginal_relevance_search("aa", k=2, fetch_k=3)
await vectorize_store.aadd_texts(
[
"Dog",
"Wolf",
"Ant",
"Sunshine and piadina",
],
ids=["d", "w", "a", "s"],
)

hits = await vectorize_store.amax_marginal_relevance_search(
"Dingo",
k=2,
fetch_k=3,
)
assert {doc.page_content for doc in hits} == {"Dog", "Ant"}

@pytest.mark.parametrize(
"vector_store",
Expand Down

0 comments on commit f28d956

Please sign in to comment.