From 289d44626fc69088b0e6fe5d260764c9be4b9221 Mon Sep 17 00:00:00 2001 From: leo-gan Date: Mon, 26 Feb 2024 11:27:48 -0800 Subject: [PATCH 1/3] refactor --- .../langchain_community/indexes/_api.py | 598 +++++++++ .../langchain_community/indexes/graph.py | 47 + .../indexes/prompts/__init__.py | 0 .../indexes/prompts/entity_extraction.py | 0 .../indexes/prompts/entity_summarization.py | 0 .../prompts/knowledge_triplet_extraction.py | 0 .../indexes/vectorstore.py | 91 ++ .../indexes/test_hashed_document.py | 50 + .../tests/unit_tests/indexes/test_indexing.py | 1155 +++++++++++++++++ libs/langchain/langchain/indexes/_api.py | 609 +-------- .../langchain/indexes/_sql_record_manager.py | 521 +------- libs/langchain/langchain/indexes/base.py | 175 +-- libs/langchain/langchain/indexes/graph.py | 50 +- .../langchain/indexes/vectorstore.py | 94 +- 14 files changed, 1970 insertions(+), 1420 deletions(-) create mode 100644 libs/community/langchain_community/indexes/_api.py create mode 100644 libs/community/langchain_community/indexes/graph.py rename libs/{langchain/langchain => community/langchain_community}/indexes/prompts/__init__.py (100%) rename libs/{langchain/langchain => community/langchain_community}/indexes/prompts/entity_extraction.py (100%) rename libs/{langchain/langchain => community/langchain_community}/indexes/prompts/entity_summarization.py (100%) rename libs/{langchain/langchain => community/langchain_community}/indexes/prompts/knowledge_triplet_extraction.py (100%) create mode 100644 libs/community/langchain_community/indexes/vectorstore.py create mode 100644 libs/community/tests/unit_tests/indexes/test_hashed_document.py create mode 100644 libs/community/tests/unit_tests/indexes/test_indexing.py diff --git a/libs/community/langchain_community/indexes/_api.py b/libs/community/langchain_community/indexes/_api.py new file mode 100644 index 0000000000000..4dfb5e9eeac5a --- /dev/null +++ b/libs/community/langchain_community/indexes/_api.py @@ -0,0 +1,598 @@ +"""Module contains logic for indexing documents into vector stores.""" +from __future__ import annotations + +import hashlib +import json +import uuid +from itertools import islice +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Optional, + Sequence, + Set, + TypedDict, + TypeVar, + Union, + cast, +) + +from langchain_core.documents import Document +from langchain_core.pydantic_v1 import root_validator +from langchain_core.vectorstores import VectorStore + +from langchain_community.document_loaders.base import BaseLoader +from langchain_community.indexes.base import NAMESPACE_UUID, RecordManager + +T = TypeVar("T") + + +def _hash_string_to_uuid(input_string: str) -> uuid.UUID: + """Hashes a string and returns the corresponding UUID.""" + hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest() + return uuid.uuid5(NAMESPACE_UUID, hash_value) + + +def _hash_nested_dict_to_uuid(data: dict[Any, Any]) -> uuid.UUID: + """Hashes a nested dictionary and returns the corresponding UUID.""" + serialized_data = json.dumps(data, sort_keys=True) + hash_value = hashlib.sha1(serialized_data.encode("utf-8")).hexdigest() + return uuid.uuid5(NAMESPACE_UUID, hash_value) + + +class _HashedDocument(Document): + """A hashed document with a unique ID.""" + + uid: str + hash_: str + """The hash of the document including content and metadata.""" + content_hash: str + """The hash of the document content.""" + metadata_hash: str + """The hash of the document metadata.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + return False + + @root_validator(pre=True) + def calculate_hashes(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Root validator to calculate content and metadata hash.""" + content = values.get("page_content", "") + metadata = values.get("metadata", {}) + + forbidden_keys = ("hash_", "content_hash", "metadata_hash") + + for key in forbidden_keys: + if key in metadata: + raise ValueError( + f"Metadata cannot contain key {key} as it " + f"is reserved for internal use." + ) + + content_hash = str(_hash_string_to_uuid(content)) + + try: + metadata_hash = str(_hash_nested_dict_to_uuid(metadata)) + except Exception as e: + raise ValueError( + f"Failed to hash metadata: {e}. " + f"Please use a dict that can be serialized using json." + ) + + values["content_hash"] = content_hash + values["metadata_hash"] = metadata_hash + values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash)) + + _uid = values.get("uid", None) + + if _uid is None: + values["uid"] = values["hash_"] + return values + + def to_document(self) -> Document: + """Return a Document object.""" + return Document( + page_content=self.page_content, + metadata=self.metadata, + ) + + @classmethod + def from_document( + cls, document: Document, *, uid: Optional[str] = None + ) -> _HashedDocument: + """Create a HashedDocument from a Document.""" + return cls( + uid=uid, + page_content=document.page_content, + metadata=document.metadata, + ) + + +def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: + """Utility batching function.""" + it = iter(iterable) + while True: + chunk = list(islice(it, size)) + if not chunk: + return + yield chunk + + +async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]: + """Utility batching function.""" + batch: List[T] = [] + async for element in iterable: + if len(batch) < size: + batch.append(element) + + if len(batch) >= size: + yield batch + batch = [] + + if batch: + yield batch + + +def _get_source_id_assigner( + source_id_key: Union[str, Callable[[Document], str], None], +) -> Callable[[Document], Union[str, None]]: + """Get the source id from the document.""" + if source_id_key is None: + return lambda doc: None + elif isinstance(source_id_key, str): + return lambda doc: doc.metadata[source_id_key] + elif callable(source_id_key): + return source_id_key + else: + raise ValueError( + f"source_id_key should be either None, a string or a callable. " + f"Got {source_id_key} of type {type(source_id_key)}." + ) + + +def _deduplicate_in_order( + hashed_documents: Iterable[_HashedDocument], +) -> Iterator[_HashedDocument]: + """Deduplicate a list of hashed documents while preserving order.""" + seen: Set[str] = set() + + for hashed_doc in hashed_documents: + if hashed_doc.hash_ not in seen: + seen.add(hashed_doc.hash_) + yield hashed_doc + + +# PUBLIC API + + +class IndexingResult(TypedDict): + """Return a detailed a breakdown of the result of the indexing operation.""" + + num_added: int + """Number of added documents.""" + num_updated: int + """Number of updated documents because they were not up to date.""" + num_deleted: int + """Number of deleted documents.""" + num_skipped: int + """Number of skipped documents because they were already up to date.""" + + +def index( + docs_source: Union[BaseLoader, Iterable[Document]], + record_manager: RecordManager, + vector_store: VectorStore, + *, + batch_size: int = 100, + cleanup: Literal["incremental", "full", None] = None, + source_id_key: Union[str, Callable[[Document], str], None] = None, + cleanup_batch_size: int = 1_000, + force_update: bool = False, +) -> IndexingResult: + """Index data from the loader into the vector store. + + Indexing functionality uses a manager to keep track of which documents + are in the vector store. + + This allows us to keep track of which documents were updated, and which + documents were deleted, which documents should be skipped. + + For the time being, documents are indexed using their hashes, and users + are not able to specify the uid of the document. + + IMPORTANT: + if auto_cleanup is set to True, the loader should be returning + the entire dataset, and not just a subset of the dataset. + Otherwise, the auto_cleanup will remove documents that it is not + supposed to. + + Args: + docs_source: Data loader or iterable of documents to index. + record_manager: Timestamped set to keep track of which documents were + updated. + vector_store: Vector store to index the documents into. + batch_size: Batch size to use when indexing. + cleanup: How to handle clean up of documents. + - Incremental: Cleans up all documents that haven't been updated AND + that are associated with source ids that were seen + during indexing. + Clean up is done continuously during indexing helping + to minimize the probability of users seeing duplicated + content. + - Full: Delete all documents that haven to been returned by the loader. + Clean up runs after all documents have been indexed. + This means that users may see duplicated content during indexing. + - None: Do not delete any documents. + source_id_key: Optional key that helps identify the original source + of the document. + cleanup_batch_size: Batch size to use when cleaning up documents. + force_update: Force update documents even if they are present in the + record manager. Useful if you are re-indexing with updated embeddings. + + Returns: + Indexing result which contains information about how many documents + were added, updated, deleted, or skipped. + """ + if cleanup not in {"incremental", "full", None}: + raise ValueError( + f"cleanup should be one of 'incremental', 'full' or None. " + f"Got {cleanup}." + ) + + if cleanup == "incremental" and source_id_key is None: + raise ValueError("Source id key is required when cleanup mode is incremental.") + + # Check that the Vectorstore has required methods implemented + methods = ["delete", "add_documents"] + + for method in methods: + if not hasattr(vector_store, method): + raise ValueError( + f"Vectorstore {vector_store} does not have required method {method}" + ) + + if type(vector_store).delete == VectorStore.delete: + # Checking if the vectorstore has overridden the default delete method + # implementation which just raises a NotImplementedError + raise ValueError("Vectorstore has not implemented the delete method") + + if isinstance(docs_source, BaseLoader): + try: + doc_iterator = docs_source.lazy_load() + except NotImplementedError: + doc_iterator = iter(docs_source.load()) + else: + doc_iterator = iter(docs_source) + + source_id_assigner = _get_source_id_assigner(source_id_key) + + # Mark when the update started. + index_start_dt = record_manager.get_time() + num_added = 0 + num_skipped = 0 + num_updated = 0 + num_deleted = 0 + + for doc_batch in _batch(batch_size, doc_iterator): + hashed_docs = list( + _deduplicate_in_order( + [_HashedDocument.from_document(doc) for doc in doc_batch] + ) + ) + + source_ids: Sequence[Optional[str]] = [ + source_id_assigner(doc) for doc in hashed_docs + ] + + if cleanup == "incremental": + # If the cleanup mode is incremental, source ids are required. + for source_id, hashed_doc in zip(source_ids, hashed_docs): + if source_id is None: + raise ValueError( + "Source ids are required when cleanup mode is incremental. " + f"Document that starts with " + f"content: {hashed_doc.page_content[:100]} was not assigned " + f"as source id." + ) + # source ids cannot be None after for loop above. + source_ids = cast(Sequence[str], source_ids) # type: ignore[assignment] + + exists_batch = record_manager.exists([doc.uid for doc in hashed_docs]) + + # Filter out documents that already exist in the record store. + uids = [] + docs_to_index = [] + uids_to_refresh = [] + seen_docs: Set[str] = set() + for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): + if doc_exists: + if force_update: + seen_docs.add(hashed_doc.uid) + else: + uids_to_refresh.append(hashed_doc.uid) + continue + uids.append(hashed_doc.uid) + docs_to_index.append(hashed_doc.to_document()) + + # Update refresh timestamp + if uids_to_refresh: + record_manager.update(uids_to_refresh, time_at_least=index_start_dt) + num_skipped += len(uids_to_refresh) + + # Be pessimistic and assume that all vector store write will fail. + # First write to vector store + if docs_to_index: + vector_store.add_documents(docs_to_index, ids=uids) + num_added += len(docs_to_index) - len(seen_docs) + num_updated += len(seen_docs) + + # And only then update the record store. + # Update ALL records, even if they already exist since we want to refresh + # their timestamp. + record_manager.update( + [doc.uid for doc in hashed_docs], + group_ids=source_ids, + time_at_least=index_start_dt, + ) + + # If source IDs are provided, we can do the deletion incrementally! + if cleanup == "incremental": + # Get the uids of the documents that were not returned by the loader. + + # mypy isn't good enough to determine that source ids cannot be None + # here due to a check that's happening above, so we check again. + for source_id in source_ids: + if source_id is None: + raise AssertionError("Source ids cannot be None here.") + + _source_ids = cast(Sequence[str], source_ids) + + uids_to_delete = record_manager.list_keys( + group_ids=_source_ids, before=index_start_dt + ) + if uids_to_delete: + # Then delete from vector store. + vector_store.delete(uids_to_delete) + # First delete from record store. + record_manager.delete_keys(uids_to_delete) + num_deleted += len(uids_to_delete) + + if cleanup == "full": + while uids_to_delete := record_manager.list_keys( + before=index_start_dt, limit=cleanup_batch_size + ): + # First delete from record store. + vector_store.delete(uids_to_delete) + # Then delete from record manager. + record_manager.delete_keys(uids_to_delete) + num_deleted += len(uids_to_delete) + + return { + "num_added": num_added, + "num_updated": num_updated, + "num_skipped": num_skipped, + "num_deleted": num_deleted, + } + + +# Define an asynchronous generator function +async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]: + """Convert an iterable to an async iterator.""" + for item in iterator: + yield item + + +async def aindex( + docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]], + record_manager: RecordManager, + vector_store: VectorStore, + *, + batch_size: int = 100, + cleanup: Literal["incremental", "full", None] = None, + source_id_key: Union[str, Callable[[Document], str], None] = None, + cleanup_batch_size: int = 1_000, + force_update: bool = False, +) -> IndexingResult: + """Index data from the loader into the vector store. + + Indexing functionality uses a manager to keep track of which documents + are in the vector store. + + This allows us to keep track of which documents were updated, and which + documents were deleted, which documents should be skipped. + + For the time being, documents are indexed using their hashes, and users + are not able to specify the uid of the document. + + IMPORTANT: + if auto_cleanup is set to True, the loader should be returning + the entire dataset, and not just a subset of the dataset. + Otherwise, the auto_cleanup will remove documents that it is not + supposed to. + + Args: + docs_source: Data loader or iterable of documents to index. + record_manager: Timestamped set to keep track of which documents were + updated. + vector_store: Vector store to index the documents into. + batch_size: Batch size to use when indexing. + cleanup: How to handle clean up of documents. + - Incremental: Cleans up all documents that haven't been updated AND + that are associated with source ids that were seen + during indexing. + Clean up is done continuously during indexing helping + to minimize the probability of users seeing duplicated + content. + - Full: Delete all documents that haven to been returned by the loader. + Clean up runs after all documents have been indexed. + This means that users may see duplicated content during indexing. + - None: Do not delete any documents. + source_id_key: Optional key that helps identify the original source + of the document. + cleanup_batch_size: Batch size to use when cleaning up documents. + force_update: Force update documents even if they are present in the + record manager. Useful if you are re-indexing with updated embeddings. + + Returns: + Indexing result which contains information about how many documents + were added, updated, deleted, or skipped. + """ + + if cleanup not in {"incremental", "full", None}: + raise ValueError( + f"cleanup should be one of 'incremental', 'full' or None. " + f"Got {cleanup}." + ) + + if cleanup == "incremental" and source_id_key is None: + raise ValueError("Source id key is required when cleanup mode is incremental.") + + # Check that the Vectorstore has required methods implemented + methods = ["adelete", "aadd_documents"] + + for method in methods: + if not hasattr(vector_store, method): + raise ValueError( + f"Vectorstore {vector_store} does not have required method {method}" + ) + + if type(vector_store).adelete == VectorStore.adelete: + # Checking if the vectorstore has overridden the default delete method + # implementation which just raises a NotImplementedError + raise ValueError("Vectorstore has not implemented the delete method") + + async_doc_iterator: AsyncIterator[Document] + if isinstance(docs_source, BaseLoader): + try: + async_doc_iterator = docs_source.alazy_load() + except NotImplementedError: + # Exception triggered when neither lazy_load nor alazy_load are implemented. + # * The default implementation of alazy_load uses lazy_load. + # * The default implementation of lazy_load raises NotImplementedError. + # In such a case, we use the load method and convert it to an async + # iterator. + async_doc_iterator = _to_async_iterator(docs_source.load()) + else: + if hasattr(docs_source, "__aiter__"): + async_doc_iterator = docs_source # type: ignore[assignment] + else: + async_doc_iterator = _to_async_iterator(docs_source) + + source_id_assigner = _get_source_id_assigner(source_id_key) + + # Mark when the update started. + index_start_dt = await record_manager.aget_time() + num_added = 0 + num_skipped = 0 + num_updated = 0 + num_deleted = 0 + + async for doc_batch in _abatch(batch_size, async_doc_iterator): + hashed_docs = list( + _deduplicate_in_order( + [_HashedDocument.from_document(doc) for doc in doc_batch] + ) + ) + + source_ids: Sequence[Optional[str]] = [ + source_id_assigner(doc) for doc in hashed_docs + ] + + if cleanup == "incremental": + # If the cleanup mode is incremental, source ids are required. + for source_id, hashed_doc in zip(source_ids, hashed_docs): + if source_id is None: + raise ValueError( + "Source ids are required when cleanup mode is incremental. " + f"Document that starts with " + f"content: {hashed_doc.page_content[:100]} was not assigned " + f"as source id." + ) + # source ids cannot be None after for loop above. + source_ids = cast(Sequence[str], source_ids) + + exists_batch = await record_manager.aexists([doc.uid for doc in hashed_docs]) + + # Filter out documents that already exist in the record store. + uids: list[str] = [] + docs_to_index: list[Document] = [] + uids_to_refresh = [] + seen_docs: Set[str] = set() + for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): + if doc_exists: + if force_update: + seen_docs.add(hashed_doc.uid) + else: + uids_to_refresh.append(hashed_doc.uid) + continue + uids.append(hashed_doc.uid) + docs_to_index.append(hashed_doc.to_document()) + + if uids_to_refresh: + # Must be updated to refresh timestamp. + await record_manager.aupdate(uids_to_refresh, time_at_least=index_start_dt) + num_skipped += len(uids_to_refresh) + + # Be pessimistic and assume that all vector store write will fail. + # First write to vector store + if docs_to_index: + await vector_store.aadd_documents(docs_to_index, ids=uids) + num_added += len(docs_to_index) - len(seen_docs) + num_updated += len(seen_docs) + + # And only then update the record store. + # Update ALL records, even if they already exist since we want to refresh + # their timestamp. + await record_manager.aupdate( + [doc.uid for doc in hashed_docs], + group_ids=source_ids, + time_at_least=index_start_dt, + ) + + # If source IDs are provided, we can do the deletion incrementally! + + if cleanup == "incremental": + # Get the uids of the documents that were not returned by the loader. + + # mypy isn't good enough to determine that source ids cannot be None + # here due to a check that's happening above, so we check again. + for source_id in source_ids: + if source_id is None: + raise AssertionError("Source ids cannot be None here.") + + _source_ids = cast(Sequence[str], source_ids) + + uids_to_delete = await record_manager.alist_keys( + group_ids=_source_ids, before=index_start_dt + ) + if uids_to_delete: + # Then delete from vector store. + await vector_store.adelete(uids_to_delete) + # First delete from record store. + await record_manager.adelete_keys(uids_to_delete) + num_deleted += len(uids_to_delete) + + if cleanup == "full": + while uids_to_delete := await record_manager.alist_keys( + before=index_start_dt, limit=cleanup_batch_size + ): + # First delete from record store. + await vector_store.adelete(uids_to_delete) + # Then delete from record manager. + await record_manager.adelete_keys(uids_to_delete) + num_deleted += len(uids_to_delete) + + return { + "num_added": num_added, + "num_updated": num_updated, + "num_skipped": num_skipped, + "num_deleted": num_deleted, + } diff --git a/libs/community/langchain_community/indexes/graph.py b/libs/community/langchain_community/indexes/graph.py new file mode 100644 index 0000000000000..4998cce5a6467 --- /dev/null +++ b/libs/community/langchain_community/indexes/graph.py @@ -0,0 +1,47 @@ +"""Graph Index Creator.""" +from typing import Optional, Type + +from langchain.chains.llm import LLMChain +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import BaseModel + +from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, parse_triples +from langchain_community.indexes.prompts.knowledge_triplet_extraction import ( + KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, +) + + +class GraphIndexCreator(BaseModel): + """Functionality to create graph index.""" + + llm: Optional[BaseLanguageModel] = None + graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph + + def from_text( + self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT + ) -> NetworkxEntityGraph: + """Create graph index from text.""" + if self.llm is None: + raise ValueError("llm should not be None") + graph = self.graph_type() + chain = LLMChain(llm=self.llm, prompt=prompt) + output = chain.predict(text=text) + knowledge = parse_triples(output) + for triple in knowledge: + graph.add_triple(triple) + return graph + + async def afrom_text( + self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT + ) -> NetworkxEntityGraph: + """Create graph index from text asynchronously.""" + if self.llm is None: + raise ValueError("llm should not be None") + graph = self.graph_type() + chain = LLMChain(llm=self.llm, prompt=prompt) + output = await chain.apredict(text=text) + knowledge = parse_triples(output) + for triple in knowledge: + graph.add_triple(triple) + return graph diff --git a/libs/langchain/langchain/indexes/prompts/__init__.py b/libs/community/langchain_community/indexes/prompts/__init__.py similarity index 100% rename from libs/langchain/langchain/indexes/prompts/__init__.py rename to libs/community/langchain_community/indexes/prompts/__init__.py diff --git a/libs/langchain/langchain/indexes/prompts/entity_extraction.py b/libs/community/langchain_community/indexes/prompts/entity_extraction.py similarity index 100% rename from libs/langchain/langchain/indexes/prompts/entity_extraction.py rename to libs/community/langchain_community/indexes/prompts/entity_extraction.py diff --git a/libs/langchain/langchain/indexes/prompts/entity_summarization.py b/libs/community/langchain_community/indexes/prompts/entity_summarization.py similarity index 100% rename from libs/langchain/langchain/indexes/prompts/entity_summarization.py rename to libs/community/langchain_community/indexes/prompts/entity_summarization.py diff --git a/libs/langchain/langchain/indexes/prompts/knowledge_triplet_extraction.py b/libs/community/langchain_community/indexes/prompts/knowledge_triplet_extraction.py similarity index 100% rename from libs/langchain/langchain/indexes/prompts/knowledge_triplet_extraction.py rename to libs/community/langchain_community/indexes/prompts/knowledge_triplet_extraction.py diff --git a/libs/community/langchain_community/indexes/vectorstore.py b/libs/community/langchain_community/indexes/vectorstore.py new file mode 100644 index 0000000000000..380ef4767a1a0 --- /dev/null +++ b/libs/community/langchain_community/indexes/vectorstore.py @@ -0,0 +1,91 @@ +from typing import Any, Dict, List, Optional, Type + +from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain +from langchain.chains.retrieval_qa.base import RetrievalQA +from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.language_models import BaseLanguageModel +from langchain_core.pydantic_v1 import BaseModel, Extra, Field +from langchain_core.vectorstores import VectorStore + +from langchain_community.document_loaders.base import BaseLoader +from langchain_community.embeddings.openai import OpenAIEmbeddings +from langchain_community.llms.openai import OpenAI +from langchain_community.vectorstores.chroma import Chroma + + +def _get_default_text_splitter() -> TextSplitter: + return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + + +class VectorStoreIndexWrapper(BaseModel): + """Wrapper around a vectorstore for easy access.""" + + vectorstore: VectorStore + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def query( + self, + question: str, + llm: Optional[BaseLanguageModel] = None, + retriever_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> str: + """Query the vectorstore.""" + llm = llm or OpenAI(temperature=0) + retriever_kwargs = retriever_kwargs or {} + chain = RetrievalQA.from_chain_type( + llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs + ) + return chain.run(question) + + def query_with_sources( + self, + question: str, + llm: Optional[BaseLanguageModel] = None, + retriever_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> dict: + """Query the vectorstore and get back sources.""" + llm = llm or OpenAI(temperature=0) + retriever_kwargs = retriever_kwargs or {} + chain = RetrievalQAWithSourcesChain.from_chain_type( + llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs + ) + return chain({chain.question_key: question}) + + +class VectorstoreIndexCreator(BaseModel): + """Logic for creating indexes.""" + + vectorstore_cls: Type[VectorStore] = Chroma + embedding: Embeddings = Field(default_factory=OpenAIEmbeddings) + text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter) + vectorstore_kwargs: dict = Field(default_factory=dict) + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def from_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper: + """Create a vectorstore index from loaders.""" + docs = [] + for loader in loaders: + docs.extend(loader.load()) + return self.from_documents(docs) + + def from_documents(self, documents: List[Document]) -> VectorStoreIndexWrapper: + """Create a vectorstore index from documents.""" + sub_docs = self.text_splitter.split_documents(documents) + vectorstore = self.vectorstore_cls.from_documents( + sub_docs, self.embedding, **self.vectorstore_kwargs + ) + return VectorStoreIndexWrapper(vectorstore=vectorstore) diff --git a/libs/community/tests/unit_tests/indexes/test_hashed_document.py b/libs/community/tests/unit_tests/indexes/test_hashed_document.py new file mode 100644 index 0000000000000..e43e4eea200d8 --- /dev/null +++ b/libs/community/tests/unit_tests/indexes/test_hashed_document.py @@ -0,0 +1,50 @@ +import pytest +from langchain_core.documents import Document + +from langchain_community.indexes._api import _HashedDocument + + +def test_hashed_document_hashing() -> None: + hashed_document = _HashedDocument( + uid="123", page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"} + ) + assert isinstance(hashed_document.hash_, str) + + +def test_hashing_with_missing_content() -> None: + """Check that ValueError is raised if page_content is missing.""" + with pytest.raises(TypeError): + _HashedDocument( + metadata={"key": "value"}, + ) # type: ignore + + +def test_uid_auto_assigned_to_hash() -> None: + """Test uid is auto-assigned to the hashed_document hash.""" + hashed_document = _HashedDocument( + page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"} + ) + assert hashed_document.uid == hashed_document.hash_ + + +def test_to_document() -> None: + """Test to_document method.""" + hashed_document = _HashedDocument( + page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"} + ) + doc = hashed_document.to_document() + assert isinstance(doc, Document) + assert doc.page_content == "Lorem ipsum dolor sit amet" + assert doc.metadata == {"key": "value"} + + +def test_from_document() -> None: + """Test from document class method.""" + document = Document( + page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"} + ) + + hashed_document = _HashedDocument.from_document(document) + # hash should be deterministic + assert hashed_document.hash_ == "fd1dc827-051b-537d-a1fe-1fa043e8b276" + assert hashed_document.uid == hashed_document.hash_ diff --git a/libs/community/tests/unit_tests/indexes/test_indexing.py b/libs/community/tests/unit_tests/indexes/test_indexing.py new file mode 100644 index 0000000000000..e439306ccaa81 --- /dev/null +++ b/libs/community/tests/unit_tests/indexes/test_indexing.py @@ -0,0 +1,1155 @@ +from datetime import datetime +from typing import ( + Any, + AsyncIterator, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Type, +) +from unittest.mock import patch + +import pytest +import pytest_asyncio +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VST, VectorStore + +from langchain_community.document_loaders.base import BaseLoader +from langchain_community.indexes._api import _abatch, aindex, index +from langchain_community.indexes._sql_record_manager import SQLRecordManager + + +class ToyLoader(BaseLoader): + """Toy loader that always returns the same documents.""" + + def __init__(self, documents: Sequence[Document]) -> None: + """Initialize with the documents to return.""" + self.documents = documents + + def lazy_load( + self, + ) -> Iterator[Document]: + yield from self.documents + + def load(self) -> List[Document]: + """Load the documents from the source.""" + return list(self.lazy_load()) + + async def alazy_load( + self, + ) -> AsyncIterator[Document]: + for document in self.documents: + yield document + + +class InMemoryVectorStore(VectorStore): + """In-memory implementation of VectorStore using a dictionary.""" + + def __init__(self, permit_upserts: bool = False) -> None: + """Vector store interface for testing things in memory.""" + self.store: Dict[str, Document] = {} + self.permit_upserts = permit_upserts + + def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: + """Delete the given documents from the store using their IDs.""" + if ids: + for _id in ids: + self.store.pop(_id, None) + + async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: + """Delete the given documents from the store using their IDs.""" + if ids: + for _id in ids: + self.store.pop(_id, None) + + def add_documents( # type: ignore + self, + documents: Sequence[Document], + *, + ids: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Add the given documents to the store (insert behavior).""" + if ids and len(ids) != len(documents): + raise ValueError( + f"Expected {len(ids)} ids, got {len(documents)} documents." + ) + + if not ids: + raise NotImplementedError("This is not implemented yet.") + + for _id, document in zip(ids, documents): + if _id in self.store and not self.permit_upserts: + raise ValueError( + f"Document with uid {_id} already exists in the store." + ) + self.store[_id] = document + + return list(ids) + + async def aadd_documents( + self, + documents: Sequence[Document], + *, + ids: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> List[str]: + if ids and len(ids) != len(documents): + raise ValueError( + f"Expected {len(ids)} ids, got {len(documents)} documents." + ) + + if not ids: + raise NotImplementedError("This is not implemented yet.") + + for _id, document in zip(ids, documents): + if _id in self.store and not self.permit_upserts: + raise ValueError( + f"Document with uid {_id} already exists in the store." + ) + self.store[_id] = document + return list(ids) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[Dict[Any, Any]]] = None, + **kwargs: Any, + ) -> List[str]: + """Add the given texts to the store (insert behavior).""" + raise NotImplementedError() + + @classmethod + def from_texts( + cls: Type[VST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[Dict[Any, Any]]] = None, + **kwargs: Any, + ) -> VST: + """Create a vector store from a list of texts.""" + raise NotImplementedError() + + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + """Find the most similar documents to the given query.""" + raise NotImplementedError() + + +@pytest.fixture +def record_manager() -> SQLRecordManager: + """Timestamped set fixture.""" + record_manager = SQLRecordManager("kittens", db_url="sqlite:///:memory:") + record_manager.create_schema() + return record_manager + + +@pytest_asyncio.fixture # type: ignore +@pytest.mark.requires("aiosqlite") +async def arecord_manager() -> SQLRecordManager: + """Timestamped set fixture.""" + record_manager = SQLRecordManager( + "kittens", + db_url="sqlite+aiosqlite:///:memory:", + async_mode=True, + ) + await record_manager.acreate_schema() + return record_manager + + +@pytest.fixture +def vector_store() -> InMemoryVectorStore: + """Vector store fixture.""" + return InMemoryVectorStore() + + +@pytest.fixture +def upserting_vector_store() -> InMemoryVectorStore: + """Vector store fixture.""" + return InMemoryVectorStore(permit_upserts=True) + + +def test_indexing_same_content( + record_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Indexing some content to confirm it gets added only once.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + ), + Document( + page_content="This is another document.", + ), + ] + ) + + assert index(loader, record_manager, vector_store) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + assert len(list(vector_store.store)) == 2 + + for _ in range(2): + # Run the indexing again + assert index(loader, record_manager, vector_store) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + +@pytest.mark.requires("aiosqlite") +async def test_aindexing_same_content( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Indexing some content to confirm it gets added only once.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + ), + Document( + page_content="This is another document.", + ), + ] + ) + + assert await aindex(loader, arecord_manager, vector_store) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + assert len(list(vector_store.store)) == 2 + + for _ in range(2): + # Run the indexing again + assert await aindex(loader, arecord_manager, vector_store) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + +def test_index_simple_delete_full( + record_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Indexing some content to confirm it gets added only once.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + ), + Document( + page_content="This is another document.", + ), + ] + ) + + with patch.object( + record_manager, "get_time", return_value=datetime(2021, 1, 1).timestamp() + ): + assert index(loader, record_manager, vector_store, cleanup="full") == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + with patch.object( + record_manager, "get_time", return_value=datetime(2021, 1, 1).timestamp() + ): + assert index(loader, record_manager, vector_store, cleanup="full") == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + loader = ToyLoader( + documents=[ + Document( + page_content="mutated document 1", + ), + Document( + page_content="This is another document.", # <-- Same as original + ), + ] + ) + + with patch.object( + record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert index(loader, record_manager, vector_store, cleanup="full") == { + "num_added": 1, + "num_deleted": 1, + "num_skipped": 1, + "num_updated": 0, + } + + doc_texts = set( + # Ignoring type since doc should be in the store and not a None + vector_store.store.get(uid).page_content # type: ignore + for uid in vector_store.store + ) + assert doc_texts == {"mutated document 1", "This is another document."} + + # Attempt to index again verify that nothing changes + with patch.object( + record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert index(loader, record_manager, vector_store, cleanup="full") == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + +@pytest.mark.requires("aiosqlite") +async def test_aindex_simple_delete_full( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Indexing some content to confirm it gets added only once.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + ), + Document( + page_content="This is another document.", + ), + ] + ) + + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp() + ): + assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp() + ): + assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + loader = ToyLoader( + documents=[ + Document( + page_content="mutated document 1", + ), + Document( + page_content="This is another document.", # <-- Same as original + ), + ] + ) + + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { + "num_added": 1, + "num_deleted": 1, + "num_skipped": 1, + "num_updated": 0, + } + + doc_texts = set( + # Ignoring type since doc should be in the store and not a None + vector_store.store.get(uid).page_content # type: ignore + for uid in vector_store.store + ) + assert doc_texts == {"mutated document 1", "This is another document."} + + # Attempt to index again verify that nothing changes + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + +def test_incremental_fails_with_bad_source_ids( + record_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing with incremental deletion strategy.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + Document( + page_content="This is yet another document.", + metadata={"source": None}, + ), + ] + ) + + with pytest.raises(ValueError): + # Should raise an error because no source id function was specified + index(loader, record_manager, vector_store, cleanup="incremental") + + with pytest.raises(ValueError): + # Should raise an error because no source id function was specified + index( + loader, + record_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) + + +@pytest.mark.requires("aiosqlite") +async def test_aincremental_fails_with_bad_source_ids( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing with incremental deletion strategy.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + Document( + page_content="This is yet another document.", + metadata={"source": None}, + ), + ] + ) + + with pytest.raises(ValueError): + # Should raise an error because no source id function was specified + await aindex( + loader, + arecord_manager, + vector_store, + cleanup="incremental", + ) + + with pytest.raises(ValueError): + # Should raise an error because no source id function was specified + await aindex( + loader, + arecord_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) + + +def test_no_delete( + record_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing without a deletion strategy.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + ) + + with patch.object( + record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert index( + loader, + record_manager, + vector_store, + cleanup=None, + source_id_key="source", + ) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + # If we add the same content twice it should be skipped + with patch.object( + record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert index( + loader, + record_manager, + vector_store, + cleanup=None, + source_id_key="source", + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + loader = ToyLoader( + documents=[ + Document( + page_content="mutated content", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + ) + + # Should result in no updates or deletions! + with patch.object( + record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert index( + loader, + record_manager, + vector_store, + cleanup=None, + source_id_key="source", + ) == { + "num_added": 1, + "num_deleted": 0, + "num_skipped": 1, + "num_updated": 0, + } + + +@pytest.mark.requires("aiosqlite") +async def test_ano_delete( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing without a deletion strategy.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + ) + + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + loader, + arecord_manager, + vector_store, + cleanup=None, + source_id_key="source", + ) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + # If we add the same content twice it should be skipped + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + loader, + arecord_manager, + vector_store, + cleanup=None, + source_id_key="source", + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + loader = ToyLoader( + documents=[ + Document( + page_content="mutated content", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + ) + + # Should result in no updates or deletions! + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + loader, + arecord_manager, + vector_store, + cleanup=None, + source_id_key="source", + ) == { + "num_added": 1, + "num_deleted": 0, + "num_skipped": 1, + "num_updated": 0, + } + + +def test_incremental_delete( + record_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing with incremental deletion strategy.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + ) + + with patch.object( + record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert index( + loader, + record_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + doc_texts = set( + # Ignoring type since doc should be in the store and not a None + vector_store.store.get(uid).page_content # type: ignore + for uid in vector_store.store + ) + assert doc_texts == {"This is another document.", "This is a test document."} + + # Attempt to index again verify that nothing changes + with patch.object( + record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert index( + loader, + record_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + # Create 2 documents from the same source all with mutated content + loader = ToyLoader( + documents=[ + Document( + page_content="mutated document 1", + metadata={"source": "1"}, + ), + Document( + page_content="mutated document 2", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", # <-- Same as original + metadata={"source": "2"}, + ), + ] + ) + + # Attempt to index again verify that nothing changes + with patch.object( + record_manager, "get_time", return_value=datetime(2021, 1, 3).timestamp() + ): + assert index( + loader, + record_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) == { + "num_added": 2, + "num_deleted": 1, + "num_skipped": 1, + "num_updated": 0, + } + + doc_texts = set( + # Ignoring type since doc should be in the store and not a None + vector_store.store.get(uid).page_content # type: ignore + for uid in vector_store.store + ) + assert doc_texts == { + "mutated document 1", + "mutated document 2", + "This is another document.", + } + + +@pytest.mark.requires("aiosqlite") +async def test_aincremental_delete( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing with incremental deletion strategy.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + ) + + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + loader.lazy_load(), + arecord_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + doc_texts = set( + # Ignoring type since doc should be in the store and not a None + vector_store.store.get(uid).page_content # type: ignore + for uid in vector_store.store + ) + assert doc_texts == {"This is another document.", "This is a test document."} + + # Attempt to index again verify that nothing changes + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + loader.lazy_load(), + arecord_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + # Create 2 documents from the same source all with mutated content + loader = ToyLoader( + documents=[ + Document( + page_content="mutated document 1", + metadata={"source": "1"}, + ), + Document( + page_content="mutated document 2", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", # <-- Same as original + metadata={"source": "2"}, + ), + ] + ) + + # Attempt to index again verify that nothing changes + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp() + ): + assert await aindex( + loader.lazy_load(), + arecord_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) == { + "num_added": 2, + "num_deleted": 1, + "num_skipped": 1, + "num_updated": 0, + } + + doc_texts = set( + # Ignoring type since doc should be in the store and not a None + vector_store.store.get(uid).page_content # type: ignore + for uid in vector_store.store + ) + assert doc_texts == { + "mutated document 1", + "mutated document 2", + "This is another document.", + } + + +def test_indexing_with_no_docs( + record_manager: SQLRecordManager, vector_store: VectorStore +) -> None: + """Check edge case when loader returns no new docs.""" + loader = ToyLoader(documents=[]) + + assert index(loader, record_manager, vector_store, cleanup="full") == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + +@pytest.mark.requires("aiosqlite") +async def test_aindexing_with_no_docs( + arecord_manager: SQLRecordManager, vector_store: VectorStore +) -> None: + """Check edge case when loader returns no new docs.""" + loader = ToyLoader(documents=[]) + + assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + +def test_deduplication( + record_manager: SQLRecordManager, vector_store: VectorStore +) -> None: + """Check edge case when loader returns no new docs.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + ] + + # Should result in only a single document being added + assert index(docs, record_manager, vector_store, cleanup="full") == { + "num_added": 1, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + +@pytest.mark.requires("aiosqlite") +async def test_adeduplication( + arecord_manager: SQLRecordManager, vector_store: VectorStore +) -> None: + """Check edge case when loader returns no new docs.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + ] + + # Should result in only a single document being added + assert await aindex(docs, arecord_manager, vector_store, cleanup="full") == { + "num_added": 1, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + +def test_cleanup_with_different_batchsize( + record_manager: SQLRecordManager, vector_store: VectorStore +) -> None: + """Check that we can clean up with different batch size.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": str(d)}, + ) + for d in range(1000) + ] + + assert index(docs, record_manager, vector_store, cleanup="full") == { + "num_added": 1000, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + docs = [ + Document( + page_content="Different doc", + metadata={"source": str(d)}, + ) + for d in range(1001) + ] + + assert index( + docs, record_manager, vector_store, cleanup="full", cleanup_batch_size=17 + ) == { + "num_added": 1001, + "num_deleted": 1000, + "num_skipped": 0, + "num_updated": 0, + } + + +@pytest.mark.requires("aiosqlite") +async def test_async_cleanup_with_different_batchsize( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Check that we can clean up with different batch size.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": str(d)}, + ) + for d in range(1000) + ] + + assert await aindex(docs, arecord_manager, vector_store, cleanup="full") == { + "num_added": 1000, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + docs = [ + Document( + page_content="Different doc", + metadata={"source": str(d)}, + ) + for d in range(1001) + ] + + assert await aindex( + docs, arecord_manager, vector_store, cleanup="full", cleanup_batch_size=17 + ) == { + "num_added": 1001, + "num_deleted": 1000, + "num_skipped": 0, + "num_updated": 0, + } + + +def test_deduplication_v2( + record_manager: SQLRecordManager, vector_store: VectorStore +) -> None: + """Check edge case when loader returns no new docs.""" + docs = [ + Document( + page_content="1", + metadata={"source": "1"}, + ), + Document( + page_content="1", + metadata={"source": "1"}, + ), + Document( + page_content="2", + metadata={"source": "2"}, + ), + Document( + page_content="3", + metadata={"source": "3"}, + ), + ] + + assert index(docs, record_manager, vector_store, cleanup="full") == { + "num_added": 3, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + # using in memory implementation here + assert isinstance(vector_store, InMemoryVectorStore) + contents = sorted( + [document.page_content for document in vector_store.store.values()] + ) + assert contents == ["1", "2", "3"] + + +async def _to_async_iter(it: Iterable[Any]) -> AsyncIterator[Any]: + """Convert an iterable to an async iterator.""" + for i in it: + yield i + + +async def test_abatch() -> None: + """Test the abatch function.""" + batches = _abatch(5, _to_async_iter(range(12))) + assert isinstance(batches, AsyncIterator) + assert [batch async for batch in batches] == [ + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11], + ] + + batches = _abatch(1, _to_async_iter(range(3))) + assert isinstance(batches, AsyncIterator) + assert [batch async for batch in batches] == [[0], [1], [2]] + + batches = _abatch(2, _to_async_iter(range(5))) + assert isinstance(batches, AsyncIterator) + assert [batch async for batch in batches] == [[0, 1], [2, 3], [4]] + + +def test_indexing_force_update( + record_manager: SQLRecordManager, upserting_vector_store: VectorStore +) -> None: + """Test indexing with force update.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + ] + + assert index(docs, record_manager, upserting_vector_store, cleanup="full") == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + assert index(docs, record_manager, upserting_vector_store, cleanup="full") == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + assert index( + docs, record_manager, upserting_vector_store, cleanup="full", force_update=True + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 2, + } + + +@pytest.mark.requires("aiosqlite") +async def test_aindexing_force_update( + arecord_manager: SQLRecordManager, upserting_vector_store: VectorStore +) -> None: + """Test indexing with force update.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + ] + + assert await aindex( + docs, arecord_manager, upserting_vector_store, cleanup="full" + ) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + assert await aindex( + docs, arecord_manager, upserting_vector_store, cleanup="full" + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + assert await aindex( + docs, + arecord_manager, + upserting_vector_store, + cleanup="full", + force_update=True, + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 2, + } diff --git a/libs/langchain/langchain/indexes/_api.py b/libs/langchain/langchain/indexes/_api.py index 8dcf6e0c93a47..bab2ca37b8c03 100644 --- a/libs/langchain/langchain/indexes/_api.py +++ b/libs/langchain/langchain/indexes/_api.py @@ -1,598 +1,15 @@ -"""Module contains logic for indexing documents into vector stores.""" -from __future__ import annotations - -import hashlib -import json -import uuid -from itertools import islice -from typing import ( - Any, - AsyncIterable, - AsyncIterator, - Callable, - Dict, - Iterable, - Iterator, - List, - Literal, - Optional, - Sequence, - Set, - TypedDict, - TypeVar, - Union, - cast, +from langchain_community.indexes._api import ( + IndexingResult, + _abatch, + _HashedDocument, + aindex, + index, ) -from langchain_community.document_loaders.base import BaseLoader -from langchain_core.documents import Document -from langchain_core.pydantic_v1 import root_validator -from langchain_core.vectorstores import VectorStore - -from langchain.indexes.base import NAMESPACE_UUID, RecordManager - -T = TypeVar("T") - - -def _hash_string_to_uuid(input_string: str) -> uuid.UUID: - """Hashes a string and returns the corresponding UUID.""" - hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest() - return uuid.uuid5(NAMESPACE_UUID, hash_value) - - -def _hash_nested_dict_to_uuid(data: dict[Any, Any]) -> uuid.UUID: - """Hashes a nested dictionary and returns the corresponding UUID.""" - serialized_data = json.dumps(data, sort_keys=True) - hash_value = hashlib.sha1(serialized_data.encode("utf-8")).hexdigest() - return uuid.uuid5(NAMESPACE_UUID, hash_value) - - -class _HashedDocument(Document): - """A hashed document with a unique ID.""" - - uid: str - hash_: str - """The hash of the document including content and metadata.""" - content_hash: str - """The hash of the document content.""" - metadata_hash: str - """The hash of the document metadata.""" - - @classmethod - def is_lc_serializable(cls) -> bool: - return False - - @root_validator(pre=True) - def calculate_hashes(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Root validator to calculate content and metadata hash.""" - content = values.get("page_content", "") - metadata = values.get("metadata", {}) - - forbidden_keys = ("hash_", "content_hash", "metadata_hash") - - for key in forbidden_keys: - if key in metadata: - raise ValueError( - f"Metadata cannot contain key {key} as it " - f"is reserved for internal use." - ) - - content_hash = str(_hash_string_to_uuid(content)) - - try: - metadata_hash = str(_hash_nested_dict_to_uuid(metadata)) - except Exception as e: - raise ValueError( - f"Failed to hash metadata: {e}. " - f"Please use a dict that can be serialized using json." - ) - - values["content_hash"] = content_hash - values["metadata_hash"] = metadata_hash - values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash)) - - _uid = values.get("uid", None) - - if _uid is None: - values["uid"] = values["hash_"] - return values - - def to_document(self) -> Document: - """Return a Document object.""" - return Document( - page_content=self.page_content, - metadata=self.metadata, - ) - - @classmethod - def from_document( - cls, document: Document, *, uid: Optional[str] = None - ) -> _HashedDocument: - """Create a HashedDocument from a Document.""" - return cls( - uid=uid, - page_content=document.page_content, - metadata=document.metadata, - ) - - -def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: - """Utility batching function.""" - it = iter(iterable) - while True: - chunk = list(islice(it, size)) - if not chunk: - return - yield chunk - - -async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]: - """Utility batching function.""" - batch: List[T] = [] - async for element in iterable: - if len(batch) < size: - batch.append(element) - - if len(batch) >= size: - yield batch - batch = [] - - if batch: - yield batch - - -def _get_source_id_assigner( - source_id_key: Union[str, Callable[[Document], str], None], -) -> Callable[[Document], Union[str, None]]: - """Get the source id from the document.""" - if source_id_key is None: - return lambda doc: None - elif isinstance(source_id_key, str): - return lambda doc: doc.metadata[source_id_key] - elif callable(source_id_key): - return source_id_key - else: - raise ValueError( - f"source_id_key should be either None, a string or a callable. " - f"Got {source_id_key} of type {type(source_id_key)}." - ) - - -def _deduplicate_in_order( - hashed_documents: Iterable[_HashedDocument], -) -> Iterator[_HashedDocument]: - """Deduplicate a list of hashed documents while preserving order.""" - seen: Set[str] = set() - - for hashed_doc in hashed_documents: - if hashed_doc.hash_ not in seen: - seen.add(hashed_doc.hash_) - yield hashed_doc - - -# PUBLIC API - - -class IndexingResult(TypedDict): - """Return a detailed a breakdown of the result of the indexing operation.""" - - num_added: int - """Number of added documents.""" - num_updated: int - """Number of updated documents because they were not up to date.""" - num_deleted: int - """Number of deleted documents.""" - num_skipped: int - """Number of skipped documents because they were already up to date.""" - - -def index( - docs_source: Union[BaseLoader, Iterable[Document]], - record_manager: RecordManager, - vector_store: VectorStore, - *, - batch_size: int = 100, - cleanup: Literal["incremental", "full", None] = None, - source_id_key: Union[str, Callable[[Document], str], None] = None, - cleanup_batch_size: int = 1_000, - force_update: bool = False, -) -> IndexingResult: - """Index data from the loader into the vector store. - - Indexing functionality uses a manager to keep track of which documents - are in the vector store. - - This allows us to keep track of which documents were updated, and which - documents were deleted, which documents should be skipped. - - For the time being, documents are indexed using their hashes, and users - are not able to specify the uid of the document. - - IMPORTANT: - if auto_cleanup is set to True, the loader should be returning - the entire dataset, and not just a subset of the dataset. - Otherwise, the auto_cleanup will remove documents that it is not - supposed to. - - Args: - docs_source: Data loader or iterable of documents to index. - record_manager: Timestamped set to keep track of which documents were - updated. - vector_store: Vector store to index the documents into. - batch_size: Batch size to use when indexing. - cleanup: How to handle clean up of documents. - - Incremental: Cleans up all documents that haven't been updated AND - that are associated with source ids that were seen - during indexing. - Clean up is done continuously during indexing helping - to minimize the probability of users seeing duplicated - content. - - Full: Delete all documents that haven to been returned by the loader. - Clean up runs after all documents have been indexed. - This means that users may see duplicated content during indexing. - - None: Do not delete any documents. - source_id_key: Optional key that helps identify the original source - of the document. - cleanup_batch_size: Batch size to use when cleaning up documents. - force_update: Force update documents even if they are present in the - record manager. Useful if you are re-indexing with updated embeddings. - - Returns: - Indexing result which contains information about how many documents - were added, updated, deleted, or skipped. - """ - if cleanup not in {"incremental", "full", None}: - raise ValueError( - f"cleanup should be one of 'incremental', 'full' or None. " - f"Got {cleanup}." - ) - - if cleanup == "incremental" and source_id_key is None: - raise ValueError("Source id key is required when cleanup mode is incremental.") - - # Check that the Vectorstore has required methods implemented - methods = ["delete", "add_documents"] - - for method in methods: - if not hasattr(vector_store, method): - raise ValueError( - f"Vectorstore {vector_store} does not have required method {method}" - ) - - if type(vector_store).delete == VectorStore.delete: - # Checking if the vectorstore has overridden the default delete method - # implementation which just raises a NotImplementedError - raise ValueError("Vectorstore has not implemented the delete method") - - if isinstance(docs_source, BaseLoader): - try: - doc_iterator = docs_source.lazy_load() - except NotImplementedError: - doc_iterator = iter(docs_source.load()) - else: - doc_iterator = iter(docs_source) - - source_id_assigner = _get_source_id_assigner(source_id_key) - - # Mark when the update started. - index_start_dt = record_manager.get_time() - num_added = 0 - num_skipped = 0 - num_updated = 0 - num_deleted = 0 - - for doc_batch in _batch(batch_size, doc_iterator): - hashed_docs = list( - _deduplicate_in_order( - [_HashedDocument.from_document(doc) for doc in doc_batch] - ) - ) - - source_ids: Sequence[Optional[str]] = [ - source_id_assigner(doc) for doc in hashed_docs - ] - - if cleanup == "incremental": - # If the cleanup mode is incremental, source ids are required. - for source_id, hashed_doc in zip(source_ids, hashed_docs): - if source_id is None: - raise ValueError( - "Source ids are required when cleanup mode is incremental. " - f"Document that starts with " - f"content: {hashed_doc.page_content[:100]} was not assigned " - f"as source id." - ) - # source ids cannot be None after for loop above. - source_ids = cast(Sequence[str], source_ids) # type: ignore[assignment] - - exists_batch = record_manager.exists([doc.uid for doc in hashed_docs]) - - # Filter out documents that already exist in the record store. - uids = [] - docs_to_index = [] - uids_to_refresh = [] - seen_docs: Set[str] = set() - for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): - if doc_exists: - if force_update: - seen_docs.add(hashed_doc.uid) - else: - uids_to_refresh.append(hashed_doc.uid) - continue - uids.append(hashed_doc.uid) - docs_to_index.append(hashed_doc.to_document()) - - # Update refresh timestamp - if uids_to_refresh: - record_manager.update(uids_to_refresh, time_at_least=index_start_dt) - num_skipped += len(uids_to_refresh) - - # Be pessimistic and assume that all vector store write will fail. - # First write to vector store - if docs_to_index: - vector_store.add_documents(docs_to_index, ids=uids) - num_added += len(docs_to_index) - len(seen_docs) - num_updated += len(seen_docs) - - # And only then update the record store. - # Update ALL records, even if they already exist since we want to refresh - # their timestamp. - record_manager.update( - [doc.uid for doc in hashed_docs], - group_ids=source_ids, - time_at_least=index_start_dt, - ) - - # If source IDs are provided, we can do the deletion incrementally! - if cleanup == "incremental": - # Get the uids of the documents that were not returned by the loader. - - # mypy isn't good enough to determine that source ids cannot be None - # here due to a check that's happening above, so we check again. - for source_id in source_ids: - if source_id is None: - raise AssertionError("Source ids cannot be None here.") - - _source_ids = cast(Sequence[str], source_ids) - - uids_to_delete = record_manager.list_keys( - group_ids=_source_ids, before=index_start_dt - ) - if uids_to_delete: - # Then delete from vector store. - vector_store.delete(uids_to_delete) - # First delete from record store. - record_manager.delete_keys(uids_to_delete) - num_deleted += len(uids_to_delete) - - if cleanup == "full": - while uids_to_delete := record_manager.list_keys( - before=index_start_dt, limit=cleanup_batch_size - ): - # First delete from record store. - vector_store.delete(uids_to_delete) - # Then delete from record manager. - record_manager.delete_keys(uids_to_delete) - num_deleted += len(uids_to_delete) - - return { - "num_added": num_added, - "num_updated": num_updated, - "num_skipped": num_skipped, - "num_deleted": num_deleted, - } - - -# Define an asynchronous generator function -async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]: - """Convert an iterable to an async iterator.""" - for item in iterator: - yield item - - -async def aindex( - docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]], - record_manager: RecordManager, - vector_store: VectorStore, - *, - batch_size: int = 100, - cleanup: Literal["incremental", "full", None] = None, - source_id_key: Union[str, Callable[[Document], str], None] = None, - cleanup_batch_size: int = 1_000, - force_update: bool = False, -) -> IndexingResult: - """Index data from the loader into the vector store. - - Indexing functionality uses a manager to keep track of which documents - are in the vector store. - - This allows us to keep track of which documents were updated, and which - documents were deleted, which documents should be skipped. - - For the time being, documents are indexed using their hashes, and users - are not able to specify the uid of the document. - - IMPORTANT: - if auto_cleanup is set to True, the loader should be returning - the entire dataset, and not just a subset of the dataset. - Otherwise, the auto_cleanup will remove documents that it is not - supposed to. - - Args: - docs_source: Data loader or iterable of documents to index. - record_manager: Timestamped set to keep track of which documents were - updated. - vector_store: Vector store to index the documents into. - batch_size: Batch size to use when indexing. - cleanup: How to handle clean up of documents. - - Incremental: Cleans up all documents that haven't been updated AND - that are associated with source ids that were seen - during indexing. - Clean up is done continuously during indexing helping - to minimize the probability of users seeing duplicated - content. - - Full: Delete all documents that haven to been returned by the loader. - Clean up runs after all documents have been indexed. - This means that users may see duplicated content during indexing. - - None: Do not delete any documents. - source_id_key: Optional key that helps identify the original source - of the document. - cleanup_batch_size: Batch size to use when cleaning up documents. - force_update: Force update documents even if they are present in the - record manager. Useful if you are re-indexing with updated embeddings. - - Returns: - Indexing result which contains information about how many documents - were added, updated, deleted, or skipped. - """ - - if cleanup not in {"incremental", "full", None}: - raise ValueError( - f"cleanup should be one of 'incremental', 'full' or None. " - f"Got {cleanup}." - ) - - if cleanup == "incremental" and source_id_key is None: - raise ValueError("Source id key is required when cleanup mode is incremental.") - - # Check that the Vectorstore has required methods implemented - methods = ["adelete", "aadd_documents"] - - for method in methods: - if not hasattr(vector_store, method): - raise ValueError( - f"Vectorstore {vector_store} does not have required method {method}" - ) - - if type(vector_store).adelete == VectorStore.adelete: - # Checking if the vectorstore has overridden the default delete method - # implementation which just raises a NotImplementedError - raise ValueError("Vectorstore has not implemented the delete method") - - async_doc_iterator: AsyncIterator[Document] - if isinstance(docs_source, BaseLoader): - try: - async_doc_iterator = docs_source.alazy_load() - except NotImplementedError: - # Exception triggered when neither lazy_load nor alazy_load are implemented. - # * The default implementation of alazy_load uses lazy_load. - # * The default implementation of lazy_load raises NotImplementedError. - # In such a case, we use the load method and convert it to an async - # iterator. - async_doc_iterator = _to_async_iterator(docs_source.load()) - else: - if hasattr(docs_source, "__aiter__"): - async_doc_iterator = docs_source # type: ignore[assignment] - else: - async_doc_iterator = _to_async_iterator(docs_source) - - source_id_assigner = _get_source_id_assigner(source_id_key) - - # Mark when the update started. - index_start_dt = await record_manager.aget_time() - num_added = 0 - num_skipped = 0 - num_updated = 0 - num_deleted = 0 - - async for doc_batch in _abatch(batch_size, async_doc_iterator): - hashed_docs = list( - _deduplicate_in_order( - [_HashedDocument.from_document(doc) for doc in doc_batch] - ) - ) - - source_ids: Sequence[Optional[str]] = [ - source_id_assigner(doc) for doc in hashed_docs - ] - - if cleanup == "incremental": - # If the cleanup mode is incremental, source ids are required. - for source_id, hashed_doc in zip(source_ids, hashed_docs): - if source_id is None: - raise ValueError( - "Source ids are required when cleanup mode is incremental. " - f"Document that starts with " - f"content: {hashed_doc.page_content[:100]} was not assigned " - f"as source id." - ) - # source ids cannot be None after for loop above. - source_ids = cast(Sequence[str], source_ids) - - exists_batch = await record_manager.aexists([doc.uid for doc in hashed_docs]) - - # Filter out documents that already exist in the record store. - uids: list[str] = [] - docs_to_index: list[Document] = [] - uids_to_refresh = [] - seen_docs: Set[str] = set() - for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): - if doc_exists: - if force_update: - seen_docs.add(hashed_doc.uid) - else: - uids_to_refresh.append(hashed_doc.uid) - continue - uids.append(hashed_doc.uid) - docs_to_index.append(hashed_doc.to_document()) - - if uids_to_refresh: - # Must be updated to refresh timestamp. - await record_manager.aupdate(uids_to_refresh, time_at_least=index_start_dt) - num_skipped += len(uids_to_refresh) - - # Be pessimistic and assume that all vector store write will fail. - # First write to vector store - if docs_to_index: - await vector_store.aadd_documents(docs_to_index, ids=uids) - num_added += len(docs_to_index) - len(seen_docs) - num_updated += len(seen_docs) - - # And only then update the record store. - # Update ALL records, even if they already exist since we want to refresh - # their timestamp. - await record_manager.aupdate( - [doc.uid for doc in hashed_docs], - group_ids=source_ids, - time_at_least=index_start_dt, - ) - - # If source IDs are provided, we can do the deletion incrementally! - - if cleanup == "incremental": - # Get the uids of the documents that were not returned by the loader. - - # mypy isn't good enough to determine that source ids cannot be None - # here due to a check that's happening above, so we check again. - for source_id in source_ids: - if source_id is None: - raise AssertionError("Source ids cannot be None here.") - - _source_ids = cast(Sequence[str], source_ids) - - uids_to_delete = await record_manager.alist_keys( - group_ids=_source_ids, before=index_start_dt - ) - if uids_to_delete: - # Then delete from vector store. - await vector_store.adelete(uids_to_delete) - # First delete from record store. - await record_manager.adelete_keys(uids_to_delete) - num_deleted += len(uids_to_delete) - - if cleanup == "full": - while uids_to_delete := await record_manager.alist_keys( - before=index_start_dt, limit=cleanup_batch_size - ): - # First delete from record store. - await vector_store.adelete(uids_to_delete) - # Then delete from record manager. - await record_manager.adelete_keys(uids_to_delete) - num_deleted += len(uids_to_delete) - - return { - "num_added": num_added, - "num_updated": num_updated, - "num_skipped": num_skipped, - "num_deleted": num_deleted, - } +__all__ = [ + "_abatch", + "_HashedDocument", + "aindex", + "index", + "IndexingResult", +] diff --git a/libs/langchain/langchain/indexes/_sql_record_manager.py b/libs/langchain/langchain/indexes/_sql_record_manager.py index 76dfd9672340e..f4695dfdbc412 100644 --- a/libs/langchain/langchain/indexes/_sql_record_manager.py +++ b/libs/langchain/langchain/indexes/_sql_record_manager.py @@ -1,518 +1,5 @@ -"""Implementation of a record management layer in SQLAlchemy. +from langchain_community.indexes._sql_record_manager import SQLRecordManager -The management layer uses SQLAlchemy to track upserted records. - -Currently, this layer only works with SQLite; hopwever, should be adaptable -to other SQL implementations with minimal effort. - -Currently, includes an implementation that uses SQLAlchemy which should -allow it to work with a variety of SQL as a backend. - -* Each key is associated with an updated_at field. -* This filed is updated whenever the key is updated. -* Keys can be listed based on the updated at field. -* Keys can be deleted. -""" -import contextlib -import decimal -import uuid -from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence, Union - -from sqlalchemy import ( - URL, - Column, - Engine, - Float, - Index, - String, - UniqueConstraint, - and_, - create_engine, - delete, - select, - text, -) -from sqlalchemy.ext.asyncio import ( - AsyncEngine, - AsyncSession, - async_sessionmaker, - create_async_engine, -) -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Query, Session, sessionmaker - -from langchain.indexes.base import RecordManager - -Base = declarative_base() - - -class UpsertionRecord(Base): # type: ignore[valid-type,misc] - """Table used to keep track of when a key was last updated.""" - - # ATTENTION: - # Prior to modifying this table, please determine whether - # we should create migrations for this table to make sure - # users do not experience data loss. - __tablename__ = "upsertion_record" - - uuid = Column( - String, - index=True, - default=lambda: str(uuid.uuid4()), - primary_key=True, - nullable=False, - ) - key = Column(String, index=True) - # Using a non-normalized representation to handle `namespace` attribute. - # If the need arises, this attribute can be pulled into a separate Collection - # table at some time later. - namespace = Column(String, index=True, nullable=False) - group_id = Column(String, index=True, nullable=True) - - # The timestamp associated with the last record upsertion. - updated_at = Column(Float, index=True) - - __table_args__ = ( - UniqueConstraint("key", "namespace", name="uix_key_namespace"), - Index("ix_key_namespace", "key", "namespace"), - ) - - -class SQLRecordManager(RecordManager): - """A SQL Alchemy based implementation of the record manager.""" - - def __init__( - self, - namespace: str, - *, - engine: Optional[Union[Engine, AsyncEngine]] = None, - db_url: Union[None, str, URL] = None, - engine_kwargs: Optional[Dict[str, Any]] = None, - async_mode: bool = False, - ) -> None: - """Initialize the SQLRecordManager. - - This class serves as a manager persistence layer that uses an SQL - backend to track upserted records. You should specify either a db_url - to create an engine or provide an existing engine. - - Args: - namespace: The namespace associated with this record manager. - engine: An already existing SQL Alchemy engine. - Default is None. - db_url: A database connection string used to create - an SQL Alchemy engine. Default is None. - engine_kwargs: Additional keyword arguments - to be passed when creating the engine. Default is an empty dictionary. - async_mode: Whether to create an async engine. - Driver should support async operations. - It only applies if db_url is provided. - Default is False. - - Raises: - ValueError: If both db_url and engine are provided or neither. - AssertionError: If something unexpected happens during engine configuration. - """ - super().__init__(namespace=namespace) - if db_url is None and engine is None: - raise ValueError("Must specify either db_url or engine") - - if db_url is not None and engine is not None: - raise ValueError("Must specify either db_url or engine, not both") - - _engine: Union[Engine, AsyncEngine] - if db_url: - if async_mode: - _engine = create_async_engine(db_url, **(engine_kwargs or {})) - else: - _engine = create_engine(db_url, **(engine_kwargs or {})) - elif engine: - _engine = engine - - else: - raise AssertionError("Something went wrong with configuration of engine.") - - _session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]] - if isinstance(_engine, AsyncEngine): - _session_factory = async_sessionmaker(bind=_engine) - else: - _session_factory = sessionmaker(bind=_engine) - - self.engine = _engine - self.dialect = _engine.dialect.name - self.session_factory = _session_factory - - def create_schema(self) -> None: - """Create the database schema.""" - if isinstance(self.engine, AsyncEngine): - raise AssertionError("This method is not supported for async engines.") - - Base.metadata.create_all(self.engine) - - async def acreate_schema(self) -> None: - """Create the database schema.""" - - if not isinstance(self.engine, AsyncEngine): - raise AssertionError("This method is not supported for sync engines.") - - async with self.engine.begin() as session: - await session.run_sync(Base.metadata.create_all) - - @contextlib.contextmanager - def _make_session(self) -> Generator[Session, None, None]: - """Create a session and close it after use.""" - - if isinstance(self.session_factory, async_sessionmaker): - raise AssertionError("This method is not supported for async engines.") - - session = self.session_factory() - try: - yield session - finally: - session.close() - - @contextlib.asynccontextmanager - async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]: - """Create a session and close it after use.""" - - if not isinstance(self.session_factory, async_sessionmaker): - raise AssertionError("This method is not supported for sync engines.") - - async with self.session_factory() as session: - yield session - - def get_time(self) -> float: - """Get the current server time as a timestamp. - - Please note it's critical that time is obtained from the server since - we want a monotonic clock. - """ - with self._make_session() as session: - # * SQLite specific implementation, can be changed based on dialect. - # * For SQLite, unlike unixepoch it will work with older versions of SQLite. - # ---- - # julianday('now'): Julian day number for the current date and time. - # The Julian day is a continuous count of days, starting from a - # reference date (Julian day number 0). - # 2440587.5 - constant represents the Julian day number for January 1, 1970 - # 86400.0 - constant represents the number of seconds - # in a day (24 hours * 60 minutes * 60 seconds) - if self.dialect == "sqlite": - query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") - elif self.dialect == "postgresql": - query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);") - else: - raise NotImplementedError(f"Not implemented for dialect {self.dialect}") - - dt = session.execute(query).scalar() - if isinstance(dt, decimal.Decimal): - dt = float(dt) - if not isinstance(dt, float): - raise AssertionError(f"Unexpected type for datetime: {type(dt)}") - return dt - - async def aget_time(self) -> float: - """Get the current server time as a timestamp. - - Please note it's critical that time is obtained from the server since - we want a monotonic clock. - """ - async with self._amake_session() as session: - # * SQLite specific implementation, can be changed based on dialect. - # * For SQLite, unlike unixepoch it will work with older versions of SQLite. - # ---- - # julianday('now'): Julian day number for the current date and time. - # The Julian day is a continuous count of days, starting from a - # reference date (Julian day number 0). - # 2440587.5 - constant represents the Julian day number for January 1, 1970 - # 86400.0 - constant represents the number of seconds - # in a day (24 hours * 60 minutes * 60 seconds) - if self.dialect == "sqlite": - query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") - elif self.dialect == "postgresql": - query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);") - else: - raise NotImplementedError(f"Not implemented for dialect {self.dialect}") - - dt = (await session.execute(query)).scalar_one_or_none() - - if isinstance(dt, decimal.Decimal): - dt = float(dt) - if not isinstance(dt, float): - raise AssertionError(f"Unexpected type for datetime: {type(dt)}") - return dt - - def update( - self, - keys: Sequence[str], - *, - group_ids: Optional[Sequence[Optional[str]]] = None, - time_at_least: Optional[float] = None, - ) -> None: - """Upsert records into the SQLite database.""" - if group_ids is None: - group_ids = [None] * len(keys) - - if len(keys) != len(group_ids): - raise ValueError( - f"Number of keys ({len(keys)}) does not match number of " - f"group_ids ({len(group_ids)})" - ) - - # Get the current time from the server. - # This makes an extra round trip to the server, should not be a big deal - # if the batch size is large enough. - # Getting the time here helps us compare it against the time_at_least - # and raise an error if there is a time sync issue. - # Here, we're just being extra careful to minimize the chance of - # data loss due to incorrectly deleting records. - update_time = self.get_time() - - if time_at_least and update_time < time_at_least: - # Safeguard against time sync issues - raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}") - - records_to_upsert = [ - { - "key": key, - "namespace": self.namespace, - "updated_at": update_time, - "group_id": group_id, - } - for key, group_id in zip(keys, group_ids) - ] - - with self._make_session() as session: - if self.dialect == "sqlite": - from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType - from sqlalchemy.dialects.sqlite import insert as sqlite_insert - - # Note: uses SQLite insert to make on_conflict_do_update work. - # This code needs to be generalized a bit to work with more dialects. - sqlite_insert_stmt: SqliteInsertType = sqlite_insert( - UpsertionRecord - ).values(records_to_upsert) - stmt = sqlite_insert_stmt.on_conflict_do_update( - [UpsertionRecord.key, UpsertionRecord.namespace], - set_=dict( - updated_at=sqlite_insert_stmt.excluded.updated_at, - group_id=sqlite_insert_stmt.excluded.group_id, - ), - ) - elif self.dialect == "postgresql": - from sqlalchemy.dialects.postgresql import Insert as PgInsertType - from sqlalchemy.dialects.postgresql import insert as pg_insert - - # Note: uses postgresql insert to make on_conflict_do_update work. - # This code needs to be generalized a bit to work with more dialects. - pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values( - records_to_upsert - ) - stmt = pg_insert_stmt.on_conflict_do_update( - "uix_key_namespace", # Name of constraint - set_=dict( - updated_at=pg_insert_stmt.excluded.updated_at, - group_id=pg_insert_stmt.excluded.group_id, - ), - ) - else: - raise NotImplementedError(f"Unsupported dialect {self.dialect}") - - session.execute(stmt) - session.commit() - - async def aupdate( - self, - keys: Sequence[str], - *, - group_ids: Optional[Sequence[Optional[str]]] = None, - time_at_least: Optional[float] = None, - ) -> None: - """Upsert records into the SQLite database.""" - if group_ids is None: - group_ids = [None] * len(keys) - - if len(keys) != len(group_ids): - raise ValueError( - f"Number of keys ({len(keys)}) does not match number of " - f"group_ids ({len(group_ids)})" - ) - - # Get the current time from the server. - # This makes an extra round trip to the server, should not be a big deal - # if the batch size is large enough. - # Getting the time here helps us compare it against the time_at_least - # and raise an error if there is a time sync issue. - # Here, we're just being extra careful to minimize the chance of - # data loss due to incorrectly deleting records. - update_time = await self.aget_time() - - if time_at_least and update_time < time_at_least: - # Safeguard against time sync issues - raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}") - - records_to_upsert = [ - { - "key": key, - "namespace": self.namespace, - "updated_at": update_time, - "group_id": group_id, - } - for key, group_id in zip(keys, group_ids) - ] - - async with self._amake_session() as session: - if self.dialect == "sqlite": - from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType - from sqlalchemy.dialects.sqlite import insert as sqlite_insert - - # Note: uses SQLite insert to make on_conflict_do_update work. - # This code needs to be generalized a bit to work with more dialects. - sqlite_insert_stmt: SqliteInsertType = sqlite_insert( - UpsertionRecord - ).values(records_to_upsert) - stmt = sqlite_insert_stmt.on_conflict_do_update( - [UpsertionRecord.key, UpsertionRecord.namespace], - set_=dict( - updated_at=sqlite_insert_stmt.excluded.updated_at, - group_id=sqlite_insert_stmt.excluded.group_id, - ), - ) - elif self.dialect == "postgresql": - from sqlalchemy.dialects.postgresql import Insert as PgInsertType - from sqlalchemy.dialects.postgresql import insert as pg_insert - - # Note: uses SQLite insert to make on_conflict_do_update work. - # This code needs to be generalized a bit to work with more dialects. - pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values( - records_to_upsert - ) - stmt = pg_insert_stmt.on_conflict_do_update( - "uix_key_namespace", # Name of constraint - set_=dict( - updated_at=pg_insert_stmt.excluded.updated_at, - group_id=pg_insert_stmt.excluded.group_id, - ), - ) - else: - raise NotImplementedError(f"Unsupported dialect {self.dialect}") - - await session.execute(stmt) - await session.commit() - - def exists(self, keys: Sequence[str]) -> List[bool]: - """Check if the given keys exist in the SQLite database.""" - session: Session - with self._make_session() as session: - filtered_query: Query = session.query(UpsertionRecord.key).filter( - and_( - UpsertionRecord.key.in_(keys), - UpsertionRecord.namespace == self.namespace, - ) - ) - records = filtered_query.all() - found_keys = set(r.key for r in records) - return [k in found_keys for k in keys] - - async def aexists(self, keys: Sequence[str]) -> List[bool]: - """Check if the given keys exist in the SQLite database.""" - async with self._amake_session() as session: - records = ( - ( - await session.execute( - select(UpsertionRecord.key).where( - and_( - UpsertionRecord.key.in_(keys), - UpsertionRecord.namespace == self.namespace, - ) - ) - ) - ) - .scalars() - .all() - ) - found_keys = set(records) - return [k in found_keys for k in keys] - - def list_keys( - self, - *, - before: Optional[float] = None, - after: Optional[float] = None, - group_ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, - ) -> List[str]: - """List records in the SQLite database based on the provided date range.""" - session: Session - with self._make_session() as session: - query: Query = session.query(UpsertionRecord).filter( - UpsertionRecord.namespace == self.namespace - ) - - if after: - query = query.filter(UpsertionRecord.updated_at > after) - if before: - query = query.filter(UpsertionRecord.updated_at < before) - if group_ids: - query = query.filter(UpsertionRecord.group_id.in_(group_ids)) - - if limit: - query = query.limit(limit) - records = query.all() - return [r.key for r in records] - - async def alist_keys( - self, - *, - before: Optional[float] = None, - after: Optional[float] = None, - group_ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, - ) -> List[str]: - """List records in the SQLite database based on the provided date range.""" - session: AsyncSession - async with self._amake_session() as session: - query: Query = select(UpsertionRecord.key).filter( - UpsertionRecord.namespace == self.namespace - ) - - # mypy does not recognize .all() or .filter() - if after: - query = query.filter(UpsertionRecord.updated_at > after) - if before: - query = query.filter(UpsertionRecord.updated_at < before) - if group_ids: - query = query.filter(UpsertionRecord.group_id.in_(group_ids)) - - if limit: - query = query.limit(limit) - records = (await session.execute(query)).scalars().all() - return list(records) - - def delete_keys(self, keys: Sequence[str]) -> None: - """Delete records from the SQLite database.""" - session: Session - with self._make_session() as session: - filtered_query: Query = session.query(UpsertionRecord).filter( - and_( - UpsertionRecord.key.in_(keys), - UpsertionRecord.namespace == self.namespace, - ) - ) - - filtered_query.delete() - session.commit() - - async def adelete_keys(self, keys: Sequence[str]) -> None: - """Delete records from the SQLite database.""" - async with self._amake_session() as session: - await session.execute( - delete(UpsertionRecord).where( - and_( - UpsertionRecord.key.in_(keys), - UpsertionRecord.namespace == self.namespace, - ) - ) - ) - - await session.commit() +__all__ = [ + "SQLRecordManager", +] diff --git a/libs/langchain/langchain/indexes/base.py b/libs/langchain/langchain/indexes/base.py index 46ef5bf2efab2..4e26e83ee5e8e 100644 --- a/libs/langchain/langchain/indexes/base.py +++ b/libs/langchain/langchain/indexes/base.py @@ -1,172 +1,5 @@ -from __future__ import annotations +from langchain_community.indexes.base import RecordManager -import uuid -from abc import ABC, abstractmethod -from typing import List, Optional, Sequence - -NAMESPACE_UUID = uuid.UUID(int=1984) - - -class RecordManager(ABC): - """An abstract base class representing the interface for a record manager.""" - - def __init__( - self, - namespace: str, - ) -> None: - """Initialize the record manager. - - Args: - namespace (str): The namespace for the record manager. - """ - self.namespace = namespace - - @abstractmethod - def create_schema(self) -> None: - """Create the database schema for the record manager.""" - - @abstractmethod - async def acreate_schema(self) -> None: - """Create the database schema for the record manager.""" - - @abstractmethod - def get_time(self) -> float: - """Get the current server time as a high resolution timestamp! - - It's important to get this from the server to ensure a monotonic clock, - otherwise there may be data loss when cleaning up old documents! - - Returns: - The current server time as a float timestamp. - """ - - @abstractmethod - async def aget_time(self) -> float: - """Get the current server time as a high resolution timestamp! - - It's important to get this from the server to ensure a monotonic clock, - otherwise there may be data loss when cleaning up old documents! - - Returns: - The current server time as a float timestamp. - """ - - @abstractmethod - def update( - self, - keys: Sequence[str], - *, - group_ids: Optional[Sequence[Optional[str]]] = None, - time_at_least: Optional[float] = None, - ) -> None: - """Upsert records into the database. - - Args: - keys: A list of record keys to upsert. - group_ids: A list of group IDs corresponding to the keys. - time_at_least: if provided, updates should only happen if the - updated_at field is at least this time. - - Raises: - ValueError: If the length of keys doesn't match the length of group_ids. - """ - - @abstractmethod - async def aupdate( - self, - keys: Sequence[str], - *, - group_ids: Optional[Sequence[Optional[str]]] = None, - time_at_least: Optional[float] = None, - ) -> None: - """Upsert records into the database. - - Args: - keys: A list of record keys to upsert. - group_ids: A list of group IDs corresponding to the keys. - time_at_least: if provided, updates should only happen if the - updated_at field is at least this time. - - Raises: - ValueError: If the length of keys doesn't match the length of group_ids. - """ - - @abstractmethod - def exists(self, keys: Sequence[str]) -> List[bool]: - """Check if the provided keys exist in the database. - - Args: - keys: A list of keys to check. - - Returns: - A list of boolean values indicating the existence of each key. - """ - - @abstractmethod - async def aexists(self, keys: Sequence[str]) -> List[bool]: - """Check if the provided keys exist in the database. - - Args: - keys: A list of keys to check. - - Returns: - A list of boolean values indicating the existence of each key. - """ - - @abstractmethod - def list_keys( - self, - *, - before: Optional[float] = None, - after: Optional[float] = None, - group_ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, - ) -> List[str]: - """List records in the database based on the provided filters. - - Args: - before: Filter to list records updated before this time. - after: Filter to list records updated after this time. - group_ids: Filter to list records with specific group IDs. - limit: optional limit on the number of records to return. - - Returns: - A list of keys for the matching records. - """ - - @abstractmethod - async def alist_keys( - self, - *, - before: Optional[float] = None, - after: Optional[float] = None, - group_ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, - ) -> List[str]: - """List records in the database based on the provided filters. - - Args: - before: Filter to list records updated before this time. - after: Filter to list records updated after this time. - group_ids: Filter to list records with specific group IDs. - limit: optional limit on the number of records to return. - - Returns: - A list of keys for the matching records. - """ - - @abstractmethod - def delete_keys(self, keys: Sequence[str]) -> None: - """Delete specified records from the database. - - Args: - keys: A list of keys to delete. - """ - - @abstractmethod - async def adelete_keys(self, keys: Sequence[str]) -> None: - """Delete specified records from the database. - - Args: - keys: A list of keys to delete. - """ +__all__ = [ + "RecordManager", +] diff --git a/libs/langchain/langchain/indexes/graph.py b/libs/langchain/langchain/indexes/graph.py index dc8e2ab38ae3f..ae99be1bb3f6b 100644 --- a/libs/langchain/langchain/indexes/graph.py +++ b/libs/langchain/langchain/indexes/graph.py @@ -1,47 +1,5 @@ -"""Graph Index Creator.""" -from typing import Optional, Type +from langchain_community.indexes.graph import GraphIndexCreator -from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, parse_triples -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel - -from langchain.chains.llm import LLMChain -from langchain.indexes.prompts.knowledge_triplet_extraction import ( - KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, -) - - -class GraphIndexCreator(BaseModel): - """Functionality to create graph index.""" - - llm: Optional[BaseLanguageModel] = None - graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph - - def from_text( - self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT - ) -> NetworkxEntityGraph: - """Create graph index from text.""" - if self.llm is None: - raise ValueError("llm should not be None") - graph = self.graph_type() - chain = LLMChain(llm=self.llm, prompt=prompt) - output = chain.predict(text=text) - knowledge = parse_triples(output) - for triple in knowledge: - graph.add_triple(triple) - return graph - - async def afrom_text( - self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT - ) -> NetworkxEntityGraph: - """Create graph index from text asynchronously.""" - if self.llm is None: - raise ValueError("llm should not be None") - graph = self.graph_type() - chain = LLMChain(llm=self.llm, prompt=prompt) - output = await chain.apredict(text=text) - knowledge = parse_triples(output) - for triple in knowledge: - graph.add_triple(triple) - return graph +__all__ = [ + "GraphIndexCreator", +] diff --git a/libs/langchain/langchain/indexes/vectorstore.py b/libs/langchain/langchain/indexes/vectorstore.py index 25a70a65d25ae..94e49ae981727 100644 --- a/libs/langchain/langchain/indexes/vectorstore.py +++ b/libs/langchain/langchain/indexes/vectorstore.py @@ -1,91 +1,5 @@ -from typing import Any, Dict, List, Optional, Type +from langchain_community.indexes.vectorstore import VectorstoreIndexCreator -from langchain_community.document_loaders.base import BaseLoader -from langchain_community.embeddings.openai import OpenAIEmbeddings -from langchain_community.llms.openai import OpenAI -from langchain_community.vectorstores.chroma import Chroma -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import BaseModel, Extra, Field -from langchain_core.vectorstores import VectorStore - -from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain -from langchain.chains.retrieval_qa.base import RetrievalQA -from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter - - -def _get_default_text_splitter() -> TextSplitter: - return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) - - -class VectorStoreIndexWrapper(BaseModel): - """Wrapper around a vectorstore for easy access.""" - - vectorstore: VectorStore - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - def query( - self, - question: str, - llm: Optional[BaseLanguageModel] = None, - retriever_kwargs: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> str: - """Query the vectorstore.""" - llm = llm or OpenAI(temperature=0) - retriever_kwargs = retriever_kwargs or {} - chain = RetrievalQA.from_chain_type( - llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs - ) - return chain.run(question) - - def query_with_sources( - self, - question: str, - llm: Optional[BaseLanguageModel] = None, - retriever_kwargs: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> dict: - """Query the vectorstore and get back sources.""" - llm = llm or OpenAI(temperature=0) - retriever_kwargs = retriever_kwargs or {} - chain = RetrievalQAWithSourcesChain.from_chain_type( - llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs - ) - return chain({chain.question_key: question}) - - -class VectorstoreIndexCreator(BaseModel): - """Logic for creating indexes.""" - - vectorstore_cls: Type[VectorStore] = Chroma - embedding: Embeddings = Field(default_factory=OpenAIEmbeddings) - text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter) - vectorstore_kwargs: dict = Field(default_factory=dict) - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - def from_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper: - """Create a vectorstore index from loaders.""" - docs = [] - for loader in loaders: - docs.extend(loader.load()) - return self.from_documents(docs) - - def from_documents(self, documents: List[Document]) -> VectorStoreIndexWrapper: - """Create a vectorstore index from documents.""" - sub_docs = self.text_splitter.split_documents(documents) - vectorstore = self.vectorstore_cls.from_documents( - sub_docs, self.embedding, **self.vectorstore_kwargs - ) - return VectorStoreIndexWrapper(vectorstore=vectorstore) +__all__ = [ + "VectorstoreIndexCreator", +] From dfdfe3c7e8b7cf71f2cb1afd0d8c604c9771cdc9 Mon Sep 17 00:00:00 2001 From: leo-gan Date: Mon, 26 Feb 2024 12:12:45 -0800 Subject: [PATCH 2/3] fixed typing --- .../langchain_community/indexes/graph.py | 5 ++++- .../langchain_community/indexes/vectorstore.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/libs/community/langchain_community/indexes/graph.py b/libs/community/langchain_community/indexes/graph.py index 4998cce5a6467..9d51126ab9ef9 100644 --- a/libs/community/langchain_community/indexes/graph.py +++ b/libs/community/langchain_community/indexes/graph.py @@ -1,7 +1,6 @@ """Graph Index Creator.""" from typing import Optional, Type -from langchain.chains.llm import LLMChain from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate from langchain_core.pydantic_v1 import BaseModel @@ -22,6 +21,8 @@ def from_text( self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT ) -> NetworkxEntityGraph: """Create graph index from text.""" + from langchain.chains.llm import LLMChain + if self.llm is None: raise ValueError("llm should not be None") graph = self.graph_type() @@ -36,6 +37,8 @@ async def afrom_text( self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT ) -> NetworkxEntityGraph: """Create graph index from text asynchronously.""" + from langchain.chains.llm import LLMChain + if self.llm is None: raise ValueError("llm should not be None") graph = self.graph_type() diff --git a/libs/community/langchain_community/indexes/vectorstore.py b/libs/community/langchain_community/indexes/vectorstore.py index 380ef4767a1a0..de0af7e50bfc4 100644 --- a/libs/community/langchain_community/indexes/vectorstore.py +++ b/libs/community/langchain_community/indexes/vectorstore.py @@ -1,8 +1,5 @@ -from typing import Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type -from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain -from langchain.chains.retrieval_qa.base import RetrievalQA -from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel @@ -14,8 +11,13 @@ from langchain_community.llms.openai import OpenAI from langchain_community.vectorstores.chroma import Chroma +if TYPE_CHECKING: + from langchain.text_splitter import TextSplitter + def _get_default_text_splitter() -> TextSplitter: + from langchain.text_splitter import RecursiveCharacterTextSplitter + return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) @@ -38,6 +40,8 @@ def query( **kwargs: Any, ) -> str: """Query the vectorstore.""" + from langchain.chains.retrieval_qa.base import RetrievalQA + llm = llm or OpenAI(temperature=0) retriever_kwargs = retriever_kwargs or {} chain = RetrievalQA.from_chain_type( @@ -53,6 +57,10 @@ def query_with_sources( **kwargs: Any, ) -> dict: """Query the vectorstore and get back sources.""" + from langchain.chains.qa_with_sources.retrieval import ( + RetrievalQAWithSourcesChain, + ) + llm = llm or OpenAI(temperature=0) retriever_kwargs = retriever_kwargs or {} chain = RetrievalQAWithSourcesChain.from_chain_type( From 1436c09b6317aeb884acf753fb691297ccece74e Mon Sep 17 00:00:00 2001 From: leo-gan Date: Mon, 26 Feb 2024 12:22:50 -0800 Subject: [PATCH 3/3] fixed typing --- libs/community/langchain_community/indexes/vectorstore.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/libs/community/langchain_community/indexes/vectorstore.py b/libs/community/langchain_community/indexes/vectorstore.py index de0af7e50bfc4..30a4652d441b0 100644 --- a/libs/community/langchain_community/indexes/vectorstore.py +++ b/libs/community/langchain_community/indexes/vectorstore.py @@ -16,9 +16,12 @@ def _get_default_text_splitter() -> TextSplitter: - from langchain.text_splitter import RecursiveCharacterTextSplitter + from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter - return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + text_splitter: TextSplitter = RecursiveCharacterTextSplitter( + chunk_size=1000, chunk_overlap=0 + ) + return text_splitter class VectorStoreIndexWrapper(BaseModel):