diff --git a/libs/astradb/langchain_astradb/__init__.py b/libs/astradb/langchain_astradb/__init__.py index 3d0d836..0811d51 100644 --- a/libs/astradb/langchain_astradb/__init__.py +++ b/libs/astradb/langchain_astradb/__init__.py @@ -1,5 +1,6 @@ from langchain_astradb.cache import AstraDBCache, AstraDBSemanticCache from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory +from langchain_astradb.document_loaders import AstraDBLoader from langchain_astradb.storage import AstraDBByteStore, AstraDBStore from langchain_astradb.vectorstores import AstraDBVectorStore @@ -9,5 +10,6 @@ "AstraDBCache", "AstraDBSemanticCache", "AstraDBChatMessageHistory", + "AstraDBLoader", "AstraDBVectorStore", ] diff --git a/libs/astradb/langchain_astradb/document_loaders.py b/libs/astradb/langchain_astradb/document_loaders.py new file mode 100644 index 0000000..ec6c078 --- /dev/null +++ b/libs/astradb/langchain_astradb/document_loaders.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import json +import logging +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, +) + +from astrapy.db import AstraDB, AsyncAstraDB +from langchain_core.document_loaders import BaseLoader +from langchain_core.documents import Document + +from langchain_astradb.utils.astradb import ( + SetupMode, + _AstraDBCollectionEnvironment, +) + +logger = logging.getLogger(__name__) + + +class AstraDBLoader(BaseLoader): + def __init__( + self, + collection_name: str, + *, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[AstraDB] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, + namespace: Optional[str] = None, + filter_criteria: Optional[Dict[str, Any]] = None, + projection: Optional[Dict[str, Any]] = None, + find_options: Optional[Dict[str, Any]] = None, + nb_prefetched: int = 1000, + page_content_mapper: Callable[[Dict], str] = json.dumps, + metadata_mapper: Optional[Callable[[Dict], Dict[str, Any]]] = None, + ) -> None: + """Load DataStax Astra DB documents. + + Args: + collection_name: name of the Astra DB collection to use. + token: API token for Astra DB usage. + api_endpoint: full URL to the API endpoint, + such as `https://-us-east1.apps.astra.datastax.com`. + astra_db_client: *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AstraDB' instance. + async_astra_db_client: *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. + namespace: namespace (aka keyspace) where the + collection is. Defaults to the database's "default namespace". + filter_criteria: Criteria to filter documents. + projection: Specifies the fields to return. + find_options: Additional options for the query. + nb_prefetched: Max number of documents to pre-fetch. Defaults to 1000. + page_content_mapper: Function applied to collection documents to create + the `page_content` of the LangChain Document. Defaults to `json.dumps`. + """ + astra_db_env = _AstraDBCollectionEnvironment( + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + namespace=namespace, + setup_mode=SetupMode.OFF, + ) + self.astra_db_env = astra_db_env + self.filter = filter_criteria + self.projection = projection + self.find_options = find_options or {} + self.nb_prefetched = nb_prefetched + self.page_content_mapper = page_content_mapper + self.metadata_mapper = metadata_mapper or ( + lambda _: { + "namespace": self.astra_db_env.astra_db.namespace, + "api_endpoint": self.astra_db_env.astra_db.base_url, + "collection": collection_name, + } + ) + + def _to_langchain_doc(self, doc: Dict[str, Any]) -> Document: + return Document( + page_content=self.page_content_mapper(doc), + metadata=self.metadata_mapper(doc), + ) + + def lazy_load(self) -> Iterator[Document]: + for doc in self.astra_db_env.collection.paginated_find( + filter=self.filter, + options=self.find_options, + projection=self.projection, + sort=None, + prefetched=self.nb_prefetched, + ): + yield self._to_langchain_doc(doc) + + async def aload(self) -> List[Document]: + """Load data into Document objects.""" + return [doc async for doc in self.alazy_load()] + + async def alazy_load(self) -> AsyncIterator[Document]: + async for doc in self.astra_db_env.async_collection.paginated_find( + filter=self.filter, + options=self.find_options, + projection=self.projection, + sort=None, + prefetched=self.nb_prefetched, + ): + yield self._to_langchain_doc(doc) diff --git a/libs/astradb/poetry.lock b/libs/astradb/poetry.lock index 4b2ef14..2db245b 100644 --- a/libs/astradb/poetry.lock +++ b/libs/astradb/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -780,17 +780,18 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, + {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] name = "langchain" -version = "0.1.11" +version = "0.1.12" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain-0.1.11-py3-none-any.whl", hash = "sha256:b5e678ac50d85370b9bc28f2c97ad5f029aac1c0cca79cac9354adf72741bc6e"}, - {file = "langchain-0.1.11.tar.gz", hash = "sha256:03f08cae7cd3f341c54f1042b3fe24d88f39eba7b7eda942735d8ced13fe6da9"}, + {file = "langchain-0.1.12-py3-none-any.whl", hash = "sha256:b4dd1760e2d035daefad08af60a209b96b729ee45492d34e3e127e553a471034"}, + {file = "langchain-0.1.12.tar.gz", hash = "sha256:5f612761ba548b81748ed8dc70535e8de0531445415028a82de3fd8255bfa8a3"}, ] [package.dependencies] @@ -798,8 +799,8 @@ aiohttp = ">=3.8.3,<4.0.0" async-timeout = {version = ">=4.0.0,<5.0.0", markers = "python_version < \"3.11\""} dataclasses-json = ">=0.5.7,<0.7" jsonpatch = ">=1.33,<2.0" -langchain-community = ">=0.0.25,<0.1" -langchain-core = ">=0.1.29,<0.2" +langchain-community = ">=0.0.28,<0.1" +langchain-core = ">=0.1.31,<0.2.0" langchain-text-splitters = ">=0.0.1,<0.1" langsmith = ">=0.1.17,<0.2.0" numpy = ">=1,<2" @@ -825,19 +826,19 @@ text-helpers = ["chardet (>=5.1.0,<6.0.0)"] [[package]] name = "langchain-community" -version = "0.0.27" +version = "0.0.28" description = "Community contributed LangChain integrations." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain_community-0.0.27-py3-none-any.whl", hash = "sha256:377a7429580a71d909012df5aae538d295fa6f21bc479e5dac6fd1589762b3ab"}, - {file = "langchain_community-0.0.27.tar.gz", hash = "sha256:266dffbd4c1666db1889cad953fa5102d4debff782335353b6d78636a761778d"}, + {file = "langchain_community-0.0.28-py3-none-any.whl", hash = "sha256:bdb015ac455ae68432ea104628717583dce041e1abdfcefe86e39f034f5e90b8"}, + {file = "langchain_community-0.0.28.tar.gz", hash = "sha256:8664d243a90550fc5ddc137b712034e02c8d43afc8d4cc832ba5842b44c864ce"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" dataclasses-json = ">=0.5.7,<0.7" -langchain-core = ">=0.1.30,<0.2.0" +langchain-core = ">=0.1.31,<0.2.0" langsmith = ">=0.1.0,<0.2.0" numpy = ">=1,<2" PyYAML = ">=5.3" @@ -847,17 +848,17 @@ tenacity = ">=8.1.0,<9.0.0" [package.extras] cli = ["typer (>=0.9.0,<0.10.0)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"] [[package]] name = "langchain-core" -version = "0.1.30" +version = "0.1.31" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain_core-0.1.30-py3-none-any.whl", hash = "sha256:c9643505e41d25ba8f20a2e8bf083d0f0d50b9a098d901511fff8df79f831ada"}, - {file = "langchain_core-0.1.30.tar.gz", hash = "sha256:e13a016e55e7f082ff3eeeda2d0cb505b89a8830e3a23c1d134d0a89d7871894"}, + {file = "langchain_core-0.1.31-py3-none-any.whl", hash = "sha256:ff028f00db8ff03565b542cea81be27426022a72c6545b54d8de66fa00948ab3"}, + {file = "langchain_core-0.1.31.tar.gz", hash = "sha256:d660cf209bb6ce61cb1c853107b091aaa809015a55dce9e0ce19b51d4c8f2a70"}, ] [package.dependencies] @@ -892,13 +893,13 @@ extended-testing = ["lxml (>=5.1.0,<6.0.0)"] [[package]] name = "langsmith" -version = "0.1.22" +version = "0.1.24" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langsmith-0.1.22-py3-none-any.whl", hash = "sha256:b877d302bd4cf7c79e9e6e24bedf669132abf0659143390a29350eda0945544f"}, - {file = "langsmith-0.1.22.tar.gz", hash = "sha256:2921ae2297c2fb23baa2641b9cf416914ac7fd65f4a9dd5a573bc30efb54b693"}, + {file = "langsmith-0.1.24-py3-none-any.whl", hash = "sha256:898ef5265bca8fc912f7fbf207e1d69cacd86055faecf6811bd42641e6319840"}, + {file = "langsmith-0.1.24.tar.gz", hash = "sha256:432b829e763f5077df411bc59bb35449813f18174d2ebc8bbbb38427071d5e7d"}, ] [package.dependencies] @@ -1209,13 +1210,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pydantic" -version = "2.6.3" +version = "2.6.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"}, - {file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"}, + {file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"}, + {file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"}, ] [package.dependencies] @@ -1444,6 +1445,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1451,8 +1453,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -1469,6 +1478,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -1476,6 +1486,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -1609,7 +1620,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} typing-extensions = ">=4.6.0" [package.extras] @@ -1877,4 +1888,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "45fced605e3bc7769326708c9e17fadf7186c59666f548b1a12e00e718ec362b" +content-hash = "5e5ddff3dc3b67b7d1f6336fe9464ee5a641b07d4ff4e591f9a2d58ee7373c6b" diff --git a/libs/astradb/pyproject.toml b/libs/astradb/pyproject.toml index 2181b1a..58be847 100644 --- a/libs/astradb/pyproject.toml +++ b/libs/astradb/pyproject.toml @@ -12,7 +12,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -langchain-core = "^0.1.5" +langchain-core = "^0.1.31" astrapy = "^0.7.7" numpy = "^1" diff --git a/libs/astradb/tests/integration_tests/test_document_loaders.py b/libs/astradb/tests/integration_tests/test_document_loaders.py new file mode 100644 index 0000000..affa639 --- /dev/null +++ b/libs/astradb/tests/integration_tests/test_document_loaders.py @@ -0,0 +1,194 @@ +""" +Test of Astra DB document loader class `AstraDBLoader` + +Required to run this test: + - a recent `astrapy` Python package available + - an Astra DB instance; + - the two environment variables set: + export ASTRA_DB_API_ENDPOINT="https://-us-east1.apps.astra.datastax.com" + export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........." + - optionally this as well (otherwise defaults are used): + export ASTRA_DB_KEYSPACE="my_keyspace" +""" +from __future__ import annotations + +import json +import os +import uuid +from typing import AsyncIterator, Iterator + +import pytest +from astrapy.db import AstraDBCollection, AsyncAstraDBCollection + +from langchain_astradb import AstraDBLoader + +ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN") +ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT") +ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_KEYSPACE") + + +def _has_env_vars() -> bool: + return all([ASTRA_DB_APPLICATION_TOKEN, ASTRA_DB_API_ENDPOINT]) + + +@pytest.fixture +def astra_db_collection() -> Iterator[AstraDBCollection]: + from astrapy.db import AstraDB + + astra_db = AstraDB( + token=ASTRA_DB_APPLICATION_TOKEN or "", + api_endpoint=ASTRA_DB_API_ENDPOINT or "", + namespace=ASTRA_DB_KEYSPACE, + ) + collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" + collection = astra_db.create_collection(collection_name) + collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) + collection.insert_many( + [{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4 + ) + + yield collection + + astra_db.delete_collection(collection_name) + + +@pytest.fixture +async def async_astra_db_collection() -> AsyncIterator[AsyncAstraDBCollection]: + from astrapy.db import AsyncAstraDB + + astra_db = AsyncAstraDB( + token=ASTRA_DB_APPLICATION_TOKEN or "", + api_endpoint=ASTRA_DB_API_ENDPOINT or "", + namespace=ASTRA_DB_KEYSPACE, + ) + collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" + collection = await astra_db.create_collection(collection_name) + await collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) + await collection.insert_many( + [{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4 + ) + + yield collection + + await astra_db.delete_collection(collection_name) + + +@pytest.mark.requires("astrapy") +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +class TestAstraDB: + def test_astradb_loader(self, astra_db_collection: AstraDBCollection) -> None: + loader = AstraDBLoader( + astra_db_collection.collection_name, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + nb_prefetched=1, + projection={"foo": 1}, + find_options={"limit": 22}, + filter_criteria={"foo": "bar"}, + ) + docs = loader.load() + + assert len(docs) == 22 + ids = set() + for doc in docs: + content = json.loads(doc.page_content) + assert content["foo"] == "bar" + assert "baz" not in content + assert content["_id"] not in ids + ids.add(content["_id"]) + assert doc.metadata == { + "namespace": astra_db_collection.astra_db.namespace, + "api_endpoint": astra_db_collection.astra_db.base_url, + "collection": astra_db_collection.collection_name, + } + + def test_page_content_mapper(self, astra_db_collection: AstraDBCollection) -> None: + loader = AstraDBLoader( + astra_db_collection.collection_name, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + find_options={"limit": 30}, + page_content_mapper=lambda x: x["foo"], + ) + docs = loader.lazy_load() + doc = next(docs) + + assert doc.page_content == "bar" + + def test_metadata_mapper(self, astra_db_collection: AstraDBCollection) -> None: + loader = AstraDBLoader( + astra_db_collection.collection_name, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + find_options={"limit": 30}, + metadata_mapper=lambda x: {"a": x["foo"]}, + ) + docs = loader.lazy_load() + doc = next(docs) + + assert doc.metadata == {"a": "bar"} + + async def test_astradb_loader_async( + self, async_astra_db_collection: AsyncAstraDBCollection + ) -> None: + await async_astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) + await async_astra_db_collection.insert_many( + [{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4 + ) + + loader = AstraDBLoader( + async_astra_db_collection.collection_name, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + nb_prefetched=1, + projection={"foo": 1}, + find_options={"limit": 22}, + filter_criteria={"foo": "bar"}, + ) + docs = await loader.aload() + + assert len(docs) == 22 + ids = set() + for doc in docs: + content = json.loads(doc.page_content) + assert content["foo"] == "bar" + assert "baz" not in content + assert content["_id"] not in ids + ids.add(content["_id"]) + assert doc.metadata == { + "namespace": async_astra_db_collection.astra_db.namespace, + "api_endpoint": async_astra_db_collection.astra_db.base_url, + "collection": async_astra_db_collection.collection_name, + } + + async def test_page_content_mapper_async( + self, async_astra_db_collection: AsyncAstraDBCollection + ) -> None: + loader = AstraDBLoader( + async_astra_db_collection.collection_name, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + find_options={"limit": 30}, + page_content_mapper=lambda x: x["foo"], + ) + doc = await loader.alazy_load().__anext__() + assert doc.page_content == "bar" + + async def test_metadata_mapper_async( + self, async_astra_db_collection: AsyncAstraDBCollection + ) -> None: + loader = AstraDBLoader( + async_astra_db_collection.collection_name, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + find_options={"limit": 30}, + metadata_mapper=lambda x: {"a": x["foo"]}, + ) + doc = await loader.alazy_load().__anext__() + assert doc.metadata == {"a": "bar"} diff --git a/libs/astradb/tests/unit_tests/test_imports.py b/libs/astradb/tests/unit_tests/test_imports.py index 1164424..c2bcff9 100644 --- a/libs/astradb/tests/unit_tests/test_imports.py +++ b/libs/astradb/tests/unit_tests/test_imports.py @@ -6,6 +6,7 @@ "AstraDBCache", "AstraDBSemanticCache", "AstraDBChatMessageHistory", + "AstraDBLoader", "AstraDBVectorStore", ]