From fb2c393b4acdda3b54671ff52f5cdb9c74e2307c Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Tue, 5 Nov 2024 18:02:51 -0600 Subject: [PATCH] progress on simplification --- .../langchain_astradb/graph_vectorstores.py | 1267 +---------------- 1 file changed, 75 insertions(+), 1192 deletions(-) diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index e14e78e..a50bcd0 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -2,28 +2,29 @@ from __future__ import annotations -import asyncio import json import logging -import secrets from dataclasses import asdict, is_dataclass from typing import ( TYPE_CHECKING, Any, - AsyncIterable, Iterable, - Sequence, cast, ) -from langchain_community.graph_vectorstores.base import GraphVectorStore, Node -from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link +from langchain_community.graph_vectorstores.cassandra_base import ( + CassandraGraphVectorStoreBase, +) +from langchain_community.graph_vectorstores.links import ( + METADATA_LINKS_KEY, + Link, + get_links, + incoming_links, +) from langchain_core._api import beta -from langchain_core.documents import Document from typing_extensions import override from langchain_astradb.utils.astradb import COMPONENT_NAME_GRAPHVECTORSTORE, SetupMode -from langchain_astradb.utils.mmr_helper import MmrHelper from langchain_astradb.vectorstores import AstraDBVectorStore if TYPE_CHECKING: @@ -31,6 +32,7 @@ from astrapy.db import AstraDB as AstraDBClient from astrapy.db import AsyncAstraDB as AsyncAstraDBClient from astrapy.info import CollectionVectorServiceOptions + from langchain_core.documents import Document from langchain_core.embeddings import Embeddings DEFAULT_INDEXING_OPTIONS = {"allow": ["metadata"]} @@ -39,19 +41,6 @@ logger = logging.getLogger(__name__) -class EmbeddedNode: - id: str - links: list[Link] - embedding: list[float] - - def __init__(self, doc: Document, embedding: list[float]) -> None: - """Create an Embedded Node.""" - node = _doc_to_node(doc=doc) - self.id = node.id or "" - self.links = node.links - self.embedding = embedding - - def _serialize_links(links: list[Link]) -> str: class SetAndLinkEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: # noqa: ANN401 @@ -78,29 +67,8 @@ def _metadata_link_key(link: Link) -> str: return f"link:{link.kind}:{link.tag}" -def _doc_to_node(doc: Document) -> Node: - metadata = doc.metadata.copy() - links = _deserialize_links(metadata.get(METADATA_LINKS_KEY)) - metadata[METADATA_LINKS_KEY] = links - - return Node( - id=doc.id, - text=doc.page_content, - metadata=metadata, - links=list(links), - ) - - -def _incoming_links(node: Node | EmbeddedNode) -> set[Link]: - return {link for link in node.links if link.direction in ["in", "bidir"]} - - -def _outgoing_links(node: Node | EmbeddedNode) -> set[Link]: - return {link for link in node.links if link.direction in ["out", "bidir"]} - - @beta() -class AstraDBGraphVectorStore(GraphVectorStore): +class AstraDBGraphVectorStore(CassandraGraphVectorStoreBase): def __init__( self, *, @@ -276,32 +244,34 @@ def __init__( collection_indexing_policy["allow"] = list(allow_list) try: - self.vector_store = AstraDBVectorStore( - collection_name=collection_name, - embedding=embedding, - token=token, - api_endpoint=api_endpoint, - environment=environment, - namespace=namespace, - metric=metric, - batch_size=batch_size, - bulk_insert_batch_concurrency=bulk_insert_batch_concurrency, - bulk_insert_overwrite_concurrency=bulk_insert_overwrite_concurrency, - bulk_delete_concurrency=bulk_delete_concurrency, - setup_mode=setup_mode, - pre_delete_collection=pre_delete_collection, - metadata_indexing_include=metadata_indexing_include, - metadata_indexing_exclude=metadata_indexing_exclude, - collection_indexing_policy=collection_indexing_policy, - collection_vector_service_options=collection_vector_service_options, - collection_embedding_api_key=collection_embedding_api_key, - content_field=content_field, - ignore_invalid_documents=ignore_invalid_documents, - autodetect_collection=autodetect_collection, - ext_callers=ext_callers, - component_name=component_name, - astra_db_client=astra_db_client, - async_astra_db_client=async_astra_db_client, + super().__init__( + vector_store=AstraDBVectorStore( + collection_name=collection_name, + embedding=embedding, + token=token, + api_endpoint=api_endpoint, + environment=environment, + namespace=namespace, + metric=metric, + batch_size=batch_size, + bulk_insert_batch_concurrency=bulk_insert_batch_concurrency, + bulk_insert_overwrite_concurrency=bulk_insert_overwrite_concurrency, + bulk_delete_concurrency=bulk_delete_concurrency, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, + metadata_indexing_include=metadata_indexing_include, + metadata_indexing_exclude=metadata_indexing_exclude, + collection_indexing_policy=collection_indexing_policy, + collection_vector_service_options=collection_vector_service_options, + collection_embedding_api_key=collection_embedding_api_key, + content_field=content_field, + ignore_invalid_documents=ignore_invalid_documents, + autodetect_collection=autodetect_collection, + ext_callers=ext_callers, + component_name=component_name, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + ) ) # for the test search, if setup_mode is ASYNC, @@ -359,95 +329,59 @@ def __init__( self.astra_env = self.vector_store.astra_env - @property - @override - def embeddings(self) -> Embeddings | None: - return self.vector_store.embedding - - def _get_metadata_filter( - self, - metadata: dict[str, Any] | None = None, - outgoing_link: Link | None = None, - ) -> dict[str, Any]: - if outgoing_link is None: - return metadata or {} - - metadata_filter = {} if metadata is None else metadata.copy() - metadata_filter[self.metadata_incoming_links_key] = _metadata_link_key( - link=outgoing_link - ) - return metadata_filter - - def _restore_links(self, doc: Document) -> Document: - """Restores the links in the document by deserializing them from metadata. + def get_metadata_for_insertion(self, doc: Document) -> dict[str, Any]: + """Prepares the links in a document by serializing them to metadata. Args: - doc: A single Document + doc: Document to prepare Returns: - The same Document with restored links. + The document metadata ready for insertion into the database. """ - links = _deserialize_links(doc.metadata.get(METADATA_LINKS_KEY)) - doc.metadata[METADATA_LINKS_KEY] = links - if self.metadata_incoming_links_key in doc.metadata: - del doc.metadata[self.metadata_incoming_links_key] - return doc - - def _get_node_metadata_for_insertion(self, node: Node) -> dict[str, Any]: - metadata = node.metadata.copy() - metadata[METADATA_LINKS_KEY] = _serialize_links(node.links) + links = get_links(doc=doc) + metadata = doc.metadata.copy() + metadata[METADATA_LINKS_KEY] = _serialize_links(links=links) metadata[self.metadata_incoming_links_key] = [ - _metadata_link_key(link=link) for link in _incoming_links(node=node) + _metadata_link_key(link=link) for link in incoming_links(links=links) ] return metadata - def _get_docs_for_insertion( - self, nodes: Iterable[Node] - ) -> tuple[list[Document], list[str]]: - docs = [] - ids = [] - for node in nodes: - node_id = secrets.token_hex(8) if not node.id else node.id - - doc = Document( - page_content=node.text, - metadata=self._get_node_metadata_for_insertion(node=node), - id=node_id, - ) - docs.append(doc) - ids.append(node_id) - return (docs, ids) - - @override - def add_nodes( - self, - nodes: Iterable[Node], - **kwargs: Any, - ) -> Iterable[str]: - """Add nodes to the graph store. + def restore_links(self, doc: Document) -> Document: + """Restores links in a document by deserializing them from metadata. Args: - nodes: the nodes to add. - **kwargs: Additional keyword arguments. + doc: Document to restore + + Returns: + The document ready for use in the graph vector store. """ - (docs, ids) = self._get_docs_for_insertion(nodes=nodes) - return self.vector_store.add_documents(docs, ids=ids) + links = _deserialize_links(doc.metadata.get(METADATA_LINKS_KEY)) + doc.metadata[METADATA_LINKS_KEY] = links + doc.metadata.pop(self.metadata_incoming_links_key) + return doc - @override - async def aadd_nodes( + def get_metadata_filter( self, - nodes: Iterable[Node], - **kwargs: Any, - ) -> AsyncIterable[str]: - """Add nodes to the graph store. + metadata: dict[str, Any] | None = None, + outgoing_link: Link | None = None, + ) -> dict[str, Any]: + """Builds a metadata filter to search for document. Args: - nodes: the nodes to add. - **kwargs: Additional keyword arguments. + metadata: Any metadata that should be used for hybrid search + outgoing_link: An optional outgoing link to add to the search + + Returns: + The document metadata ready for insertion into the database. """ - (docs, ids) = self._get_docs_for_insertion(nodes=nodes) - for inserted_id in await self.vector_store.aadd_documents(docs, ids=ids): - yield inserted_id + if outgoing_link is None: + return metadata or {} + + metadata_filter = {} if metadata is None else metadata.copy() + metadata_filter[self.metadata_incoming_links_key] = _metadata_link_key( + link=outgoing_link + ) + return metadata_filter @classmethod @override @@ -536,1054 +470,3 @@ async def afrom_documents( ) await store.aadd_documents(documents, ids=ids) return store - - @override - def similarity_search( - self, - query: str, - k: int = 4, - filter: dict[str, Any] | None = None, - **kwargs: Any, - ) -> list[Document]: - """Retrieve documents from this graph store. - - Args: - query: The query string. - k: The number of Documents to return. Defaults to 4. - filter: Optional metadata to filter the results. - **kwargs: Additional keyword arguments. - - Returns: - Collection of retrieved documents. - """ - return [ - self._restore_links(doc) - for doc in self.vector_store.similarity_search( - query=query, - k=k, - filter=filter, - **kwargs, - ) - ] - - @override - async def asimilarity_search( - self, - query: str, - k: int = 4, - filter: dict[str, Any] | None = None, - **kwargs: Any, - ) -> list[Document]: - """Retrieve documents from this graph store. - - Args: - query: The query string. - k: The number of Documents to return. Defaults to 4. - filter: Optional metadata to filter the results. - **kwargs: Additional keyword arguments. - - Returns: - Collection of retrieved documents. - """ - return [ - self._restore_links(doc) - for doc in await self.vector_store.asimilarity_search( - query=query, - k=k, - filter=filter, - **kwargs, - ) - ] - - @override - def similarity_search_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: dict[str, Any] | None = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter on the metadata to apply. - **kwargs: Additional arguments are ignored. - - Returns: - The list of Documents most similar to the query vector. - """ - return [ - self._restore_links(doc) - for doc in self.vector_store.similarity_search_by_vector( - embedding, - k=k, - filter=filter, - **kwargs, - ) - ] - - @override - async def asimilarity_search_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: dict[str, Any] | None = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter on the metadata to apply. - **kwargs: Additional arguments are ignored. - - Returns: - The list of Documents most similar to the query vector. - """ - return [ - self._restore_links(doc) - for doc in await self.vector_store.asimilarity_search_by_vector( - embedding, - k=k, - filter=filter, - **kwargs, - ) - ] - - def metadata_search( - self, - filter: dict[str, Any] | None = None, # noqa: A002 - n: int = 5, - ) -> Iterable[Document]: - """Get documents via a metadata search. - - Args: - filter: the metadata to query for. - n: the maximum number of documents to return. - """ - return [ - self._restore_links(doc) - for doc in self.vector_store.metadata_search( - filter=filter or {}, - n=n, - ) - ] - - async def ametadata_search( - self, - filter: dict[str, Any] | None = None, # noqa: A002 - n: int = 5, - ) -> Iterable[Document]: - """Get documents via a metadata search. - - Args: - filter: the metadata to query for. - n: the maximum number of documents to return. - """ - return [ - self._restore_links(doc) - for doc in await self.vector_store.ametadata_search( - filter=filter or {}, - n=n, - ) - ] - - def get_by_document_id(self, document_id: str) -> Document | None: - """Retrieve a single document from the store, given its document ID. - - Args: - document_id: The document ID - - Returns: - The the document if it exists. Otherwise None. - """ - doc = self.vector_store.get_by_document_id(document_id=document_id) - return self._restore_links(doc) if doc is not None else None - - async def aget_by_document_id(self, document_id: str) -> Document | None: - """Retrieve a single document from the store, given its document ID. - - Args: - document_id: The document ID - - Returns: - The the document if it exists. Otherwise None. - """ - doc = await self.vector_store.aget_by_document_id(document_id=document_id) - return self._restore_links(doc) if doc is not None else None - - def get_node(self, node_id: str) -> Node | None: - """Retrieve a single node from the store, given its ID. - - Args: - node_id: The node ID - - Returns: - The the node if it exists. Otherwise None. - """ - doc = self.vector_store.get_by_document_id(document_id=node_id) - if doc is None: - return None - return _doc_to_node(doc=doc) - - async def aget_node(self, node_id: str) -> Node | None: - """Retrieve a single node from the store, given its ID. - - Args: - node_id: The node ID - - Returns: - The the node if it exists. Otherwise None. - """ - doc = await self.vector_store.aget_by_document_id(document_id=node_id) - if doc is None: - return None - return _doc_to_node(doc=doc) - - @override - async def ammr_traversal_search( # noqa: C901 - self, - query: str, - *, - initial_roots: Sequence[str] = (), - k: int = 4, - depth: int = 2, - fetch_k: int = 100, - adjacent_k: int = 10, - lambda_mult: float = 0.5, - score_threshold: float = float("-inf"), - filter: dict[str, Any] | None = None, - **kwargs: Any, - ) -> AsyncIterable[Document]: - """Retrieve documents from this graph store using MMR-traversal. - - This strategy first retrieves the top `fetch_k` results by similarity to - the question. It then selects the top `k` results based on - maximum-marginal relevance using the given `lambda_mult`. - - At each step, it considers the (remaining) documents from `fetch_k` as - well as any documents connected by edges to a selected document - retrieved based on similarity (a "root"). - - Args: - query: The query string to search for. - initial_roots: Optional list of document IDs to use for initializing search. - The top `adjacent_k` nodes adjacent to each initial root will be - included in the set of initial candidates. To fetch only in the - neighborhood of these nodes, set `fetch_k = 0`. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of initial Documents to fetch via similarity. - Will be added to the nodes adjacent to `initial_roots`. - Defaults to 100. - adjacent_k: Number of adjacent Documents to fetch. - Defaults to 10. - depth: Maximum depth of a node (number of edges) from a node - retrieved via similarity. Defaults to 2. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding to maximum - diversity and 1 to minimum diversity. Defaults to 0.5. - score_threshold: Only documents with a score greater than or equal - this threshold will be chosen. Defaults to -infinity. - filter: Optional metadata to filter the results. - **kwargs: Additional keyword arguments. - """ - # For each unselected node, stores the outgoing links. - outgoing_links_map: dict[str, set[Link]] = {} - visited_links: set[Link] = set() - # Map from id to Document, used as a cache - retrieved_docs: dict[str, Document] = {} - - def get_candidates(nodes: Iterable[EmbeddedNode]) -> dict[str, list[float]]: - nonlocal outgoing_links_map - - candidates: dict[str, list[float]] = {} - for node in nodes: - if node.id not in outgoing_links_map: - outgoing_links_map[node.id] = _outgoing_links(node=node) - candidates[node.id] = node.embedding - return candidates - - async def fetch_initial_candidates() -> ( - tuple[list[float], dict[str, list[float]]] - ): - """Gets the embedded query and the set of initial candidates. - - If fetch_k is zero, there will be no initial candidates. - """ - nonlocal retrieved_docs - - query_embedding, initial_nodes = await self._aget_initial( - query=query, - retrieved_docs=retrieved_docs, - fetch_k=fetch_k, - filter=filter, - ) - - return query_embedding, get_candidates(nodes=initial_nodes) - - async def fetch_neighborhood_candidates( - neighborhood: Sequence[str], - ) -> dict[str, list[float]]: - nonlocal outgoing_links_map, visited_links, retrieved_docs - - # Put the neighborhood into the outgoing links, to avoid adding it - # to the candidate set in the future. - outgoing_links_map.update( - {content_id: set() for content_id in neighborhood} - ) - - # Initialize the visited_links with the set of outgoing links from the - # neighborhood. This prevents re-visiting them. - visited_links = await self._aget_outgoing_links(neighborhood) - - # Call `self._aget_adjacent` to fetch the candidates. - adjacent_nodes = await self._aget_adjacent( - links=visited_links, - query_embedding=query_embedding, - k_per_link=adjacent_k, - filter=filter, - retrieved_docs=retrieved_docs, - ) - - return get_candidates(nodes=adjacent_nodes) - - query_embedding, initial_candidates = await fetch_initial_candidates() - helper = MmrHelper( - k=k, - query_embedding=query_embedding, - lambda_mult=lambda_mult, - score_threshold=score_threshold, - ) - helper.add_candidates(candidates=initial_candidates) - - if initial_roots: - neighborhood_candidates = await fetch_neighborhood_candidates(initial_roots) - helper.add_candidates(candidates=neighborhood_candidates) - - # Tracks the depth of each candidate. - depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()} - - # Select the best item, K times. - for _ in range(k): - selected_id = helper.pop_best() - - if selected_id is None: - break - - next_depth = depths[selected_id] + 1 - if next_depth < depth: - # If the next nodes would not exceed the depth limit, find the - # adjacent nodes. - - # Find the links linked to from the selected ID. - selected_outgoing_links = outgoing_links_map.pop(selected_id) - - # Don't re-visit already visited links. - selected_outgoing_links.difference_update(visited_links) - - # Find the nodes with incoming links from those links. - adjacent_nodes = await self._aget_adjacent( - links=selected_outgoing_links, - query_embedding=query_embedding, - k_per_link=adjacent_k, - filter=filter, - retrieved_docs=retrieved_docs, - ) - - # Record the selected_outgoing_links as visited. - visited_links.update(selected_outgoing_links) - - new_candidates = {} - for adjacent_node in adjacent_nodes: - if adjacent_node.id not in outgoing_links_map: - outgoing_links_map[adjacent_node.id] = _outgoing_links( - node=adjacent_node - ) - new_candidates[adjacent_node.id] = adjacent_node.embedding - if next_depth < depths.get(adjacent_node.id, depth + 1): - # If this is a new shortest depth, or there was no - # previous depth, update the depths. This ensures that - # when we discover a node we will have the shortest - # depth available. - # - # NOTE: No effort is made to traverse from nodes that - # were previously selected if they become reachable via - # a shorter path via nodes selected later. This is - # currently "intended", but may be worth experimenting - # with. - depths[adjacent_node.id] = next_depth - helper.add_candidates(new_candidates) - - for doc_id, similarity_score, mmr_score in zip( - helper.selected_ids, - helper.selected_similarity_scores, - helper.selected_mmr_scores, - ): - if doc_id in retrieved_docs: - doc = self._restore_links(retrieved_docs[doc_id]) - doc.metadata["similarity_score"] = similarity_score - doc.metadata["mmr_score"] = mmr_score - yield doc - else: - msg = f"retrieved_docs should contain id: {doc_id}" - raise RuntimeError(msg) - - @override - def mmr_traversal_search( # noqa: C901 - self, - query: str, - *, - initial_roots: Sequence[str] = (), - k: int = 4, - depth: int = 2, - fetch_k: int = 100, - adjacent_k: int = 10, - lambda_mult: float = 0.5, - score_threshold: float = float("-inf"), - filter: dict[str, Any] | None = None, - **kwargs: Any, - ) -> Iterable[Document]: - """Retrieve documents from this graph store using MMR-traversal. - - This strategy first retrieves the top `fetch_k` results by similarity to - the question. It then selects the top `k` results based on - maximum-marginal relevance using the given `lambda_mult`. - - At each step, it considers the (remaining) documents from `fetch_k` as - well as any documents connected by edges to a selected document - retrieved based on similarity (a "root"). - - Args: - query: The query string to search for. - initial_roots: Optional list of document IDs to use for initializing search. - The top `adjacent_k` nodes adjacent to each initial root will be - included in the set of initial candidates. To fetch only in the - neighborhood of these nodes, set `fetch_k = 0`. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of initial Documents to fetch via similarity. - Will be added to the nodes adjacent to `initial_roots`. - Defaults to 100. - adjacent_k: Number of adjacent Documents to fetch. - Defaults to 10. - depth: Maximum depth of a node (number of edges) from a node - retrieved via similarity. Defaults to 2. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding to maximum - diversity and 1 to minimum diversity. Defaults to 0.5. - score_threshold: Only documents with a score greater than or equal - this threshold will be chosen. Defaults to -infinity. - filter: Optional metadata to filter the results. - **kwargs: Additional keyword arguments. - """ - # For each unselected node, stores the outgoing links. - outgoing_links_map: dict[str, set[Link]] = {} - visited_links: set[Link] = set() - # Map from id to Document, used as a cache - retrieved_docs: dict[str, Document] = {} - - def get_candidates(nodes: Iterable[EmbeddedNode]) -> dict[str, list[float]]: - nonlocal outgoing_links_map - - candidates: dict[str, list[float]] = {} - for node in nodes: - if node.id not in outgoing_links_map: - outgoing_links_map[node.id] = _outgoing_links(node=node) - candidates[node.id] = node.embedding - return candidates - - def fetch_initial_candidates() -> tuple[list[float], dict[str, list[float]]]: - """Gets the embedded query and the set of initial candidates. - - If fetch_k is zero, there will be no initial candidates. - """ - nonlocal retrieved_docs - - query_embedding, initial_nodes = self._get_initial( - query=query, - retrieved_docs=retrieved_docs, - fetch_k=fetch_k, - filter=filter, - ) - - return query_embedding, get_candidates(nodes=initial_nodes) - - def fetch_neighborhood_candidates( - neighborhood: Sequence[str], - ) -> dict[str, list[float]]: - nonlocal outgoing_links_map, visited_links, retrieved_docs - - # Put the neighborhood into the outgoing links, to avoid adding it - # to the candidate set in the future. - outgoing_links_map.update( - {content_id: set() for content_id in neighborhood} - ) - - # Initialize the visited_links with the set of outgoing links from the - # neighborhood. This prevents re-visiting them. - visited_links = self._get_outgoing_links(neighborhood) - - # Call `self._get_adjacent` to fetch the candidates. - adjacent_nodes = self._get_adjacent( - links=visited_links, - query_embedding=query_embedding, - k_per_link=adjacent_k, - filter=filter, - retrieved_docs=retrieved_docs, - ) - - return get_candidates(nodes=adjacent_nodes) - - query_embedding, initial_candidates = fetch_initial_candidates() - helper = MmrHelper( - k=k, - query_embedding=query_embedding, - lambda_mult=lambda_mult, - score_threshold=score_threshold, - ) - helper.add_candidates(candidates=initial_candidates) - - if initial_roots: - neighborhood_candidates = fetch_neighborhood_candidates(initial_roots) - helper.add_candidates(candidates=neighborhood_candidates) - - # Tracks the depth of each candidate. - depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()} - - # Select the best item, K times. - for _ in range(k): - selected_id = helper.pop_best() - - if selected_id is None: - break - - next_depth = depths[selected_id] + 1 - if next_depth < depth: - # If the next nodes would not exceed the depth limit, find the - # adjacent nodes. - - # Find the links linked to from the selected ID. - selected_outgoing_links = outgoing_links_map.pop(selected_id) - - # Don't re-visit already visited links. - selected_outgoing_links.difference_update(visited_links) - - # Find the nodes with incoming links from those links. - adjacent_nodes = self._get_adjacent( - links=selected_outgoing_links, - query_embedding=query_embedding, - k_per_link=adjacent_k, - filter=filter, - retrieved_docs=retrieved_docs, - ) - - # Record the selected_outgoing_links as visited. - visited_links.update(selected_outgoing_links) - - new_candidates = {} - for adjacent_node in adjacent_nodes: - if adjacent_node.id not in outgoing_links_map: - outgoing_links_map[adjacent_node.id] = _outgoing_links( - node=adjacent_node - ) - new_candidates[adjacent_node.id] = adjacent_node.embedding - if next_depth < depths.get(adjacent_node.id, depth + 1): - # If this is a new shortest depth, or there was no - # previous depth, update the depths. This ensures that - # when we discover a node we will have the shortest - # depth available. - # - # NOTE: No effort is made to traverse from nodes that - # were previously selected if they become reachable via - # a shorter path via nodes selected later. This is - # currently "intended", but may be worth experimenting - # with. - depths[adjacent_node.id] = next_depth - helper.add_candidates(new_candidates) - - for doc_id, similarity_score, mmr_score in zip( - helper.selected_ids, - helper.selected_similarity_scores, - helper.selected_mmr_scores, - ): - if doc_id in retrieved_docs: - doc = self._restore_links(retrieved_docs[doc_id]) - doc.metadata["similarity_score"] = similarity_score - doc.metadata["mmr_score"] = mmr_score - yield doc - else: - msg = f"retrieved_docs should contain id: {doc_id}" - raise RuntimeError(msg) - - @override - async def atraversal_search( # noqa: C901 - self, - query: str, - *, - k: int = 4, - depth: int = 1, - filter: dict[str, Any] | None = None, - **kwargs: Any, - ) -> AsyncIterable[Document]: - """Retrieve documents from this knowledge store. - - First, `k` nodes are retrieved using a vector search for the `query` string. - Then, additional nodes are discovered up to the given `depth` from those - starting nodes. - - Args: - query: The query string. - k: The number of Documents to return from the initial vector search. - Defaults to 4. - depth: The maximum depth of edges to traverse. Defaults to 1. - filter: Optional metadata to filter the results. - **kwargs: Additional keyword arguments. - - Returns: - Collection of retrieved documents. - """ - # Depth 0: - # Query for `k` nodes similar to the question. - # Retrieve `content_id` and `outgoing_links()`. - # - # Depth 1: - # Query for nodes that have an incoming link in the `outgoing_links()` set. - # Combine node IDs. - # Query for `outgoing_links()` of those "new" node IDs. - # - # ... - - # Map from visited ID to depth - visited_ids: dict[str, int] = {} - - # Map from visited link to depth - visited_links: dict[Link, int] = {} - - # Map from id to Document - retrieved_docs: dict[str, Document] = {} - - async def visit_nodes(d: int, docs: Iterable[Document]) -> None: - """Recursively visit nodes and their outgoing links.""" - nonlocal visited_ids, visited_links, retrieved_docs - - # Iterate over nodes, tracking the *new* outgoing links for this - # depth. These are links that are either new, or newly discovered at a - # lower depth. - outgoing_links: set[Link] = set() - for doc in docs: - if doc.id is not None: - if doc.id not in retrieved_docs: - retrieved_docs[doc.id] = doc - - # If this node is at a closer depth, update visited_ids - if d <= visited_ids.get(doc.id, depth): - visited_ids[doc.id] = d - - # If we can continue traversing from this node, - if d < depth: - node = _doc_to_node(doc=doc) - # Record any new (or newly discovered at a lower depth) - # links to the set to traverse. - for link in _outgoing_links(node=node): - if d <= visited_links.get(link, depth): - # Record that we'll query this link at the - # given depth, so we don't fetch it again - # (unless we find it an earlier depth) - visited_links[link] = d - outgoing_links.add(link) - - if outgoing_links: - metadata_search_tasks = [] - for outgoing_link in outgoing_links: - metadata_filter = self._get_metadata_filter( - metadata=filter, - outgoing_link=outgoing_link, - ) - metadata_search_tasks.append( - asyncio.create_task( - self.vector_store.ametadata_search( - filter=metadata_filter, n=1000 - ) - ) - ) - results = await asyncio.gather(*metadata_search_tasks) - - # Visit targets concurrently - visit_target_tasks = [ - visit_targets(d=d + 1, docs=docs) for docs in results - ] - await asyncio.gather(*visit_target_tasks) - - async def visit_targets(d: int, docs: Iterable[Document]) -> None: - """Visit target nodes retrieved from outgoing links.""" - nonlocal visited_ids, retrieved_docs - - new_ids_at_next_depth = set() - for doc in docs: - if doc.id is not None: - if doc.id not in retrieved_docs: - retrieved_docs[doc.id] = doc - - if d <= visited_ids.get(doc.id, depth): - new_ids_at_next_depth.add(doc.id) - - if new_ids_at_next_depth: - visit_node_tasks = [ - visit_nodes(d=d, docs=[retrieved_docs[doc_id]]) - for doc_id in new_ids_at_next_depth - if doc_id in retrieved_docs - ] - - fetch_tasks = [ - asyncio.create_task( - self.vector_store.aget_by_document_id(document_id=doc_id) - ) - for doc_id in new_ids_at_next_depth - if doc_id not in retrieved_docs - ] - - new_docs: list[Document | None] = await asyncio.gather(*fetch_tasks) - - visit_node_tasks.extend( - visit_nodes(d=d, docs=[new_doc]) - for new_doc in new_docs - if new_doc is not None - ) - - await asyncio.gather(*visit_node_tasks) - - # Start the traversal - initial_docs = self.vector_store.similarity_search( - query=query, - k=k, - filter=filter, - ) - await visit_nodes(d=0, docs=initial_docs) - - for doc_id in visited_ids: - if doc_id in retrieved_docs: - yield self._restore_links(retrieved_docs[doc_id]) - else: - msg = f"retrieved_docs should contain id: {doc_id}" - raise RuntimeError(msg) - - @override - def traversal_search( # noqa: C901 - self, - query: str, - *, - k: int = 4, - depth: int = 1, - filter: dict[str, Any] | None = None, - **kwargs: Any, - ) -> Iterable[Document]: - """Retrieve documents from this knowledge store. - - First, `k` nodes are retrieved using a vector search for the `query` string. - Then, additional nodes are discovered up to the given `depth` from those - starting nodes. - - Args: - query: The query string. - k: The number of Documents to return from the initial vector search. - Defaults to 4. - depth: The maximum depth of edges to traverse. Defaults to 1. - filter: Optional metadata to filter the results. - **kwargs: Additional keyword arguments. - - Returns: - Collection of retrieved documents. - """ - # Depth 0: - # Query for `k` nodes similar to the question. - # Retrieve `content_id` and `outgoing_links()`. - # - # Depth 1: - # Query for nodes that have an incoming link in the `outgoing_links()` set. - # Combine node IDs. - # Query for `outgoing_links()` of those "new" node IDs. - # - # ... - - # Map from visited ID to depth - visited_ids: dict[str, int] = {} - - # Map from visited link to depth - visited_links: dict[Link, int] = {} - - # Map from id to Document - retrieved_docs: dict[str, Document] = {} - - def visit_nodes(d: int, docs: Iterable[Document]) -> None: - """Recursively visit nodes and their outgoing links.""" - nonlocal visited_ids, visited_links, retrieved_docs - - # Iterate over nodes, tracking the *new* outgoing links for this - # depth. These are links that are either new, or newly discovered at a - # lower depth. - outgoing_links: set[Link] = set() - for doc in docs: - if doc.id is not None: - if doc.id not in retrieved_docs: - retrieved_docs[doc.id] = doc - - # If this node is at a closer depth, update visited_ids - if d <= visited_ids.get(doc.id, depth): - visited_ids[doc.id] = d - - # If we can continue traversing from this node, - if d < depth: - node = _doc_to_node(doc=doc) - # Record any new (or newly discovered at a lower depth) - # links to the set to traverse. - for link in _outgoing_links(node=node): - if d <= visited_links.get(link, depth): - # Record that we'll query this link at the - # given depth, so we don't fetch it again - # (unless we find it an earlier depth) - visited_links[link] = d - outgoing_links.add(link) - - if outgoing_links: - for outgoing_link in outgoing_links: - metadata_filter = self._get_metadata_filter( - metadata=filter, - outgoing_link=outgoing_link, - ) - - docs = self.vector_store.metadata_search( - filter=metadata_filter, n=1000 - ) - - visit_targets(d=d + 1, docs=docs) - - def visit_targets(d: int, docs: Iterable[Document]) -> None: - """Visit target nodes retrieved from outgoing links.""" - nonlocal visited_ids, retrieved_docs - - new_ids_at_next_depth = set() - for doc in docs: - if doc.id is not None: - if doc.id not in retrieved_docs: - retrieved_docs[doc.id] = doc - - if d <= visited_ids.get(doc.id, depth): - new_ids_at_next_depth.add(doc.id) - - if new_ids_at_next_depth: - for doc_id in new_ids_at_next_depth: - if doc_id in retrieved_docs: - visit_nodes(d=d, docs=[retrieved_docs[doc_id]]) - else: - new_doc = self.vector_store.get_by_document_id( - document_id=doc_id - ) - if new_doc is not None: - visit_nodes(d=d, docs=[new_doc]) - - # Start the traversal - initial_docs = self.vector_store.similarity_search( - query=query, - k=k, - filter=filter, - ) - visit_nodes(d=0, docs=initial_docs) - - for doc_id in visited_ids: - if doc_id in retrieved_docs: - yield self._restore_links(retrieved_docs[doc_id]) - else: - msg = f"retrieved_docs should contain id: {doc_id}" - raise RuntimeError(msg) - - def _get_outgoing_links(self, source_ids: Iterable[str]) -> set[Link]: - """Return the set of outgoing links for the given source IDs synchronously. - - Args: - source_ids: The IDs of the source nodes to retrieve outgoing links for. - - Returns: - A set of `Link` objects representing the outgoing links from the source - nodes. - """ - links = set() - - for source_id in source_ids: - doc = self.vector_store.get_by_document_id(document_id=source_id) - if doc is not None: - node = _doc_to_node(doc=doc) - links.update(_outgoing_links(node=node)) - - return links - - async def _aget_outgoing_links(self, source_ids: Iterable[str]) -> set[Link]: - """Return the set of outgoing links for the given source IDs asynchronously. - - Args: - source_ids: The IDs of the source nodes to retrieve outgoing links for. - - Returns: - A set of `Link` objects representing the outgoing links from the source - nodes. - """ - links = set() - - # Create coroutine objects without scheduling them yet - coroutines = [ - self.vector_store.aget_by_document_id(document_id=source_id) - for source_id in source_ids - ] - - # Schedule and await all coroutines - docs = await asyncio.gather(*coroutines) - - for doc in docs: - if doc is not None: - node = _doc_to_node(doc=doc) - links.update(_outgoing_links(node=node)) - - return links - - def _get_initial( - self, - query: str, - retrieved_docs: dict[str, Document], - fetch_k: int, - filter: dict[str, Any] | None = None, # noqa: A002 - ) -> tuple[list[float], list[EmbeddedNode]]: - ( - query_embedding, - result, - ) = self.vector_store.similarity_search_with_embedding( - query=query, - k=fetch_k, - filter=filter, - ) - - initial_nodes: list[EmbeddedNode] = [] - for doc, embedding in result: - if doc.id is not None: - retrieved_docs[doc.id] = doc - initial_nodes.append(EmbeddedNode(doc=doc, embedding=embedding)) - - return query_embedding, initial_nodes - - async def _aget_initial( - self, - query: str, - retrieved_docs: dict[str, Document], - fetch_k: int, - filter: dict[str, Any] | None = None, # noqa: A002 - ) -> tuple[list[float], list[EmbeddedNode]]: - ( - query_embedding, - result, - ) = await self.vector_store.asimilarity_search_with_embedding( - query=query, - k=fetch_k, - filter=filter, - ) - - initial_nodes: list[EmbeddedNode] = [] - for doc, embedding in result: - if doc.id is not None: - retrieved_docs[doc.id] = doc - initial_nodes.append(EmbeddedNode(doc=doc, embedding=embedding)) - - return query_embedding, initial_nodes - - def _get_adjacent( - self, - links: set[Link], - query_embedding: list[float], - retrieved_docs: dict[str, Document], - k_per_link: int | None = None, - filter: dict[str, Any] | None = None, # noqa: A002 - ) -> Iterable[EmbeddedNode]: - """Return the target nodes with incoming links from any of the given links. - - Args: - links: The links to look for. - query_embedding: The query embedding. Used to rank target nodes. - retrieved_docs: A cache of retrieved docs. This will be added to. - k_per_link: The number of target nodes to fetch for each link. - filter: Optional metadata to filter the results. - - Returns: - Iterable of adjacent edges. - """ - targets: dict[str, EmbeddedNode] = {} - - for link in links: - metadata_filter = self._get_metadata_filter( - metadata=filter, - outgoing_link=link, - ) - - result = self.vector_store.similarity_search_with_embedding_by_vector( - embedding=query_embedding, - k=k_per_link or 10, - filter=metadata_filter, - ) - - for doc, embedding in result: - if doc.id is not None: - retrieved_docs[doc.id] = doc - if doc.id not in targets: - targets[doc.id] = EmbeddedNode(doc=doc, embedding=embedding) - - # TODO: Consider a combined limit based on the similarity and/or - # predicated MMR score? - return targets.values() - - async def _aget_adjacent( - self, - links: set[Link], - query_embedding: list[float], - retrieved_docs: dict[str, Document], - k_per_link: int | None = None, - filter: dict[str, Any] | None = None, # noqa: A002 - ) -> Iterable[EmbeddedNode]: - """Return the target nodes with incoming links from any of the given links. - - Args: - links: The links to look for. - query_embedding: The query embedding. Used to rank target nodes. - retrieved_docs: A cache of retrieved docs. This will be added to. - k_per_link: The number of target nodes to fetch for each link. - filter: Optional metadata to filter the results. - - Returns: - Iterable of adjacent edges. - """ - targets: dict[str, EmbeddedNode] = {} - - tasks = [] - for link in links: - metadata_filter = self._get_metadata_filter( - metadata=filter, - outgoing_link=link, - ) - - tasks.append( - self.vector_store.asimilarity_search_with_embedding_by_vector( - embedding=query_embedding, - k=k_per_link or 10, - filter=metadata_filter, - ) - ) - - results: list[list[tuple[Document, list[float]]]] = await asyncio.gather(*tasks) - - for result in results: - for doc, embedding in result: - if doc.id is not None: - retrieved_docs[doc.id] = doc - if doc.id not in targets: - targets[doc.id] = EmbeddedNode(doc=doc, embedding=embedding) - - # TODO: Consider a combined limit based on the similarity and/or - # predicated MMR score? - return targets.values()