diff --git a/.github/labeler.yml b/.github/labeler.yml index a37f44111..93eba1d82 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -4,6 +4,11 @@ integration:amazon-bedrock: - any-glob-to-any-file: "integrations/amazon_bedrock/**/*" - any-glob-to-any-file: ".github/workflows/amazon_bedrock.yml" +integration:astra: + - changed-files: + - any-glob-to-any-file: "integrations/astra/**/*" + - any-glob-to-any-file: ".github/workflows/astra.yml" + integration:chroma: - changed-files: - any-glob-to-any-file: "integrations/chroma/**/*" diff --git a/.github/workflows/astra.yml b/.github/workflows/astra.yml new file mode 100644 index 000000000..b751550de --- /dev/null +++ b/.github/workflows/astra.yml @@ -0,0 +1,60 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / astra + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - 'integrations/astra/**' + - '.github/workflows/astra.yml' + +defaults: + run: + working-directory: integrations/astra + +concurrency: + group: astra-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + max-parallel: 1 + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ['3.9', '3.10'] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Run tests + env: + ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} + ASTRA_DB_ID: ${{ secrets.ASTRA_DB_ID }} + run: hatch run cov \ No newline at end of file diff --git a/README.md b/README.md index 38246ff43..46517c76b 100644 --- a/README.md +++ b/README.md @@ -77,4 +77,3 @@ deepset-haystack | [pinecone-haystack](integrations/pinecone/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/pinecone-haystack.svg?color=orange)](https://pypi.org/project/pinecone-haystack) | [![Test / pinecone](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml) | | [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/fileconverter/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured / fileconverter](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml) | -| [weaviate-haystack](integrations/weaviate/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/weaviate-haystack.svg?color=orange)](https://pypi.org/project/weaviate-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml) | diff --git a/integrations/astra/LICENSE b/integrations/astra/LICENSE new file mode 100644 index 000000000..6134ab324 --- /dev/null +++ b/integrations/astra/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023-present deepset GmbH + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/integrations/astra/README.md b/integrations/astra/README.md new file mode 100644 index 000000000..9c3b264cb --- /dev/null +++ b/integrations/astra/README.md @@ -0,0 +1,94 @@ +[![test](https://github.com/deepset-ai/document-store/actions/workflows/test.yml/badge.svg)](https://github.com/deepset-ai/document-store/actions/workflows/test.yml) + +# Astra Store + +## Installation +install astra-haystack package locally to run integration tests: + +Open in gitpod: +[![Open in Gitpod](https://gitpod.io/button/open-in-gitpod.svg)](https://gitpod.io/#https://github.com/Anant/astra-haystack/tree/main) + +Switch Python version to 3.9 (Requires 3.8+ but not 3.12) +``` +pyenv install 3.9 +pyenv local 3.9 +``` + +Local install for the package +`pip install -e .` +To execute integration tests, add needed environment variables +`ASTRA_DB_ID=` +`ASTRA_DB_APPLICATION_TOKEN=` +and execute +`python examples/example.py` + +Install requirements +`pip install -r requirements.txt` + +Export environment variables +``` +export KEYSPACE_NAME= +export COLLECTION_NAME= +export OPENAI_API_KEY= +export ASTRA_DB_ID= +export ASTRA_DB_REGION= +export ASTRA_DB_APPLICATION_TOKEN= +``` + +run the python examples +`python example/example.py` +or +`python example/pipeline_example.py` + +## Usage + +This package includes Astra Document Store and Astra Retriever classes that integrate with Haystack, allowing you to easily perform document retrieval or RAG with Astra, and include those functions in Haystack pipelines. + +### In order to use the Document Store directly: + +Import the Document Store: +``` +from astra_store.document_store import AstraDocumentStore +from haystack.preview.document_stores import DuplicatePolicy +``` + +Load in environment variables: +``` +astra_id = os.getenv("ASTRA_DB_ID", "") +astra_region = os.getenv("ASTRA_DB_REGION", "us-east1") + +astra_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") +collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search") +keyspace_name = os.getenv("KEYSPACE_NAME", "recommender_demo") +``` + +Create the Document Store object: +``` +document_store = AstraDocumentStore( + astra_id=astra_id, + astra_region=astra_region, + astra_collection=collection_name, + astra_keyspace=keyspace_name, + astra_application_token=astra_application_token, + duplicates_policy=DuplicatePolicy.SKIP, + embedding_dim=384, +) +``` + +Then you can use the document store functions like count_document below: +`document_store.count_documents()` + +### Using the Astra Retriever with Haystack Pipelines + +Create the Document Store object like above, then import and create the Pipeline: + +``` +from haystack.preview import Pipeline +pipeline = Pipeline() +``` + +Add your AstraRetriever into the pipeline +`pipeline.add_component(instance=AstraSingleRetriever(document_store=document_store), name="retriever")` + +Add other components and connect them as desired. Then run your pipeline: +`pipeline.run(...)` diff --git a/integrations/astra/__init__.py b/integrations/astra/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/astra/examples/data/usr_01.txt b/integrations/astra/examples/data/usr_01.txt new file mode 100644 index 000000000..2cb74d47c --- /dev/null +++ b/integrations/astra/examples/data/usr_01.txt @@ -0,0 +1,192 @@ +*usr_01.txt* For Vim version 9.0. Last change: 2019 Nov 21 + + VIM USER MANUAL - by Bram Moolenaar + + About the manuals + + +This chapter introduces the manuals available with Vim. Read this to know the +conditions under which the commands are explained. + +|01.1| Two manuals +|01.2| Vim installed +|01.3| Using the Vim tutor +|01.4| Copyright + + Next chapter: |usr_02.txt| The first steps in Vim +Table of contents: |usr_toc.txt| + +============================================================================== +*01.1* Two manuals + +The Vim documentation consists of two parts: + +1. The User manual + Task oriented explanations, from simple to complex. Reads from start to + end like a book. + +2. The Reference manual + Precise description of how everything in Vim works. + +The notation used in these manuals is explained here: |notation| + + +JUMPING AROUND + +The text contains hyperlinks between the two parts, allowing you to quickly +jump between the description of an editing task and a precise explanation of +the commands and options used for it. Use these two commands: + + Press CTRL-] to jump to a subject under the cursor. + Press CTRL-O to jump back (repeat to go further back). + +Many links are in vertical bars, like this: |bars|. The bars themselves may +be hidden or invisible; see below. An option name, like 'number', a command +in double quotes like ":write" and any other word can also be used as a link. +Try it out: Move the cursor to CTRL-] and press CTRL-] on it. + +Other subjects can be found with the ":help" command; see |help.txt|. + +The bars and stars are usually hidden with the |conceal| feature. They also +use |hl-Ignore|, using the same color for the text as the background. You can +make them visible with: > + :set conceallevel=0 + :hi link HelpBar Normal + :hi link HelpStar Normal + +============================================================================== +*01.2* Vim installed + +Most of the manuals assume that Vim has been properly installed. If you +didn't do that yet, or if Vim doesn't run properly (e.g., files can't be found +or in the GUI the menus do not show up) first read the chapter on +installation: |usr_90.txt|. + *not-compatible* +The manuals often assume you are using Vim with Vi-compatibility switched +off. For most commands this doesn't matter, but sometimes it is important, +e.g., for multi-level undo. An easy way to make sure you are using a nice +setup is to copy the example vimrc file. By doing this inside Vim you don't +have to check out where it is located. How to do this depends on the system +you are using: + +Unix: > + :!cp -i $VIMRUNTIME/vimrc_example.vim ~/.vimrc +MS-Windows: > + :!copy $VIMRUNTIME/vimrc_example.vim $VIM/_vimrc +Amiga: > + :!copy $VIMRUNTIME/vimrc_example.vim $VIM/.vimrc + +If the file already exists you probably want to keep it. + +If you start Vim now, the 'compatible' option should be off. You can check it +with this command: > + + :set compatible? + +If it responds with "nocompatible" you are doing well. If the response is +"compatible" you are in trouble. You will have to find out why the option is +still set. Perhaps the file you wrote above is not found. Use this command +to find out: > + + :scriptnames + +If your file is not in the list, check its location and name. If it is in the +list, there must be some other place where the 'compatible' option is switched +back on. + +For more info see |vimrc| and |compatible-default|. + + Note: + This manual is about using Vim in the normal way. There is an + alternative called "evim" (easy Vim). This is still Vim, but used in + a way that resembles a click-and-type editor like Notepad. It always + stays in Insert mode, thus it feels very different. It is not + explained in the user manual, since it should be mostly + self-explanatory. See |evim-keys| for details. + +============================================================================== +*01.3* Using the Vim tutor *tutor* *vimtutor* + +Instead of reading the text (boring!) you can use the vimtutor to learn your +first Vim commands. This is a 30-minute tutorial that teaches the most basic +Vim functionality hands-on. + +On Unix, if Vim has been properly installed, you can start it from the shell: +> + vimtutor + +On MS-Windows you can find it in the Program/Vim menu. Or execute +vimtutor.bat in the $VIMRUNTIME directory. + +This will make a copy of the tutor file, so that you can edit it without +the risk of damaging the original. + There are a few translated versions of the tutor. To find out if yours is +available, use the two-letter language code. For French: > + + vimtutor fr + +On Unix, if you prefer using the GUI version of Vim, use "gvimtutor" or +"vimtutor -g" instead of "vimtutor". + +For OpenVMS, if Vim has been properly installed, you can start vimtutor from a +VMS prompt with: > + + @VIM:vimtutor + +Optionally add the two-letter language code as above. + + +On other systems, you have to do a little work: + +1. Copy the tutor file. You can do this with Vim (it knows where to find it): +> + vim --clean -c 'e $VIMRUNTIME/tutor/tutor' -c 'w! TUTORCOPY' -c 'q' +< + This will write the file "TUTORCOPY" in the current directory. To use a +translated version of the tutor, append the two-letter language code to the +filename. For French: +> + vim --clean -c 'e $VIMRUNTIME/tutor/tutor.fr' -c 'w! TUTORCOPY' -c 'q' +< +2. Edit the copied file with Vim: +> + vim --clean TUTORCOPY +< + The --clean argument makes sure Vim is started with nice defaults. + +3. Delete the copied file when you are finished with it: +> + del TUTORCOPY +< +============================================================================== +*01.4* Copyright *manual-copyright* + +The Vim user manual and reference manual are Copyright (c) 1988-2003 by Bram +Moolenaar. This material may be distributed only subject to the terms and +conditions set forth in the Open Publication License, v1.0 or later. The +latest version is presently available at: + http://www.opencontent.org/openpub/ + +People who contribute to the manuals must agree with the above copyright +notice. + *frombook* +Parts of the user manual come from the book "Vi IMproved - Vim" by Steve +Oualline (published by New Riders Publishing, ISBN: 0735710015). The Open +Publication License applies to this book. Only selected parts are included +and these have been modified (e.g., by removing the pictures, updating the +text for Vim 6.0 and later, fixing mistakes). The omission of the |frombook| +tag does not mean that the text does not come from the book. + +Many thanks to Steve Oualline and New Riders for creating this book and +publishing it under the OPL! It has been a great help while writing the user +manual. Not only by providing literal text, but also by setting the tone and +style. + +If you make money through selling the manuals, you are strongly encouraged to +donate part of the profit to help AIDS victims in Uganda. See |iccf|. + +============================================================================== + +Next chapter: |usr_02.txt| The first steps in Vim + +Copyright: see |manual-copyright| vim:tw=78:ts=8:noet:ft=help:norl: diff --git a/integrations/astra/examples/example.py b/integrations/astra/examples/example.py new file mode 100644 index 000000000..ac93f43ed --- /dev/null +++ b/integrations/astra/examples/example.py @@ -0,0 +1,119 @@ +import logging +import os +from pathlib import Path + +from haystack import Pipeline +from haystack.components.converters import TextFileToDocument +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter +from haystack.components.routers import FileTypeRouter +from haystack.components.writers import DocumentWriter +from haystack.document_stores import DuplicatePolicy + +from astra_haystack.document_store import AstraDocumentStore +from astra_haystack.retriever import AstraRetriever + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +HERE = Path(__file__).resolve().parent +file_paths = [HERE / "data" / Path(name) for name in os.listdir("integrations/astra/examples/data")] +logger.info(file_paths) + +astra_id = os.getenv("ASTRA_DB_ID", "") +astra_region = os.getenv("ASTRA_DB_REGION", "us-east1") + +astra_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") +collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search") +keyspace_name = os.getenv("KEYSPACE_NAME", "recommender_demo") + +# We support many different databases. Here, we load a simple and lightweight in-memory database. +document_store = AstraDocumentStore( + astra_id=astra_id, + astra_region=astra_region, + astra_collection=collection_name, + astra_keyspace=keyspace_name, + astra_application_token=astra_application_token, + duplicates_policy=DuplicatePolicy.OVERWRITE, + embedding_dim=384, +) + +# Create components and an indexing pipeline that converts txt files to documents, +# cleans and splits them, and indexes them +p = Pipeline() +p.add_component(instance=FileTypeRouter(mime_types=["text/plain", "application/pdf"]), name="file_type_router") +p.add_component(instance=TextFileToDocument(), name="text_file_converter") +p.add_component(instance=DocumentCleaner(), name="cleaner") +p.add_component(instance=DocumentSplitter(split_by="word", split_length=150, split_overlap=30), name="splitter") +p.add_component( + instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), + name="embedder", +) +p.add_component(instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="writer") + +p.connect("file_type_router.text/plain", "text_file_converter.sources") +p.connect("text_file_converter.documents", "cleaner.documents") +p.connect("cleaner.documents", "splitter.documents") +p.connect("splitter.documents", "embedder.documents") +p.connect("embedder.documents", "writer.documents") + +p.run({"file_type_router": {"sources": file_paths}}) + +# Create a querying pipeline on the indexed data +q = Pipeline() +q.add_component( + instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), + name="embedder", +) +q.add_component("retriever", AstraRetriever(document_store)) + +q.connect("embedder", "retriever") + +question = "This chapter introduces the manuals available with Vim" +result = q.run({"embedder": {"text": question}, "retriever": {"top_k": 1}}) +logger.info(result) + +ALL_DOCUMENTS_COUNT = 9 +documents_count = document_store.count_documents() +logger.info("count:") +logger.info(documents_count) +if documents_count != ALL_DOCUMENTS_COUNT: + msg = f"count mismatch, expected 9 documents, got {documents_count}" + raise ValueError(msg) + +logger.info( + f"""filter results: {document_store.filter_documents( + { + "field": "meta", + "operator": "==", + "value": { + "file_path": "/workspace/astra-haystack/examples/data/usr_01.txt", + "source_id": "5b2d27de79bba97da6fc446180d0d99e1024bc7dd6a757037f0934162cfb0916", + }, + } + ) +}""" +) + +logger.info( + f"""get_document_by_id {document_store.get_document_by_id( + "92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10")}""" +) + +logger.info( + f"""get_documents_by_ids {document_store.get_documents_by_id( + [ + "92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10", + "6f2450a51eaa3eeb9239d875402bcfe24b2d3534ff27f26c1f3fc8133b04e756", + ] + )}""" +) + +document_store.delete_documents(["92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10"]) + +documents_count = document_store.count_documents() +logger.info(f"count: {document_store.count_documents()}") +if documents_count != ALL_DOCUMENTS_COUNT - 1: + msg = f"count mismatch, expected 9 documents, got {documents_count}" + raise ValueError(msg) diff --git a/integrations/astra/examples/pipeline_example.py b/integrations/astra/examples/pipeline_example.py new file mode 100644 index 000000000..fb13c3d93 --- /dev/null +++ b/integrations/astra/examples/pipeline_example.py @@ -0,0 +1,107 @@ +import logging +import os + +from haystack import Document, Pipeline +from haystack.components.builders.answer_builder import AnswerBuilder +from haystack.components.builders.prompt_builder import PromptBuilder +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.generators import OpenAIGenerator +from haystack.components.writers import DocumentWriter +from haystack.document_stores import DuplicatePolicy + +from astra_haystack.document_store import AstraDocumentStore +from astra_haystack.retriever import AstraRetriever + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +# Create a RAG query pipeline +prompt_template = """ + Given these documents, answer the question. + + Documents: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + + Question: {{question}} + + Answer: + """ + +astra_id = os.getenv("ASTRA_DB_ID", "") +astra_region = os.getenv("ASTRA_DB_REGION", "us-east1") + +astra_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") +collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search") +keyspace_name = os.getenv("KEYSPACE_NAME", "recommender_demo") + +# We support many different databases. Here, we load a simple and lightweight in-memory database. +document_store = AstraDocumentStore( + astra_id=astra_id, + astra_region=astra_region, + astra_collection=collection_name, + astra_keyspace=keyspace_name, + astra_application_token=astra_application_token, + duplicates_policy=DuplicatePolicy.SKIP, + embedding_dim=384, +) + + +# Add Documents +documents = [ + Document(content="There are over 7,000 languages spoken around the world today."), + Document( + content="Elephants have been observed to behave in a way that indicates" + " a high level of self-awareness, such as recognizing themselves in mirrors." + ), + Document( + content="In certain parts of the world, like the Maldives, Puerto Rico, " + "and San Diego, you can witness the phenomenon of bioluminescent waves." + ), +] +p = Pipeline() +p.add_component( + instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), + name="embedder", +) +p.add_component(instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="writer") +p.connect("embedder.documents", "writer.documents") + +p.run({"embedder": {"documents": documents}}) + + +# Construct rag pipeline +rag_pipeline = Pipeline() +rag_pipeline.add_component( + instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), + name="embedder", +) +rag_pipeline.add_component(instance=AstraRetriever(document_store=document_store), name="retriever") +rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") +rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") +rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") +rag_pipeline.connect("embedder", "retriever") +rag_pipeline.connect("retriever", "prompt_builder.documents") +rag_pipeline.connect("prompt_builder", "llm") +rag_pipeline.connect("llm.replies", "answer_builder.replies") +rag_pipeline.connect("llm.meta", "answer_builder.meta") +rag_pipeline.connect("retriever", "answer_builder.documents") + + +# Draw the pipeline +rag_pipeline.draw("./rag_pipeline.png") + + +# Run the pipeline +question = "How many languages are there in the world today?" +result = rag_pipeline.run( + { + "embedder": {"text": question}, + "retriever": {"top_k": 2}, + "prompt_builder": {"question": question}, + "answer_builder": {"query": question}, + } +) + +logger.info(result) diff --git a/integrations/astra/examples/requirements.txt b/integrations/astra/examples/requirements.txt new file mode 100644 index 000000000..be27e7427 --- /dev/null +++ b/integrations/astra/examples/requirements.txt @@ -0,0 +1,3 @@ +haystack-ai==2.0.0b4 +sentence_transformers==2.2.2 +openai==1.6.1 \ No newline at end of file diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml new file mode 100644 index 000000000..b99449e03 --- /dev/null +++ b/integrations/astra/pyproject.toml @@ -0,0 +1,187 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "astra-haystack" +dynamic = ["version"] +description = '' +readme = "README.md" +requires-python = ">=3.7" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "Anant Corporation", email = "support@anant.us" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "haystack-ai", + "pydantic", + "typing_extensions", +] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/astra#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/astra" + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/astra-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/astra-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.7", "3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive {args:src/astra_haystack tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] +exclude = ["example"] + +[tool.ruff.isort] +known-first-party = ["astra_haystack"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source_pkgs = ["astra_haystack", "tests"] +branch = true +parallel = true +omit = [ + "example" +] + +[tool.coverage.paths] +astra_haystack = ["src/astra_haystack", "*/astra-store/src/astra_haystack"] +tests = ["tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[tool.pytest.ini_options] +minversion = "6.0" +markers = [ + "unit: unit tests", + "integration: integration tests" +] + +[[tool.mypy.overrides]] +module = [ + "astra_haystack.*", + "astra_client.*", + "pydantic.*", + "haystack.*", + "pytest.*" +] +ignore_missing_imports = true diff --git a/integrations/astra/src/astra_haystack/__init__.py b/integrations/astra/src/astra_haystack/__init__.py new file mode 100644 index 000000000..5c99dedf6 --- /dev/null +++ b/integrations/astra/src/astra_haystack/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 +from astra_haystack.document_store import AstraDocumentStore + +__all__ = ["AstraDocumentStore"] diff --git a/integrations/astra/src/astra_haystack/astra_client.py b/integrations/astra/src/astra_haystack/astra_client.py new file mode 100644 index 000000000..ec0263a5a --- /dev/null +++ b/integrations/astra/src/astra_haystack/astra_client.py @@ -0,0 +1,298 @@ +import json +import logging +from typing import Dict, List, Optional, Union + +import requests +from pydantic.dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class Response: + document_id: str + text: Optional[str] + values: Optional[list] + metadata: Optional[dict] + score: Optional[float] + + +@dataclass +class QueryResponse: + matches: List[Response] + + def get(self, key): + return self.__dict__[key] + + +class AstraClient: + """ + A client for interacting with an Astra index via JSON API + """ + + def __init__( + self, + astra_id: str, + astra_region: str, + astra_application_token: str, + keyspace_name: str, + collection_name: str, + embedding_dim: int, + similarity_function: str, + ): + self.astra_id = astra_id + self.astra_application_token = astra_application_token + self.astra_region = astra_region + self.keyspace_name = keyspace_name + self.collection_name = collection_name + self.embedding_dim = embedding_dim + self.similarity_function = similarity_function + + self.request_url = f"https://{self.astra_id}-{self.astra_region}.apps.astra.datastax.com/api/json/v1/{self.keyspace_name}/{self.collection_name}" + self.request_header = { + "x-cassandra-token": self.astra_application_token, + "Content-Type": "application/json", + } + self.create_url = ( + f"https://{self.astra_id}-{self.astra_region}.apps.astra.datastax.com/api/json/v1/{self.keyspace_name}" + ) + + index_exists = self.find_index() + if not index_exists: + self.create_index() + + def find_index(self): + find_query = {"findCollections": {"options": {"explain": True}}} + response = requests.request("POST", self.create_url, headers=self.request_header, data=json.dumps(find_query)) + response.raise_for_status() + response_dict = json.loads(response.text) + + if "status" in response_dict: + collection_name_matches = list( + filter(lambda d: d["name"] == self.collection_name, response_dict["status"]["collections"]) + ) + + if len(collection_name_matches) == 0: + logger.warning( + f"Astra collection {self.collection_name} not found under {self.keyspace_name}. Will be created." + ) + return False + + collection_embedding_dim = collection_name_matches[0]["options"]["vector"]["dimension"] + if collection_embedding_dim != self.embedding_dim: + msg = ( + f"Collection vector dimension is not valid, expected {self.embedding_dim}, " + f"found {collection_embedding_dim}" + ) + raise Exception(msg) + + else: + msg = f"status not in response: {response.text}" + raise Exception(msg) + + return True + + def create_index(self): + create_query = { + "createCollection": { + "name": self.collection_name, + "options": {"vector": {"dimension": self.embedding_dim, "metric": self.similarity_function}}, + } + } + response = requests.request("POST", self.create_url, headers=self.request_header, data=json.dumps(create_query)) + response.raise_for_status() + response_dict = json.loads(response.text) + if "errors" in response_dict: + raise Exception(response_dict["errors"]) + logger.info(f"Collection {self.collection_name} created: {response.text}") + + def query( + self, + vector: Optional[List[float]] = None, + query_filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + top_k: Optional[int] = None, + include_metadata: Optional[bool] = None, + include_values: Optional[bool] = None, + ) -> QueryResponse: + """ + The Query operation searches a namespace, using a query vector. + It retrieves the ids of the most similar items in a namespace, along with their similarity scores. + + Args: + vector (List[float]): The query vector. This should be the same length as the dimension of the index + being queried. Each `query()` request can contain only one of the parameters + `queries`, `id` or `vector`... [optional] + top_k (int): The number of results to return for each query. Must be an integer greater than 1. + query_filter (Dict[str, Union[str, float, int, bool, List, dict]): + The filter to apply. You can use vector metadata to limit your search. [optional] + include_metadata (bool): Indicates whether metadata is included in the response as well as the ids. + If omitted the server will use the default value of False [optional] + include_values (bool): Indicates whether values/vector is included in the response as well as the ids. + If omitted the server will use the default value of False [optional] + + Returns: object which contains the list of the closest vectors as ScoredVector objects, + and namespace name. + """ + # get vector data and scores + if vector is None: + responses = self._query_without_vector(top_k, query_filter) + else: + responses = self._query(vector, top_k, query_filter) + + # include_metadata means return all columns in the table (including text that got embedded) + # include_values means return the vector of the embedding for the searched items + formatted_response = self._format_query_response(responses, include_metadata, include_values) + + return formatted_response + + def _query_without_vector(self, top_k, filters=None): + query = {"filter": filters, "options": {"limit": top_k}} + return self.find_documents(query) + + @staticmethod + def _format_query_response(responses, include_metadata, include_values): + final_res = [] + if responses is None: + return QueryResponse(matches=[]) + for response in responses: + _id = response.pop("_id") + score = response.pop("$similarity", None) + text = response.pop("content", None) + values = response.pop("$vector", None) if include_values else [] + metadata = response if include_metadata else {} # Add all remaining fields to the metadata + rsp = Response(_id, text, values, metadata, score) + final_res.append(rsp) + return QueryResponse(final_res) + + def _query(self, vector, top_k, filters=None): + query = {"sort": {"$vector": vector}, "options": {"limit": top_k, "includeSimilarity": True}} + if filters is not None: + query["filter"] = filters + result = self.find_documents(query) + return result + + def find_documents(self, find_query): + query = json.dumps({"find": find_query}) + response = requests.request( + "POST", + self.request_url, + headers=self.request_header, + data=query, + ) + response.raise_for_status() + response_dict = json.loads(response.text) + if "errors" in response_dict: + raise Exception(response_dict["errors"]) + if "data" in response_dict and "documents" in response_dict["data"]: + return response_dict["data"]["documents"] + else: + logger.warning(f"No documents found: {response_dict}") + + def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse: + document_batch = [] + + def batch_generator(chunks, batch_size): + for i in range(0, len(chunks), batch_size): + i_end = min(len(chunks), i + batch_size) + batch = chunks[i:i_end] + yield batch + + for id_batch in batch_generator(ids, batch_size): + document_batch.extend(self.find_documents({"filter": {"_id": {"$in": id_batch}}})) + formatted_docs = self._format_query_response(document_batch, include_metadata=True, include_values=True) + return formatted_docs + + def insert(self, documents: List[Dict]): + query = json.dumps({"insertMany": {"options": {"ordered": False}, "documents": documents}}) + response = requests.request( + "POST", + self.request_url, + headers=self.request_header, + data=query, + ) + response.raise_for_status() + response_dict = json.loads(response.text) + + inserted_ids = ( + response_dict["status"]["insertedIds"] + if "status" in response_dict and "insertedIds" in response_dict["status"] + else [] + ) + if "errors" in response_dict: + logger.error(response_dict["errors"]) + return inserted_ids + + def update_document(self, document: Dict, id_key: str): + document_id = document.pop(id_key) + query = json.dumps( + { + "findOneAndUpdate": { + "filter": {id_key: document_id}, + "update": {"$set": document}, + "options": {"returnDocument": "after"}, + } + } + ) + response = requests.request( + "POST", + self.request_url, + headers=self.request_header, + data=query, + ) + response.raise_for_status() + response_dict = json.loads(response.text) + document[id_key] = document_id + + if "status" in response_dict and "errors" not in response_dict: + if "matchedCount" in response_dict["status"] and "modifiedCount" in response_dict["status"]: + if response_dict["status"]["matchedCount"] == 1 and response_dict["status"]["modifiedCount"] == 1: + return True + logger.warning(f"Documents {document_id} not updated in Astra {response.text}") + return False + + def delete( + self, + ids: Optional[List[str]] = None, + delete_all: Optional[bool] = None, + filters: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + ) -> int: + if delete_all: + query = {"deleteMany": {}} # type: dict + if ids is not None: + query = {"deleteMany": {"filter": {"_id": {"$in": ids}}}} + if filters is not None: + query = {"deleteMany": {"filter": filters}} + + deletion_counter = 0 + moredata = True + while moredata: + response = requests.request( + "POST", + self.request_url, + headers=self.request_header, + data=json.dumps(query), + ) + response.raise_for_status() + response_dict = response.json() + if "errors" in response_dict: + raise Exception(response_dict["errors"]) + if "moreData" not in response_dict.get("status", {}): + moredata = False + deletion_counter += int(response_dict["status"].get("deletedCount", 0)) + + return deletion_counter + + def count_documents(self) -> int: + """ + Returns how many documents are present in the document store. + """ + response = requests.request( + "POST", + self.request_url, + headers=self.request_header, + data=json.dumps({"countDocuments": {}}), + ) + response.raise_for_status() + if "errors" in response.json(): + raise Exception(response.json()["errors"]) + return response.json()["status"]["count"] diff --git a/integrations/astra/src/astra_haystack/document_store.py b/integrations/astra/src/astra_haystack/document_store.py new file mode 100644 index 000000000..a9a02c148 --- /dev/null +++ b/integrations/astra/src/astra_haystack/document_store.py @@ -0,0 +1,383 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 +import json +import logging +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Union + +import pandas as pd +from haystack import default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores import ( + DuplicateDocumentError, + DuplicatePolicy, + MissingDocumentError, +) + +from astra_haystack.astra_client import AstraClient +from astra_haystack.errors import AstraDocumentStoreFilterError +from astra_haystack.filters import _convert_filters + +logger = logging.getLogger(__name__) + + +MAX_BATCH_SIZE = 20 + + +def _batches(input_list, batch_size): + input_length = len(input_list) + for ndx in range(0, input_length, batch_size): + yield input_list[ndx : min(ndx + batch_size, input_length)] + + +class AstraDocumentStore: + """ + An AstraDocumentStore document store for Haystack. + """ + + def __init__( + self, + astra_id: str, + astra_region: str, + astra_application_token: str, + astra_keyspace: str, + astra_collection: str, + embedding_dim: Optional[int] = 768, + duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE, + similarity: str = "cosine", + ): + """ + The connection to Astra DB is established and managed through the JSON API. + The required credentials (database ID, region, and application token) can be generated + through the UI by clicking and the connect tab, and then selecting JSON API and + Generate Configuration. + + :param astra_id: id of the Astra DB instance. + :param astra_region: Region of cloud servers (can be found when generating the token). + :param astra_application_token: the connection token for Astra. + :param astra_keyspace: The keyspace for the current Astra DB. + :param astra_collection: The current collection in the keyspace in the current Astra DB. + :param embedding_dim: Dimension of embedding vector. + :param similarity: The similarity function used to compare document vectors. + :param duplicates_policy: Handle duplicate documents based on DuplicatePolicy parameter options. + Parameter options : (SKIP, OVERWRITE, FAIL, NONE) + - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, + it is skipped and not written. + - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, it is skipped and not written. + - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. + - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. + """ + + self.duplicates_policy = duplicates_policy + self.astra_id = astra_id + self.astra_region = astra_region + self.astra_application_token = astra_application_token + self.astra_keyspace = astra_keyspace + self.astra_collection = astra_collection + self.embedding_dim = embedding_dim + self.similarity = similarity + + self.index = AstraClient( + astra_id=self.astra_id, + astra_region=self.astra_region, + astra_application_token=self.astra_application_token, + keyspace_name=self.astra_keyspace, + collection_name=self.astra_collection, + embedding_dim=self.embedding_dim, + similarity_function=self.similarity, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AstraDocumentStore": + return default_from_dict(cls, data) + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + duplicates_policy=self.duplicates_policy.name, + astra_id=self.astra_id, + astra_region=self.astra_region, + astra_keyspace=self.astra_keyspace, + astra_collection=self.astra_collection, + embedding_dim=self.embedding_dim, + similarity=self.similarity, + ) + + def write_documents( + self, + documents: List[Document], + index: Optional[str] = None, + batch_size: int = 20, + policy: DuplicatePolicy = DuplicatePolicy.NONE, + ): + """ + Indexes documents for later queries. + + :param documents: a list of Haystack Document objects. + :param index: Optional name of index where the documents shall be written to. + If None, the DocumentStore's default index (self.index) will be used. + :param batch_size: Number of documents that are passed to bulk function at a time. + :param policy: Handle duplicate documents based on DuplicatePolicy parameter options. + Parameter options : (SKIP, OVERWRITE, FAIL, NONE) + - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, + it is skipped and not written. + - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, + it is skipped and not written. + - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. + - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. + :return: int + """ + + if index is None and self.index is None: + msg = "No Astra client provided" + raise ValueError(msg) + + if index is None: + index = self.index + + if policy is None or policy == DuplicatePolicy.NONE: + if self.duplicates_policy is not None and self.duplicates_policy != DuplicatePolicy.NONE: + policy = self.duplicates_policy + else: + policy = DuplicatePolicy.SKIP + + if batch_size > MAX_BATCH_SIZE: + logger.warning( + f"batch_size set to {batch_size}, " + f"but maximum batch_size for Astra when using the JSON API is 20. batch_size set to 20." + ) + batch_size = MAX_BATCH_SIZE + + def _convert_input_document(document: Union[dict, Document]): + if isinstance(document, Document): + document_dict = asdict(document) + elif isinstance(document, dict): + document_dict = document + else: + msg = f"Unsupported type for documents, documents is of type {type(document)}." + raise ValueError(msg) + + if "id" in document_dict: + if "_id" not in document_dict: + document_dict["_id"] = document_dict.pop("id") + elif "_id" in document_dict: + msg = f"Duplicate id definitions, both 'id' and '_id' present in document {document_dict}" + raise Exception(msg) + if "_id" in document_dict: + if not isinstance(document_dict["_id"], str): + msg = ( + f"Document id {document_dict['_id']} is not a string, " + f"but is of type {type(document_dict['_id'])}" + ) + raise Exception(msg) + + if "dataframe" in document_dict and document_dict["dataframe"] is not None: + document_dict["dataframe"] = document_dict.pop("dataframe").to_json() + document_dict["$vector"] = document_dict.pop("embedding", None) + + return document_dict + + documents_to_write = [_convert_input_document(doc) for doc in documents] + + duplicate_documents = [] + new_documents = [] + i = 0 + while i < len(documents_to_write): + doc = documents_to_write[i] + response = self.index.find_documents({"filter": {"_id": doc["_id"]}}) + if response: + if policy == DuplicatePolicy.FAIL: + msg = f"ID '{doc['_id']}' already exists." + raise DuplicateDocumentError(msg) + duplicate_documents.append(doc) + else: + new_documents.append(doc) + i = i + 1 + + insertion_counter = 0 + if policy == DuplicatePolicy.SKIP: + if len(new_documents) > 0: + for batch in _batches(new_documents, batch_size): + inserted_ids = index.insert(batch) # type: ignore + insertion_counter += len(inserted_ids) + logger.info(f"write_documents inserted documents with id {inserted_ids}") + else: + logger.warning("No documents written. Argument policy set to SKIP") + + elif policy == DuplicatePolicy.OVERWRITE: + if len(new_documents) > 0: + for batch in _batches(new_documents, batch_size): + inserted_ids = index.insert(batch) # type: ignore + insertion_counter += len(inserted_ids) + logger.info(f"write_documents inserted documents with id {inserted_ids}") + else: + logger.warning("No documents written. Argument policy set to OVERWRITE") + + if len(duplicate_documents) > 0: + updated_ids = [] + for duplicate_doc in duplicate_documents: + updated = index.update_document(duplicate_doc, "_id") # type: ignore + if updated: + updated_ids.append(duplicate_doc["_id"]) + insertion_counter = insertion_counter + len(updated_ids) + logger.info(f"write_documents updated documents with id {updated_ids}") + else: + logger.info("No documents updated. Argument policy set to OVERWRITE") + + elif policy == DuplicatePolicy.FAIL: + if len(new_documents) > 0: + for batch in _batches(new_documents, batch_size): + inserted_ids = index.insert(batch) # type: ignore + insertion_counter = insertion_counter + len(inserted_ids) + logger.info(f"write_documents inserted documents with id {inserted_ids}") + else: + logger.warning("No documents written. Argument policy set to FAIL") + + return insertion_counter + + def count_documents(self) -> int: + """ + Returns how many documents are present in the document store. + """ + return self.index.count_documents() + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """Returns at most 1000 documents that match the filter + + Args: + filters (Optional[Dict[str, Any]], optional): Filters to apply. Defaults to None. + + Raises: + AstraDocumentStoreFilterError: If the filter is invalid or not supported by this class. + + Returns: + List[Document]: A list of matching documents. + """ + if not isinstance(filters, dict) and filters is not None: + msg = "Filters must be a dictionary or None" + raise AstraDocumentStoreFilterError(msg) + + if filters is not None: + if "id" in filters: + filters["_id"] = filters.pop("id") + vector = None + if filters is not None and "embedding" in filters.keys(): + if "$in" in filters["embedding"]: + embeds = filters.pop("embedding") + vectors = embeds["$in"] + else: + filters["$vector"] = filters.pop("embedding") + vectors = [filters.pop("$vector")] + documents = [] + for vector in vectors: + converted_filters = _convert_filters(filters) + results = self.index.query( + vector=vector, + query_filter=converted_filters, + top_k=1000, + include_values=True, + include_metadata=True, + ) + documents.extend(self._get_result_to_documents(results)) + else: + converted_filters = _convert_filters(filters) + results = self.index.query( + vector=vector, query_filter=converted_filters, top_k=1000, include_values=True, include_metadata=True + ) + documents = self._get_result_to_documents(results) + return documents + + @staticmethod + def _get_result_to_documents(results) -> List[Document]: + documents = [] + for match in results.matches: + dataframe = match.metadata.pop("dataframe", None) + if dataframe is not None: + df = pd.DataFrame.from_dict(json.loads(dataframe)) + else: + df = None + document = Document( + content=match.text, + id=match.document_id, + embedding=match.values, + dataframe=df, + blob=match.metadata.pop("blob", None), + meta=match.metadata.pop("meta", None), + score=match.score, + ) + documents.append(document) + return documents + + def get_documents_by_id(self, ids: List[str]) -> List[Document]: + """ + Returns documents with given ids. + """ + results = self.index.get_documents(ids=ids) + ret = self._get_result_to_documents(results) + return ret + + def get_document_by_id(self, document_id: str) -> Document: + """ + :param document_id: id of the document to retrieve + Returns documents with given ids. + Raises MissingDocumentError when document_id does not exist in document store + """ + document = self.index.get_documents(ids=[document_id]) + ret = self._get_result_to_documents(document) + if not ret: + msg = f"Document {document_id} does not exist" + raise MissingDocumentError(msg) + return ret[0] + + def search( + self, query_embedding: List[float], top_k: int, filters: Optional[Dict[str, Any]] = None + ) -> List[Document]: + """Perform a search for a list of queries. + + Args: + query_embedding (List[float]): A list of query embeddings. + top_k (int): The number of results to return. + filters (Optional[Dict[str, Any]], optional): Filters to apply during search. Defaults to None. + + Returns: + List[Document]: A list of matching documents. + """ + converted_filters = _convert_filters(filters) + + result = self._get_result_to_documents( + self.index.query( + vector=query_embedding, + top_k=top_k, + query_filter=converted_filters, + include_metadata=True, + include_values=True, + ) + ) + logger.debug(f"Raw responses: {result}") # leaving for debugging + + return result + + def delete_documents(self, document_ids: Optional[List[str]] = None, delete_all: Optional[bool] = None) -> None: + """ + Deletes all documents with a matching document_ids from the document store. + Fails with `MissingDocumentError` if no document with this id is present in the store. + + :param document_ids: the document_ids to delete. + :param delete_all: delete all documents. + """ + + deletion_counter = 0 + if self.index.count_documents() > 0: + if document_ids is not None: + for batch in _batches(document_ids, MAX_BATCH_SIZE): + deletion_counter += self.index.delete(ids=batch) + else: + deletion_counter = self.index.delete(delete_all=delete_all) + logger.info(f"{deletion_counter} documents deleted") + + if document_ids is not None and deletion_counter == 0: + msg = f"Document {document_ids} does not exist" + raise MissingDocumentError(msg) + else: + logger.info("No documents in document store") diff --git a/integrations/astra/src/astra_haystack/errors.py b/integrations/astra/src/astra_haystack/errors.py new file mode 100644 index 000000000..186a8fef2 --- /dev/null +++ b/integrations/astra/src/astra_haystack/errors.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 +from haystack.document_stores.errors import DocumentStoreError +from haystack.errors import FilterError + + +class AstraDocumentStoreError(DocumentStoreError): + pass + + +class AstraDocumentStoreFilterError(FilterError): + pass + + +class AstraDocumentStoreConfigError(AstraDocumentStoreError): + pass diff --git a/integrations/astra/src/astra_haystack/filters.py b/integrations/astra/src/astra_haystack/filters.py new file mode 100644 index 000000000..6b628486b --- /dev/null +++ b/integrations/astra/src/astra_haystack/filters.py @@ -0,0 +1,133 @@ +from typing import Any, Dict, List, Optional + +import pandas as pd +from haystack.errors import FilterError + + +def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts Haystack filters to Astra compatible filters. + """ + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise FilterError(msg) + + if "field" in filters: + return _parse_comparison_condition(filters) + return _parse_logical_condition(filters) + + +def _convert_filters(filters: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: + """ + Convert haystack filters to astra filterstring capturing all boolean operators + """ + if not filters: + return None + filters = _normalize_filters(filters) + + filter_statements = {} + for key, value in filters.items(): + if key in {"$and", "$or"}: + filter_statements[key] = value + else: + if key == "id": + filter_statements[key] = {"_id": value} + if key != "$in" and isinstance(value, list): + filter_statements[key] = {"$in": value} + elif isinstance(value, pd.DataFrame): + filter_statements[key] = value.to_json() + elif isinstance(value, dict): + for dkey, dvalue in value.items(): + if dkey == "$in" and not isinstance(dvalue, list): + exception_message = f"$in operator must have `ARRAY`, got {dvalue} of type {type(dvalue)}" + raise FilterError(exception_message) + converted = {dkey: dvalue} + filter_statements[key] = converted + else: + filter_statements[key] = value + + return filter_statements + + +# TODO consider other operators, or filters that are not with the same structure as field operator value +OPERATORS = { + "==": "$eq", + "!=": "$neq", + ">": "$gt", + ">=": "$gte", + "<": "$lt", + "<=": "$lte", + "in": "$in", + "not in": "$nin", + "AND": "$and", + "OR": "$or", +} + + +def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + operator = condition["operator"] + conditions = [_parse_comparison_condition(c) for c in condition["conditions"]] + if len(conditions) > 1: + conditions = _normalize_ranges(conditions) + if operator not in OPERATORS: + msg = f"Unknown operator {operator}" + raise FilterError(msg) + return {OPERATORS[operator]: conditions} + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "field" not in condition: + msg = f"'field' key missing in {condition}" + raise FilterError(msg) + field: str = condition["field"] + + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + value: Any = condition["value"] + if isinstance(value, pd.DataFrame): + value = value.to_json() + + return {field: {OPERATORS[operator]: value}} + + +def _normalize_ranges(conditions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Merges range conditions acting on a same field. + + Example usage: + + ```python + conditions = [ + {"range": {"date": {"lt": "2021-01-01"}}}, + {"range": {"date": {"gte": "2015-01-01"}}}, + ] + conditions = _normalize_ranges(conditions) + assert conditions == [ + {"range": {"date": {"lt": "2021-01-01", "gte": "2015-01-01"}}}, + ] + ``` + """ + range_conditions = [next(iter(c["range"].items())) for c in conditions if "range" in c] + if range_conditions: + conditions = [c for c in conditions if "range" not in c] + range_conditions_dict: Dict[str, Any] = {} + for field_name, comparison in range_conditions: + if field_name not in range_conditions_dict: + range_conditions_dict[field_name] = {} + range_conditions_dict[field_name].update(comparison) + + for field_name, comparisons in range_conditions_dict.items(): + conditions.append({"range": {field_name: comparisons}}) + return conditions diff --git a/integrations/astra/src/astra_haystack/retriever.py b/integrations/astra/src/astra_haystack/retriever.py new file mode 100644 index 000000000..47304df2c --- /dev/null +++ b/integrations/astra/src/astra_haystack/retriever.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict + +from astra_haystack.document_store import AstraDocumentStore + + +@component +class AstraRetriever: + """ + A component for retrieving documents from an AstraDocumentStore. + """ + + def __init__(self, document_store: AstraDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): + """ + Create an AstraRetriever component. Usually you pass some basic configuration + parameters to the constructor. + + :param filters: A dictionary with filters to narrow down the search space (default is None). + :param top_k: The maximum number of documents to retrieve (default is 10). + """ + self.filters = filters + self.top_k = top_k + self.document_store = document_store + + if not isinstance(document_store, AstraDocumentStore): + message = "document_store must be an instance of AstraDocumentStore" + raise Exception(message) + + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Run the retriever on the given list of queries. + + Args: + query_embedding (List[str]): An input list of queries + filters (Optional[Dict[str, Any]], optional): A dictionary with filters to narrow down the search space. + Defaults to None. + top_k (Optional[int], optional): The maximum number of documents to retrieve. Defaults to None. + """ + + if not top_k: + top_k = self.top_k + + if not filters: + filters = self.filters + + return {"documents": self.document_store.search(query_embedding, top_k, filters=filters)} + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AstraRetriever": + document_store = AstraDocumentStore.from_dict(data["init_parameters"]["document_store"]) + data["init_parameters"]["document_store"] = document_store + return default_from_dict(cls, data) diff --git a/integrations/astra/tests/__init__.py b/integrations/astra/tests/__init__.py new file mode 100644 index 000000000..f5e799e88 --- /dev/null +++ b/integrations/astra/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/astra/tests/conftest.py b/integrations/astra/tests/conftest.py new file mode 100644 index 000000000..863b0aa7f --- /dev/null +++ b/integrations/astra/tests/conftest.py @@ -0,0 +1,35 @@ +import os + +import pytest +from haystack.document_stores import DuplicatePolicy + +from astra_haystack.document_store import AstraDocumentStore + + +@pytest.fixture +def document_store() -> AstraDocumentStore: + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + astra_id = os.getenv("ASTRA_DB_ID", "") + astra_region = os.getenv("ASTRA_DB_REGION", "us-east-2") + + astra_application_token = os.getenv( + "ASTRA_DB_APPLICATION_TOKEN", + "", + ) + + keyspace_name = "astra_haystack_test" + collection_name = "haystack_integration" + + astra_store = AstraDocumentStore( + astra_id=astra_id, + astra_region=astra_region, + astra_application_token=astra_application_token, + astra_keyspace=keyspace_name, + astra_collection=collection_name, + duplicates_policy=DuplicatePolicy.OVERWRITE, + embedding_dim=768, + ) + return astra_store diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py new file mode 100644 index 000000000..972c565b2 --- /dev/null +++ b/integrations/astra/tests/test_document_store.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List + +import pytest +from haystack import Document +from haystack.document_stores import DuplicatePolicy, MissingDocumentError +from haystack.testing.document_store import DocumentStoreBaseTests + +from astra_haystack.document_store import AstraDocumentStore + + +@pytest.mark.skipif( + os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN is not set" +) +@pytest.mark.skipif(os.environ.get("ASTRA_DB_ID", "") == "", reason="ASTRA_DB_ID is not set") +class TestDocumentStore(DocumentStoreBaseTests): + """ + Common test cases will be provided by `DocumentStoreBaseTests` but + you can add more to this class. + """ + + @pytest.fixture + @pytest.mark.usefixtures + def document_store(self, document_store) -> AstraDocumentStore: + return document_store + + @pytest.fixture(autouse=True) + def run_before_and_after_tests(self, document_store: AstraDocumentStore): + """ + Cleaning up document store + """ + document_store.delete_documents(delete_all=True) + assert document_store.count_documents() == 0 + + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + This is used in every test, if a Document Store implementation has a different behaviour + it should override this method. + + This can happen for example when the Document Store sets a score to returned Documents. + Since we can't know what the score will be, we can't compare the Documents reliably. + """ + import operator + + received.sort(key=operator.attrgetter("id")) + expected.sort(key=operator.attrgetter("id")) + assert received == expected + + def test_comparison_equal_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "==", "value": None}) + # Astra does not support filtering on None, it returns empty list + self.assert_documents_are_equal(result, []) + + def test_write_documents(self, document_store: AstraDocumentStore): + """ + Test write_documents() overwrites stored Document when trying to write one with same id + using DuplicatePolicy.OVERWRITE. + """ + doc1 = Document(id="1", content="test doc 1") + doc2 = Document(id="1", content="test doc 2") + + assert document_store.write_documents([doc2], policy=DuplicatePolicy.OVERWRITE) == 1 + self.assert_documents_are_equal(document_store.filter_documents(), [doc2]) + assert document_store.write_documents(documents=[doc1], policy=DuplicatePolicy.OVERWRITE) == 1 + self.assert_documents_are_equal(document_store.filter_documents(), [doc1]) + + def test_delete_documents_non_existing_document(self, document_store: AstraDocumentStore): + """ + Test delete_documents() doesn't delete any Document when called with non existing id. + """ + doc = Document(content="test doc") + document_store.write_documents([doc]) + assert document_store.count_documents() == 1 + + with pytest.raises(MissingDocumentError): + document_store.delete_documents(["non_existing_id"]) + + # No Document has been deleted + assert document_store.count_documents() == 1 + + def test_delete_documents_more_than_twenty_delete_all(self, document_store: AstraDocumentStore): + """ + Test delete_documents() deletes all documents when called on an Astra DB with + more than 20 documents. Twenty documents is the maximum number of deleted + documents in one call for Astra. + """ + docs = [] + for i in range(1, 26): + doc = Document(content=f"test doc {i}", id=str(i)) + docs.append(doc) + document_store.write_documents(docs) + assert document_store.count_documents() == 25 + + document_store.delete_documents(delete_all=True) + + assert document_store.count_documents() == 0 + + def test_delete_documents_more_than_twenty_delete_ids(self, document_store: AstraDocumentStore): + """ + Test delete_documents() deletes all documents when called on an Astra DB with + more than 20 documents. Twenty documents is the maximum number of deleted + documents in one call for Astra. + """ + docs = [] + document_ids = [] + for i in range(1, 26): + doc = Document(content=f"test doc {i}", id=str(i)) + docs.append(doc) + document_ids.append(str(i)) + document_store.write_documents(docs) + assert document_store.count_documents() == 25 + + document_store.delete_documents(document_ids=document_ids) + + # No Document has been deleted + assert document_store.count_documents() == 0 + + @pytest.mark.skip(reason="Unsupported filter operator not.") + def test_not_operator(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $neq.") + def test_comparison_not_equal_with_none(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $neq.") + def test_comparison_not_equal(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $neq.") + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $nin.") + def test_comparison_not_in(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $nin.") + def test_comparison_not_in_with_with_non_list(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $nin.") + def test_comparison_not_in_with_with_non_list_iterable(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gt.") + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gt.") + def test_comparison_greater_than_with_string(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gt.") + def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gt.") + def test_comparison_greater_than_with_list(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gt.") + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gt.") + def test_comparison_greater_than(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gte.") + def test_comparison_greater_than_equal(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gte.") + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gte.") + def test_comparison_greater_than_equal_with_list(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gte.") + def test_comparison_greater_than_equal_with_dataframe(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gte.") + def test_comparison_greater_than_equal_with_string(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $gte.") + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lte.") + def test_comparison_less_than_equal(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lte.") + def test_comparison_less_than_equal_with_string(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lte.") + def test_comparison_less_than_equal_with_dataframe(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lte.") + def test_comparison_less_than_equal_with_list(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lte.") + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lte.") + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lt.") + def test_comparison_less_than_with_none(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lt.") + def test_comparison_less_than_with_list(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lt.") + def test_comparison_less_than_with_dataframe(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lt.") + def test_comparison_less_than_with_string(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lt.") + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): + pass + + @pytest.mark.skip(reason="Unsupported filter operator $lt.") + def test_comparison_less_than(self, document_store, filterable_docs): + pass diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py new file mode 100644 index 000000000..2212d44fd --- /dev/null +++ b/integrations/astra/tests/test_retriever.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest + +from astra_haystack.retriever import AstraRetriever + + +@pytest.mark.skipif( + os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN is not set" +) +@pytest.mark.skipif(os.environ.get("ASTRA_DB_ID", "") == "", reason="ASTRA_DB_ID is not set") +@pytest.mark.integration +def test_retriever_to_json(document_store): + retriever = AstraRetriever(document_store, filters={"foo": "bar"}, top_k=99) + assert retriever.to_dict() == { + "type": "astra_haystack.retriever.AstraRetriever", + "init_parameters": { + "filters": {"foo": "bar"}, + "top_k": 99, + "document_store": { + "init_parameters": { + "astra_collection": "haystack_integration", + "astra_id": "63195634-ba44-49be-8a3c-12e830eb1c01", + "astra_keyspace": "astra_haystack_test", + "astra_region": "us-east-2", + "duplicates_policy": "OVERWRITE", + "embedding_dim": 768, + "similarity": "cosine", + }, + "type": "astra_haystack.document_store.AstraDocumentStore", + }, + }, + } + + +@pytest.mark.skipif( + os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN is not set" +) +@pytest.mark.skipif(os.environ.get("ASTRA_DB_ID", "") == "", reason="ASTRA_DB_ID is not set") +@pytest.mark.integration +def test_retriever_from_json(): + data = { + "type": "astra_haystack.retriever.AstraRetriever", + "init_parameters": { + "filters": {"bar": "baz"}, + "top_k": 42, + "document_store": { + "init_parameters": { + "astra_collection": "haystack_integration", + "astra_id": "63195634-ba44-49be-8a3c-12e830eb1c01", + "astra_application_token": os.getenv("ASTRA_DB_APPLICATION_TOKEN", ""), + "astra_keyspace": "astra_haystack_test", + "astra_region": "us-east-2", + "duplicates_policy": "overwrite", + "embedding_dim": 768, + "similarity": "cosine", + }, + "type": "astra_haystack.document_store.AstraDocumentStore", + }, + }, + } + retriever = AstraRetriever.from_dict(data) + assert retriever.top_k == 42 + assert retriever.filters == {"bar": "baz"}