Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add document loader #7

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libs/astradb/langchain_astradb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from langchain_astradb.cache import AstraDBCache, AstraDBSemanticCache
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
from langchain_astradb.document_loaders import AstraDBLoader
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
from langchain_astradb.vectorstores import AstraDBVectorStore

Expand All @@ -9,5 +10,6 @@
"AstraDBCache",
"AstraDBSemanticCache",
"AstraDBChatMessageHistory",
"AstraDBLoader",
"AstraDBVectorStore",
]
115 changes: 115 additions & 0 deletions libs/astradb/langchain_astradb/document_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

import json
import logging
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
)

from astrapy.db import AstraDB, AsyncAstraDB
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document

from langchain_astradb.utils.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)

logger = logging.getLogger(__name__)


class AstraDBLoader(BaseLoader):
def __init__(
self,
collection_name: str,
*,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
filter_criteria: Optional[Dict[str, Any]] = None,
projection: Optional[Dict[str, Any]] = None,
find_options: Optional[Dict[str, Any]] = None,
nb_prefetched: int = 1000,
page_content_mapper: Callable[[Dict], str] = json.dumps,
metadata_mapper: Optional[Callable[[Dict], Dict[str, Any]]] = None,
) -> None:
"""Load DataStax Astra DB documents.

Args:
collection_name: name of the Astra DB collection to use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is. Defaults to the database's "default namespace".
filter_criteria: Criteria to filter documents.
projection: Specifies the fields to return.
find_options: Additional options for the query.
nb_prefetched: Max number of documents to pre-fetch. Defaults to 1000.
page_content_mapper: Function applied to collection documents to create
the `page_content` of the LangChain Document. Defaults to `json.dumps`.
"""
astra_db_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=SetupMode.OFF,
)
self.astra_db_env = astra_db_env
self.filter = filter_criteria
self.projection = projection
self.find_options = find_options or {}
self.nb_prefetched = nb_prefetched
self.page_content_mapper = page_content_mapper
self.metadata_mapper = metadata_mapper or (
lambda _: {
"namespace": self.astra_db_env.astra_db.namespace,
"api_endpoint": self.astra_db_env.astra_db.base_url,
"collection": collection_name,
}
)

def _to_langchain_doc(self, doc: Dict[str, Any]) -> Document:
return Document(
page_content=self.page_content_mapper(doc),
metadata=self.metadata_mapper(doc),
)

def lazy_load(self) -> Iterator[Document]:
for doc in self.astra_db_env.collection.paginated_find(
filter=self.filter,
options=self.find_options,
projection=self.projection,
sort=None,
prefetched=self.nb_prefetched,
):
yield self._to_langchain_doc(doc)

async def aload(self) -> List[Document]:
"""Load data into Document objects."""
return [doc async for doc in self.alazy_load()]

async def alazy_load(self) -> AsyncIterator[Document]:
async for doc in self.astra_db_env.async_collection.paginated_find(
filter=self.filter,
options=self.find_options,
projection=self.projection,
sort=None,
prefetched=self.nb_prefetched,
):
yield self._to_langchain_doc(doc)
Loading
Loading