Skip to content

Commit

Permalink
Add Vectorize support to AstraDBGraphVectorStore (#98)
Browse files Browse the repository at this point in the history
* progress on graph vectorize

* about to remove by_vector tests

* now only mmr failing

* changes for mmr

* lint

* combined tests

* minor tweaks

* simiplify method in vectorstore

* shortcut return

* split search_with_embedding method and added sync version
  • Loading branch information
epinzur authored Oct 25, 2024
1 parent 89250df commit 9f0ca93
Show file tree
Hide file tree
Showing 5 changed files with 909 additions and 259 deletions.
166 changes: 102 additions & 64 deletions libs/astradb/langchain_astradb/graph_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@
logger = logging.getLogger(__name__)


class AdjacentNode:
class EmbeddedNode:
id: str
links: list[Link]
embedding: list[float]

def __init__(self, node: Node, embedding: list[float]) -> None:
"""Create an Adjacent Node."""
def __init__(self, doc: Document, embedding: list[float]) -> None:
"""Create an Embedded Node."""
node = _doc_to_node(doc=doc)
self.id = node.id or ""
self.links = node.links
self.embedding = embedding
Expand Down Expand Up @@ -90,11 +91,11 @@ def _doc_to_node(doc: Document) -> Node:
)


def _incoming_links(node: Node | AdjacentNode) -> set[Link]:
def _incoming_links(node: Node | EmbeddedNode) -> set[Link]:
return {link for link in node.links if link.direction in ["in", "bidir"]}


def _outgoing_links(node: Node | AdjacentNode) -> set[Link]:
def _outgoing_links(node: Node | EmbeddedNode) -> set[Link]:
return {link for link in node.links if link.direction in ["out", "bidir"]}


Expand All @@ -104,7 +105,7 @@ def __init__(
self,
*,
collection_name: str,
embedding: Embeddings,
embedding: Embeddings | None = None,
metadata_incoming_links_key: str = "incoming_links",
token: str | TokenProvider | None = None,
api_endpoint: str | None = None,
Expand Down Expand Up @@ -262,7 +263,6 @@ def __init__(
:meth:`~add_texts` and :meth:`~add_documents` as well.
"""
self.metadata_incoming_links_key = metadata_incoming_links_key
self.embedding = embedding

# update indexing policy to ensure incoming_links are indexed
if metadata_indexing_include is not None:
Expand Down Expand Up @@ -362,7 +362,7 @@ def __init__(
@property
@override
def embeddings(self) -> Embeddings | None:
return self.embedding
return self.vector_store.embedding

def _get_metadata_filter(
self,
Expand Down Expand Up @@ -454,13 +454,20 @@ async def aadd_nodes(
def from_texts(
cls: type[AstraDBGraphVectorStore],
texts: Iterable[str],
embedding: Embeddings,
embedding: Embeddings | None = None,
metadatas: list[dict] | None = None,
ids: Iterable[str] | None = None,
collection_vector_service_options: CollectionVectorServiceOptions | None = None,
collection_embedding_api_key: str | EmbeddingHeadersProvider | None = None,
**kwargs: Any,
) -> AstraDBGraphVectorStore:
"""Return AstraDBGraphVectorStore initialized from texts and embeddings."""
store = cls(embedding=embedding, **kwargs)
store = cls(
embedding=embedding,
collection_vector_service_options=collection_vector_service_options,
collection_embedding_api_key=collection_embedding_api_key,
**kwargs,
)
store.add_texts(texts, metadatas, ids=ids)
return store

Expand All @@ -469,12 +476,19 @@ def from_texts(
def from_documents(
cls: type[AstraDBGraphVectorStore],
documents: Iterable[Document],
embedding: Embeddings,
embedding: Embeddings | None = None,
ids: Iterable[str] | None = None,
collection_vector_service_options: CollectionVectorServiceOptions | None = None,
collection_embedding_api_key: str | EmbeddingHeadersProvider | None = None,
**kwargs: Any,
) -> AstraDBGraphVectorStore:
"""Return AstraDBGraphVectorStore initialized from docs and embeddings."""
store = cls(embedding=embedding, **kwargs)
store = cls(
embedding=embedding,
collection_vector_service_options=collection_vector_service_options,
collection_embedding_api_key=collection_embedding_api_key,
**kwargs,
)
store.add_documents(documents, ids=ids)
return store

Expand Down Expand Up @@ -717,21 +731,43 @@ async def ammr_traversal_search( # noqa: C901
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
"""
query_embedding = self.embedding.embed_query(query)
helper = MmrHelper(
k=k,
query_embedding=query_embedding,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
)

# For each unselected node, stores the outgoing links.
outgoing_links_map: dict[str, set[Link]] = {}
visited_links: set[Link] = set()
# Map from id to Document
# Map from id to Document, used as a cache
retrieved_docs: dict[str, Document] = {}

async def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
def get_candidates(nodes: Iterable[EmbeddedNode]) -> dict[str, list[float]]:
nonlocal outgoing_links_map

candidates: dict[str, list[float]] = {}
for node in nodes:
if node.id not in outgoing_links_map:
outgoing_links_map[node.id] = _outgoing_links(node=node)
candidates[node.id] = node.embedding
return candidates

async def fetch_initial_candidates() -> (
tuple[list[float], dict[str, list[float]]]
):
"""Gets the embedded query and the set of initial candidates.
If fetch_k is zero, there will be no initial candidates.
"""
nonlocal retrieved_docs

query_embedding, initial_nodes = await self._get_initial(
query=query,
retrieved_docs=retrieved_docs,
fetch_k=fetch_k,
filter=filter,
)

return query_embedding, get_candidates(nodes=initial_nodes)

async def fetch_neighborhood_candidates(
neighborhood: Sequence[str],
) -> dict[str, list[float]]:
nonlocal outgoing_links_map, visited_links, retrieved_docs

# Put the neighborhood into the outgoing links, to avoid adding it
Expand All @@ -753,41 +789,20 @@ async def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
retrieved_docs=retrieved_docs,
)

new_candidates: dict[str, list[float]] = {}
for adjacent_node in adjacent_nodes:
if adjacent_node.id not in outgoing_links_map:
outgoing_links_map[adjacent_node.id] = _outgoing_links(
node=adjacent_node
)
new_candidates[adjacent_node.id] = adjacent_node.embedding
helper.add_candidates(new_candidates)

async def fetch_initial_candidates() -> None:
nonlocal outgoing_links_map, visited_links, retrieved_docs

results = (
await self.vector_store.asimilarity_search_with_embedding_id_by_vector(
embedding=query_embedding,
k=fetch_k,
filter=filter,
)
)

candidates: dict[str, list[float]] = {}
for doc, embedding, doc_id in results:
if doc_id not in retrieved_docs:
retrieved_docs[doc_id] = doc
return get_candidates(nodes=adjacent_nodes)

if doc_id not in outgoing_links_map:
node = _doc_to_node(doc)
outgoing_links_map[doc_id] = _outgoing_links(node=node)
candidates[doc_id] = embedding
helper.add_candidates(candidates)
query_embedding, initial_candidates = await fetch_initial_candidates()
helper = MmrHelper(
k=k,
query_embedding=query_embedding,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
)
helper.add_candidates(candidates=initial_candidates)

if initial_roots:
await fetch_neighborhood(initial_roots)
if fetch_k > 0:
await fetch_initial_candidates()
neighborhood_candidates = await fetch_neighborhood_candidates(initial_roots)
helper.add_candidates(candidates=neighborhood_candidates)

# Tracks the depth of each candidate.
depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()}
Expand Down Expand Up @@ -1142,14 +1157,38 @@ async def _get_outgoing_links(self, source_ids: Iterable[str]) -> set[Link]:

return links

async def _get_initial(
self,
query: str,
retrieved_docs: dict[str, Document],
fetch_k: int,
filter: dict[str, Any] | None = None, # noqa: A002
) -> tuple[list[float], list[EmbeddedNode]]:
(
query_embedding,
result,
) = await self.vector_store.asimilarity_search_with_embedding(
query=query,
k=fetch_k,
filter=filter,
)

initial_nodes: list[EmbeddedNode] = []
for doc, embedding in result:
if doc.id is not None:
retrieved_docs[doc.id] = doc
initial_nodes.append(EmbeddedNode(doc=doc, embedding=embedding))

return query_embedding, initial_nodes

async def _get_adjacent(
self,
links: set[Link],
query_embedding: list[float],
retrieved_docs: dict[str, Document],
k_per_link: int | None = None,
filter: dict[str, Any] | None = None, # noqa: A002
) -> Iterable[AdjacentNode]:
) -> Iterable[EmbeddedNode]:
"""Return the target nodes with incoming links from any of the given links.
Args:
Expand All @@ -1162,7 +1201,7 @@ async def _get_adjacent(
Returns:
Iterable of adjacent edges.
"""
targets: dict[str, AdjacentNode] = {}
targets: dict[str, EmbeddedNode] = {}

tasks = []
for link in links:
Expand All @@ -1172,22 +1211,21 @@ async def _get_adjacent(
)

tasks.append(
self.vector_store.asimilarity_search_with_embedding_id_by_vector(
self.vector_store.asimilarity_search_with_embedding_by_vector(
embedding=query_embedding,
k=k_per_link or 10,
filter=metadata_filter,
)
)

results = await asyncio.gather(*tasks)
results: list[list[tuple[Document, list[float]]]] = await asyncio.gather(*tasks)

for result in results:
for doc, embedding, doc_id in result:
if doc_id not in retrieved_docs:
retrieved_docs[doc_id] = doc
if doc_id not in targets:
node = _doc_to_node(doc=doc)
targets[doc_id] = AdjacentNode(node=node, embedding=embedding)
for doc, embedding in result:
if doc.id is not None:
retrieved_docs[doc.id] = doc
if doc.id not in targets:
targets[doc.id] = EmbeddedNode(doc=doc, embedding=embedding)

# TODO: Consider a combined limit based on the similarity and/or
# predicated MMR score?
Expand Down
Loading

0 comments on commit 9f0ca93

Please sign in to comment.