Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

indexes refactoring #18146

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
598 changes: 598 additions & 0 deletions libs/community/langchain_community/indexes/_api.py

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions libs/community/langchain_community/indexes/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Graph Index Creator."""
from typing import Optional, Type

from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel

from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, parse_triples
from langchain_community.indexes.prompts.knowledge_triplet_extraction import (
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)


class GraphIndexCreator(BaseModel):
"""Functionality to create graph index."""

llm: Optional[BaseLanguageModel] = None
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph

def from_text(
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
) -> NetworkxEntityGraph:
"""Create graph index from text."""
from langchain.chains.llm import LLMChain

if self.llm is None:
raise ValueError("llm should not be None")
graph = self.graph_type()
chain = LLMChain(llm=self.llm, prompt=prompt)
output = chain.predict(text=text)
knowledge = parse_triples(output)
for triple in knowledge:
graph.add_triple(triple)
return graph

async def afrom_text(
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
) -> NetworkxEntityGraph:
"""Create graph index from text asynchronously."""
from langchain.chains.llm import LLMChain

if self.llm is None:
raise ValueError("llm should not be None")
graph = self.graph_type()
chain = LLMChain(llm=self.llm, prompt=prompt)
output = await chain.apredict(text=text)
knowledge = parse_triples(output)
for triple in knowledge:
graph.add_triple(triple)
return graph
102 changes: 102 additions & 0 deletions libs/community/langchain_community/indexes/vectorstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
from langchain_core.vectorstores import VectorStore

from langchain_community.document_loaders.base import BaseLoader
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.llms.openai import OpenAI
from langchain_community.vectorstores.chroma import Chroma

if TYPE_CHECKING:
from langchain.text_splitter import TextSplitter


def _get_default_text_splitter() -> TextSplitter:
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter

text_splitter: TextSplitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=0
)
return text_splitter


class VectorStoreIndexWrapper(BaseModel):
"""Wrapper around a vectorstore for easy access."""

vectorstore: VectorStore

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

def query(
self,
question: str,
llm: Optional[BaseLanguageModel] = None,
retriever_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Query the vectorstore."""
from langchain.chains.retrieval_qa.base import RetrievalQA

llm = llm or OpenAI(temperature=0)
retriever_kwargs = retriever_kwargs or {}
chain = RetrievalQA.from_chain_type(
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
)
return chain.run(question)

def query_with_sources(
self,
question: str,
llm: Optional[BaseLanguageModel] = None,
retriever_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> dict:
"""Query the vectorstore and get back sources."""
from langchain.chains.qa_with_sources.retrieval import (
RetrievalQAWithSourcesChain,
)

llm = llm or OpenAI(temperature=0)
retriever_kwargs = retriever_kwargs or {}
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
)
return chain({chain.question_key: question})


class VectorstoreIndexCreator(BaseModel):
"""Logic for creating indexes."""

vectorstore_cls: Type[VectorStore] = Chroma
embedding: Embeddings = Field(default_factory=OpenAIEmbeddings)
text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter)
vectorstore_kwargs: dict = Field(default_factory=dict)

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

def from_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper:
"""Create a vectorstore index from loaders."""
docs = []
for loader in loaders:
docs.extend(loader.load())
return self.from_documents(docs)

def from_documents(self, documents: List[Document]) -> VectorStoreIndexWrapper:
"""Create a vectorstore index from documents."""
sub_docs = self.text_splitter.split_documents(documents)
vectorstore = self.vectorstore_cls.from_documents(
sub_docs, self.embedding, **self.vectorstore_kwargs
)
return VectorStoreIndexWrapper(vectorstore=vectorstore)
50 changes: 50 additions & 0 deletions libs/community/tests/unit_tests/indexes/test_hashed_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
from langchain_core.documents import Document

from langchain_community.indexes._api import _HashedDocument


def test_hashed_document_hashing() -> None:
hashed_document = _HashedDocument(
uid="123", page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
)
assert isinstance(hashed_document.hash_, str)


def test_hashing_with_missing_content() -> None:
"""Check that ValueError is raised if page_content is missing."""
with pytest.raises(TypeError):
_HashedDocument(
metadata={"key": "value"},
) # type: ignore


def test_uid_auto_assigned_to_hash() -> None:
"""Test uid is auto-assigned to the hashed_document hash."""
hashed_document = _HashedDocument(
page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
)
assert hashed_document.uid == hashed_document.hash_


def test_to_document() -> None:
"""Test to_document method."""
hashed_document = _HashedDocument(
page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
)
doc = hashed_document.to_document()
assert isinstance(doc, Document)
assert doc.page_content == "Lorem ipsum dolor sit amet"
assert doc.metadata == {"key": "value"}


def test_from_document() -> None:
"""Test from document class method."""
document = Document(
page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
)

hashed_document = _HashedDocument.from_document(document)
# hash should be deterministic
assert hashed_document.hash_ == "fd1dc827-051b-537d-a1fe-1fa043e8b276"
assert hashed_document.uid == hashed_document.hash_
Loading
Loading